├── pydmed ├── __init__.py ├── utils │ ├── __init__.py │ ├── minimath.py │ ├── multiproc.py │ ├── output.py │ └── data.py ├── extensions │ ├── __init__.py │ ├── dl.py │ └── wsi.py ├── stat.py ├── streamcollector.py └── lightdl.py ├── howitworks.gif ├── sample_notebooks ├── Sample3_Heatmap_for_WSIs │ ├── Output │ │ └── FromNdarraytoImage │ │ │ └── .gitignore │ └── sample3_heatmap_for_WSIs.ipynb ├── Sample_2_Output │ └── patient_1.eps.png └── sample_1_train_classifier.ipynb ├── button_quickstart.png ├── button_quickstart2.png ├── .gitignore ├── LICENSE └── README.md /pydmed/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pydmed/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pydmed/extensions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /howitworks.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirakbarnejad/PyDmed/HEAD/howitworks.gif -------------------------------------------------------------------------------- /sample_notebooks/Sample3_Heatmap_for_WSIs/Output/FromNdarraytoImage/.gitignore: -------------------------------------------------------------------------------- 1 | *.npy 2 | -------------------------------------------------------------------------------- /button_quickstart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirakbarnejad/PyDmed/HEAD/button_quickstart.png -------------------------------------------------------------------------------- /button_quickstart2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirakbarnejad/PyDmed/HEAD/button_quickstart2.png -------------------------------------------------------------------------------- /sample_notebooks/Sample_2_Output/patient_1.eps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirakbarnejad/PyDmed/HEAD/sample_notebooks/Sample_2_Output/patient_1.eps.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | NonGit/ 2 | .ipynb_checkpoints 3 | 4 | 5 | 6 | #-------------------------------------- git default ignored files for python project --------------------- 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | # C extensions 11 | *.so 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | # Translations 47 | *.mo 48 | *.pot 49 | 50 | # Django stuff: 51 | *.log 52 | 53 | # Sphinx documentation 54 | docs/_build/ 55 | 56 | # PyBuilder 57 | target/ 58 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 amirakbarnejad 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pydmed/utils/minimath.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | import math 5 | 6 | 7 | 8 | def lcm(list_numbers): 9 | ''' 10 | Computes the lcm of numbers in a list. 11 | ''' 12 | lcm = list_numbers[0] 13 | for idx_number in range(1, len(list_numbers)): 14 | lcm = (lcm*list_numbers[idx_number])/math.gcd(lcm, list_numbers[idx_number]) 15 | lcm = int(lcm) 16 | return lcm 17 | 18 | 19 | def multimode(list_input): 20 | ''' 21 | `statistics.multimode` does not exist in all python versions. 22 | Therefore minimath.multimode is implemented. 23 | ''' 24 | set_data = set(list_input) 25 | dict_freqs = {val:0 for val in set_data} 26 | for elem in list_input: 27 | dict_freqs[elem] = dict_freqs[elem] + 1 28 | mode = max((v, k) for k, v in dict_freqs.items())[1] 29 | return mode 30 | 31 | 32 | def multiminority(list_input): 33 | ''' 34 | Returns the minority in a list. This function works if there are many minorities available in the list. 35 | ''' 36 | set_data = set(list_input) 37 | dict_freqs = {val:0 for val in set_data} 38 | for elem in list_input: 39 | dict_freqs[elem] = dict_freqs[elem] + 1 40 | minority = min((v, k) for k, v in dict_freqs.items())[1] 41 | return minority 42 | -------------------------------------------------------------------------------- /pydmed/utils/multiproc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import os, sys 4 | import psutil 5 | from pathlib import Path 6 | import re 7 | import time 8 | import random 9 | import multiprocessing as mp 10 | from abc import ABC, abstractmethod 11 | import openslide 12 | import torch 13 | import torchvision 14 | import torchvision.models as models 15 | from multiprocessing import Process, Queue 16 | 17 | 18 | def poplast_from_queue(queue): 19 | ''' 20 | Pops the last element of a `multiprocessing.Queue`. 21 | ''' 22 | size_queue = queue.qsize() 23 | if(size_queue == 0): 24 | return None 25 | elem = None 26 | for count in range(size_queue): 27 | try: 28 | elem = queue.get_nowait() 29 | except: 30 | pass 31 | return elem 32 | 33 | 34 | def set_nicemax(): 35 | ''' 36 | Sets the priority of the process to the highest value. 37 | ''' 38 | maxcount = 1000 39 | N_old = os.nice(0) 40 | count = 0 41 | while(True): 42 | count += 1 43 | N_new = os.nice(N_old+1000) 44 | if(N_new == N_old): 45 | return 46 | if(count > maxcount): 47 | return 48 | 49 | 50 | 51 | def terminaterecursively(pid): 52 | print("=================================================================================") 53 | parent = psutil.Process(pid)#TODO:copyright, https://www.reddit.com/r/learnpython/comments/7vwyez/how_to_kill_child_processes_when_using/ 54 | for child in parent.children(recursive=True): 55 | try: 56 | child.kill() 57 | except: 58 | pass 59 | #print(" killed subprocess {}".format(child)) 60 | #if including_parent: 61 | try: 62 | parent.kill() 63 | except: 64 | pass 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | [![button](button_quickstart2.png)](https://amirakbarnejad.github.io/Tutorial/tutorial_section1.html) 3 | 4 | # PyDmed (Python Dataloader for Medical Imaging) 5 | 6 | ***Sample notebooks are available in the folder [sample notebooks](https://github.com/amirakbarnejad/PyDmed/tree/master/sample_notebooks/).*** 7 | 8 | The loading speed of hard drives is well below the processing speed of modern GPUs. 9 | This is problematic for machine learning algorithms, specially for medical imaging datasets with large instances. 10 | 11 | For example, consider the following case: we have a dataset containing 500 12 | [whole-slide-images](https://en.wikipedia.org/wiki/Digital_pathology) 13 | (WSIs) each of which are approximately 100000x100000. 14 | We want the dataloader to repeatedly do the following steps: 15 | 1. randomly select one of those huge images (i.e., WSIs). 16 | 2. crop and return a random 224x224 patch from the huge image. 17 | 18 | 19 | ***PyDmed solves this issue.*** 20 | 21 | 22 | # How It Works? 23 | The following two classes are pretty much the whole API of PyDmed. 24 | 1. `BigChunk`: a relatively big chunk from a patient. It can be, e.g., a 5000x5000 patch from a huge whole-slide-image. 25 | 2. `SmallChunk`: a small data chunk collected from a big chunk. It can be, e.g., a 224x224 patch cropped from a 5000x5000 big chunk. In the below figure, `SmallChunk`s are the blue small patches. 26 | 27 | The below figure illustrates the idea of PyDmed. 28 | As long as some `BigChunk`s are loaded into RAM, we can quickly collect some `SmallChunk`s and pass them to GPU(s). 29 | As illustrated below, `BigChunk`s are loaded/replaced from disk time to time. 30 | ![Alt Text](howitworks.gif) 31 | 32 | [![button](button_quickstart2.png)](https://amirakbarnejad.github.io/Tutorial/tutorial_section1.html) 33 | 34 | # Issues 35 | We regularly check for possible issues and update pydmed. Please check out "Issues" if you faced any problems running pydmed. 36 | If you couldn't find your issue there, please raise the issue so we can improve pydmed. 37 | 38 | # Installation 39 | PyDmed is now available as a pyton local package. To use PyDmed one needs to have the folder called `PyDmed/` (by, e.g., cloning the repo). 40 | Afterwards, the folder has to be added to `sys.path` as done in [sample notebook 1](https://github.com/amirakbarnejad/PyDmed/blob/8ef0f6b9282815498bc50bf31a827a8a7eeb48a8/sample_notebooks/sample_1_train_classifier.ipynb) 41 | or in [the sample colab notebook](https://colab.research.google.com/drive/1WvntL-guv9JATJQWaS_Ww32DLBwGd9Ux?usp=sharing). 42 | 43 | # Citation 44 | To cide pydmed, please cite the following paper 45 | 46 | 47 | @inproceedings{akbarnejad2021deep, 48 | title={Deep Fisher Vector Coding For Whole Slide Image Classification}, 49 | author={Akbarnejad, Amir and Ray, Nilanjan and Bigras, Gilbert}, 50 | booktitle={2021 IEEE 18th International Symposium on Biomedical Imaging (ISBI)}, 51 | pages={243--246}, 52 | year={2021}, 53 | organization={IEEE} 54 | } 55 | -------------------------------------------------------------------------------- /pydmed/extensions/dl.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | ''' 4 | Extensions related to PyDmed's dataloader. 5 | 6 | ''' 7 | 8 | import math 9 | import numpy as np 10 | from abc import ABC, abstractmethod 11 | import random 12 | import time 13 | import openslide 14 | import copy 15 | import torchvision 16 | import pydmed 17 | import pydmed.lightdl 18 | from pydmed import * 19 | from pydmed.lightdl import * 20 | 21 | 22 | 23 | class LabelBalancedDL(pydmed.lightdl.LightDL): 24 | ''' 25 | This dataloader makes sure that the returned smallchunks are have a balanced label 26 | frequency. 27 | Inputs. 28 | - func_getlabel_of_patient: a function that takes in a `Patient` and returns 29 | the corresponding label. The returned smallchunks are balanced in terms of 30 | this label. 31 | - ... other arguments, same as LightDL, 32 | https://github.com/amirakbarnejad/PyDmed/blob/8575ea991fe464b6e451d1a3381f9026581153da/pydmed/lightdl.py#L292 33 | ''' 34 | def __init__(self, func_getlabel_of_patient, *args, **kwargs): 35 | ''' 36 | Inputs. 37 | - func_getlabel_of_patient: a function that takes in a `Patient` and returns 38 | the corresponding label. The returned smallchunks are balanced in terms of 39 | this label. 40 | - ... other arguments, same as LightDL, 41 | https://github.com/amirakbarnejad/PyDmed/blob/8575ea991fe464b6e451d1a3381f9026581153da/pydmed/lightdl.py#L29 42 | ''' 43 | super(LabelBalancedDL, self).__init__(*args, **kwargs) 44 | #grab privates 45 | self.func_getlabel_of_patient = func_getlabel_of_patient 46 | #make separate lists for different classes ==== 47 | possible_labels = list( 48 | set( 49 | [self.func_getlabel_of_patient(patient)\ 50 | for patient in self.dataset.list_patients] 51 | ) 52 | ) 53 | dict_label_to_listpatients = {label:[] for label in possible_labels} 54 | for patient in self.dataset.list_patients: 55 | label_of_patient = self.func_getlabel_of_patient(patient) 56 | dict_label_to_listpatients[label_of_patient].append(patient) 57 | self.possible_labels = possible_labels 58 | self.dict_label_to_listpatients = dict_label_to_listpatients 59 | 60 | def initial_schedule(self): 61 | # ~ print("override initsched called.") 62 | #split numbigchunks to lists of almost equal length ====== 63 | avg_inbin = self.const_global_info["num_bigchunkloaders"]/len(self.possible_labels) 64 | avg_inbin = math.floor(avg_inbin) 65 | list_binsize = [avg_inbin for label in self.possible_labels] 66 | num_toadd = self.const_global_info["num_bigchunkloaders"]-\ 67 | avg_inbin*len(self.possible_labels) 68 | for n in range(num_toadd): 69 | list_binsize[n] += 1 70 | #randomly sample patients from different classes ===== 71 | toret_list_patients = [] 72 | for idx_bin, size_bin in enumerate(list_binsize): 73 | label = self.possible_labels[idx_bin] 74 | toret_list_patients = toret_list_patients +\ 75 | random.choices(self.dict_label_to_listpatients[label], k=size_bin) 76 | return toret_list_patients 77 | 78 | def schedule(self): 79 | # ~ print("override sched called.") 80 | #get initial fields ============================== 81 | list_loadedpatients = self.get_list_loadedpatients() 82 | list_waitingpatients = self.get_list_waitingpatients() 83 | schedcount_of_waitingpatients = [self.get_schedcount_of(patient)\ 84 | for patient in list_waitingpatients] 85 | #patient_toremove is selected randomly ======================= 86 | patient_toremove = random.choice(list_loadedpatients) 87 | #choose the patient to load ================ 88 | minority_label = pydmed.utils.minimath.multiminority( 89 | [self.func_getlabel_of_patient(patient) for patient in list_loadedpatients] 90 | ) 91 | toadd_candidates = self.dict_label_to_listpatients[minority_label] 92 | weights = 1.0/(1.0+np.array( 93 | [self.get_schedcount_of(patient) for patient in toadd_candidates] 94 | )) 95 | weights[weights==1.0] =10000000.0 #if the case is not loaded sofar, give it a high prio 96 | patient_toload = random.choices(toadd_candidates,\ 97 | weights = weights, k=1)[0] 98 | return patient_toremove, patient_toload 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /pydmed/utils/output.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import sys 5 | import os 6 | import multiprocessing as mp 7 | import csv 8 | import time 9 | 10 | 11 | 12 | class StreamWriter(mp.Process): 13 | def __init__(self, list_patients=None, rootpath=None, fname_tosave=None, 14 | waiting_time_before_flush = 3): 15 | ''' 16 | StreamWriter works in two modes: 17 | 1) one file is created for the whole dataset. In this case, 18 | only `fname_tosave` is used and the argument `rootpath` must be None. 19 | 2) one file is created for each `Patient` in the directory `rootpath`. 20 | In this case, `fname_tosave` must be None. 21 | Inputs: 22 | - waiting_time_before_flush: before flushing the contents, it should 23 | sleep a few seconds. Default is 3 seconds. 24 | ''' 25 | super(StreamWriter, self).__init__() 26 | if(isinstance(rootpath, str) and isinstance(fname_tosave, str)): 27 | if((rootpath!=None) and (fname_tosave!=None)): 28 | exception_msg = "One of the arguments `rootpath` and `fname_tosave`"+\ 29 | " must be set to None. For details, please refer to"+\ 30 | " `StreamWriter` documentation" 31 | raise Exception(exception_msg) 32 | if(isinstance(fname_tosave, str)): 33 | if(fname_tosave != None): 34 | self.op_mode = 1 35 | if(isinstance(rootpath, str)): 36 | if(rootpath != None): 37 | self.op_mode = 2 38 | if(hasattr(self, "op_mode") == False): 39 | exception_msg = "Exactly one of the arguments `rootpath` or `fname_tosave`"+\ 40 | " must be set to a string."+\ 41 | " For details, please refer to"+\ 42 | " `StreamWriter` documentation" 43 | raise Exception(exception_msg) 44 | if(self.op_mode == 1): 45 | if(fname_tosave.endswith(".csv") == False): 46 | raise Exception("The argument `fname_tosave` must end with .csv."+\ 47 | "Because only .csv format is supported.") 48 | if(self.op_mode == 2): 49 | if(len(list(os.listdir(rootpath))) > 0): 50 | print(list(os.listdir(rootpath))) 51 | raise Exception("The folder {} \n is not empty.".format(rootpath)+\ 52 | " Delete its files before continuing.") 53 | #grab privates ================ 54 | self.list_patients = list_patients 55 | self.rootpath = rootpath 56 | self.fname_tosave = fname_tosave 57 | self.waiting_time_before_flush = waiting_time_before_flush 58 | #make/open csv file(s) ======================= 59 | if(self.op_mode == 1): 60 | self.list_files = [open(fname_tosave, mode='a+')] 61 | self.list_writers = [csv.writer(f, delimiter=',',\ 62 | quotechar='"', quoting=csv.QUOTE_MINIMAL 63 | ) for f in self.list_files] 64 | elif(self.op_mode == 2): 65 | self.list_files = [open(os.path.join(rootpath,\ 66 | "patient_{}.csv".format(patient.int_uniqueid)) 67 | , mode='a+') for patient in list_patients] 68 | self.list_writers = [csv.writer(f, delimiter=',',\ 69 | quotechar='"', quoting=csv.QUOTE_MINIMAL 70 | ) for f in self.list_files] 71 | #make mp stuff ======== 72 | self.queue_towrite = mp.Queue() #there is one queue in both operating modes. 73 | self.queue_signal_end = mp.Queue() #this queue not empty means "close" 74 | self.flag_closecalled = False #once close is called, writing would be disabled. 75 | 76 | def flush_and_close(self): 77 | self.flag_closecalled = True 78 | time.sleep(self.waiting_time_before_flush) 79 | self.queue_signal_end.put_nowait("stop") 80 | 81 | 82 | def run(self): 83 | while True: 84 | if(self.queue_signal_end.qsize()>0): 85 | #execute flush_and_close ========== 86 | self.flag_closecalled = True 87 | self._wrt_onclose() 88 | for f in self.list_files: 89 | f.flush() 90 | f.close() 91 | break 92 | else: 93 | #patrol the queue ========== 94 | self._wrt_patrol() 95 | 96 | 97 | def write(self, patient, str_towrite): 98 | ''' 99 | Writes to file (s). 100 | Inputs. 101 | - patient: an instance of `Patient`. This argument is ignored 102 | when operating in mode 1. 103 | - str_towrite: the string to be written to file. 104 | ''' 105 | if(self.flag_closecalled == False): 106 | self.queue_towrite.put_nowait({"patient": patient, "str_towrite":str_towrite}) 107 | else: 108 | print("`StreamWriter` cannot `write` after calling the `close` function.") 109 | 110 | def _wrt_patrol(self): 111 | ''' 112 | Pops/writes one element from the queue 113 | ''' 114 | if(self.queue_towrite.qsize() > 0): 115 | try: 116 | poped_elem = self.queue_towrite.get_nowait() 117 | 118 | if(self.op_mode == 1): 119 | self.list_files[0].write(poped_elem["str_towrite"]) 120 | elif(self.op_mode == 2): 121 | patient, str_towrite = poped_elem["patient"], poped_elem["str_towrite"] 122 | assert(patient in self.list_patients) 123 | idx_patient = self.list_patients.index(patient) 124 | self.list_files[idx_patient].write(str_towrite) 125 | except Exception as e: 126 | pass 127 | #print("\n\n\n\n*************") 128 | #print(str(e)) 129 | 130 | 131 | def _wrt_onclose(self): 132 | ''' 133 | Pops/writes all elements of the queue. 134 | ''' 135 | qsize = self.queue_towrite.qsize() 136 | if(qsize > 0): 137 | for idx_elem in range(qsize): 138 | try: 139 | poped_elem = self.queue_towrite.get_nowait() 140 | if(self.op_mode == 1): 141 | self.list_files[0].write(poped_elem["str_towrite"]) 142 | elif(self.op_mode == 2): 143 | patient, str_towrite = poped_elem["patient"], poped_elem["str_towrite"] 144 | assert(patient in self.list_patients) 145 | idx_patient = self.list_patients.index(patient) 146 | self.list_files[idx_patient].write(str_towrite) 147 | except: 148 | pass 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /pydmed/stat.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import numpy as np 5 | import os, sys 6 | import math 7 | from pathlib import Path 8 | import re 9 | import time 10 | import random 11 | import multiprocessing as mp 12 | import openslide 13 | from multiprocessing import Process, Queue 14 | from abc import ABC, abstractmethod 15 | import matplotlib.pyplot as plt 16 | import pydmed.utils.output 17 | from pydmed.utils.output import StreamWriter 18 | 19 | 20 | class ProcessedPiece: 21 | def __init__(self, source_smallchunk, data=None, stat=None): 22 | ''' 23 | Note: stat and data are the same. The two arguments are added to support previous versions. 24 | Inputs: 25 | - data: the data part of the processed piece. 26 | - source_smallchunk: the `SmallChunk` from which the stat is collected. 27 | ''' 28 | #grab privates 29 | if(isinstance(data, np.ndarray) == True): 30 | self.stat = data 31 | else: 32 | self.stat = stat 33 | self.source_smallchunk = source_smallchunk 34 | self.source_smallchunk.data = "None, to avoid memory leak" 35 | 36 | 37 | class StreamCollector(object): 38 | def __init__(self, lightdl, str_collectortype, flag_visualizestats=False, kwargs_streamwriter=None): 39 | ''' 40 | TODO:adddoc. str_collectortype can be "accum" or "saveall" or "stream_to_file" 41 | ''' 42 | #grab initargs 43 | self.lightdl = lightdl 44 | self.str_collectortype = str_collectortype 45 | self.flag_visualizestats = flag_visualizestats 46 | self.kwargs_streamwriter = kwargs_streamwriter 47 | #make internals 48 | self.dict_patient_to_liststats = {patient:[] for patient in self.lightdl.dataset.list_patients} 49 | self.dict_patient_to_accumstat = {patient:None for patient in self.lightdl.dataset.list_patients} 50 | self._queue_onfinish_collectedstats = mp.Queue() 51 | if(self.str_collectortype.startswith("stream_to_file")): 52 | self.streamwriter = StreamWriter(lightdl.dataset.list_patients, **kwargs_streamwriter) 53 | 54 | 55 | 56 | def start_collecting(self): 57 | 58 | 59 | #make the following line tunable 60 | os.system("taskset -a -cp {} {}".format("0,1,2,3,4", os.getpid())) 61 | self.lightdl.start() 62 | if(self.str_collectortype.startswith("stream_to_file")): 63 | self.streamwriter.start() 64 | # ~ print(" statcollector.streamwriter started ") 65 | 66 | 67 | #TODO:make the following lines tunable. 68 | time.sleep(1*10) 69 | os.system("taskset -cp {} {}".format("5,6,7", os.getpid())) 70 | time.sleep(1*10) 71 | 72 | time.sleep(2) #TODO:remove 73 | time_lastcheck = time.time()+2 #TODO:make tunable 74 | time_lastupdate_visstats = time.time()+5 #TODO:make tunable 75 | count = 0 76 | #plot the visstats if needed, 77 | # ~ if(self.flag_visualizestats == True): 78 | # ~ pass 79 | # ~ self.logfile = open(r"lotstat.txt","w+") 80 | #TODO:handle visualizestats 81 | # ~ fig = plt.figure() 82 | # ~ ax = fig.add_subplot(111) 83 | # ~ x = [u for u in range(len(list(self.dict_patient_to_liststats.keys())))] 84 | # ~ y = [len(self.dict_patient_to_liststats[pat]) for pat in self.dict_patient_to_liststats.keys()] 85 | # ~ line1, = ax.plot(x, y, 'r-') 86 | # ~ plt.show() 87 | while True: 88 | count +=1 89 | 90 | #get an instance from the DL ============ 91 | retval_dl = self.lightdl.get() 92 | flag_invalid_retvaldl = False 93 | if(isinstance(retval_dl, str)): 94 | if(retval_dl == pydmed.lightdl.PYDMEDRESERVED_DLRETURNEDLASTINSTANCE): 95 | flag_invalid_retvaldl = True 96 | 97 | #collect stat only if the retval is valid ===== 98 | if(flag_invalid_retvaldl == False): 99 | list_collectedstats = self.process_pieceofstream(retval_dl) 100 | list_patients = [st.source_smallchunk.patient for st in list_collectedstats] 101 | self._manage_stats(list_collectedstats, list_patients) 102 | 103 | 104 | #stop collecting if needed ====== 105 | if((time.time()-time_lastcheck) > 5):#TODO:make tunable 106 | time_lastcheck = time.time() 107 | if(self.get_flag_finishcollecting() == True): 108 | toret_onfinish_collectedstats = {} 109 | #colllate all statistics 110 | if(self.str_collectortype == "saveall"): 111 | for patient in self.lightdl.dataset.list_patients: 112 | toret_onfinish_collectedstats[patient] = self.collate_stats_onfinishcollecting(patient, self.dict_patient_to_liststats[patient]) 113 | elif(self.str_collectortype == "accum"): 114 | for patient in self.lightdl.dataset.list_patients: 115 | toret_onfinish_collectedstats[patient] = \ 116 | self.dict_patient_to_accumstat[patient] 117 | elif(self.str_collectortype.startswith("stream_to_file")): 118 | self.streamwriter.flush_and_close() 119 | 120 | self._onfinish_collectedstats = toret_onfinish_collectedstats 121 | if(self.str_collectortype.startswith("stream_to_file") == False): 122 | pass #self._queue_onfinish_collectedstats.put_nowait(toret_onfinish_collectedstats) 123 | time.sleep(3) #TODO:make tunable 124 | #stop the lightdl 125 | self.lightdl.pause_loading() 126 | break 127 | 128 | def get_finalstats(self): 129 | try: 130 | if(self.str_collectortype.startswith("stream_to_file") == False): 131 | toret = self._onfinish_collectedstats #self._queue_onfinish_collectedstats.get() 132 | return toret 133 | 134 | # ~ return self._onfinish_collectedstats 135 | except: 136 | print("Error in getting the final collected stats. Is the StreamCollector finished when you called `StreamCollector.get_finalstats`?") 137 | 138 | def _manage_stats(self, list_collectedstats, list_patients): 139 | for n in range(len(list_patients)): 140 | patient = list_patients[n] 141 | if(self.str_collectortype == "saveall"): 142 | self.dict_patient_to_liststats[patient].append(list_collectedstats[n]) 143 | elif(self.str_collectortype == "accum"): 144 | self.dict_patient_to_accumstat[patient] = self.accum_statistics(self.dict_patient_to_accumstat[patient], 145 | list_collectedstats[n], 146 | patient) 147 | elif(self.str_collectortype.startswith("stream_to_file")): 148 | self.streamwriter.write(patient, list_collectedstats[n].stat) 149 | 150 | 151 | 152 | @abstractmethod 153 | def accum_statistics(self, prev_accum, new_stat, patient): 154 | ''' 155 | TODO:adddoc 156 | Outputs. 157 | - new_accum: TODO:adddoc. 158 | ''' 159 | pass 160 | 161 | 162 | @abstractmethod 163 | def get_statistics(self, returnvalue_of_collatefunction): 164 | ''' 165 | Same as `process_pieceofstream`. To support previous version. 166 | ''' 167 | pass #return self.process_pieceofstream(returnvalue_of_collatefunction) 168 | 169 | 170 | @abstractmethod 171 | def process_pieceofstream(self, returnvalue_of_collatefunction): 172 | ''' 173 | The abstract method that specifies the stat.collector's behaviour. 174 | Inputs. 175 | - TODO:adddoc, retval of collatefunc is by default, `x, list_patients, list_smallchunks`. But if you have overriden that output, ... 176 | Outputs. 177 | - list_liststats: a list of the same lenght as list_patients. Each element of the list is an instance of `Statistic`. 178 | ''' 179 | return self.get_statistics(returnvalue_of_collatefunction) #return self.process_pieceofstream(returnvalue_of_collatefunction) 180 | 181 | 182 | @abstractmethod 183 | def get_flag_finishcollecting(self): #TODO: make this method patient-wise, so that user would not need to work with self.dict_patient_to_liststats 184 | pass 185 | 186 | @abstractmethod 187 | def collate_stats_onfinishcollecting(self, patient, list_collectedstats): 188 | ''' 189 | This function is called when collecting is finished. 190 | This fucntion should collate and return the collected stats for the input patient and collected stats. 191 | Inputs. 192 | - patient: the patient, and instance of of utils.data.Patient. 193 | - list_collectedstats: the collected statistics, a list of objects as returned by `StreamCollector.process_pieceofstream`. 194 | ''' 195 | pass 196 | 197 | 198 | 199 | StatCollector = StreamCollector #to support previous versions. 200 | Statistic = ProcessedPiece #to support previous versions. 201 | 202 | -------------------------------------------------------------------------------- /pydmed/streamcollector.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import numpy as np 5 | import os, sys 6 | import math 7 | from pathlib import Path 8 | import re 9 | import time 10 | import random 11 | import multiprocessing as mp 12 | import openslide 13 | from multiprocessing import Process, Queue 14 | from abc import ABC, abstractmethod 15 | import matplotlib.pyplot as plt 16 | import pydmed.utils.output 17 | from pydmed.utils.output import StreamWriter 18 | 19 | 20 | class ProcessedPiece: 21 | def __init__(self, source_smallchunk, data=None, stat=None): 22 | ''' 23 | Note: stat and data are the same. The two arguments are added to support previous versions. 24 | Inputs: 25 | - data: the data part of the processed piece. 26 | - source_smallchunk: the `SmallChunk` from which the stat is collected. 27 | ''' 28 | #grab privates 29 | if(stat is None): 30 | self.stat = data 31 | elif(data is None): 32 | self.stat = stat 33 | else: 34 | raise Exception("When creating a `ProcessedPiece`, either the argument `data` or the argument `stat` must be set to zero.") 35 | self.source_smallchunk = source_smallchunk 36 | self.source_smallchunk.data = "None, to avoid memory leak" 37 | 38 | 39 | class StreamCollector(object): 40 | def __init__(self, lightdl, str_collectortype, flag_visualizestats=False, kwargs_streamwriter=None): 41 | ''' 42 | TODO:adddoc. str_collectortype can be "accum" or "saveall" or "stream_to_file" 43 | ''' 44 | #grab initargs 45 | self.lightdl = lightdl 46 | self.str_collectortype = str_collectortype 47 | self.flag_visualizestats = flag_visualizestats 48 | self.kwargs_streamwriter = kwargs_streamwriter 49 | #make internals 50 | self.dict_patient_to_liststats = {patient:[] for patient in self.lightdl.dataset.list_patients} 51 | self.dict_patient_to_accumstat = {patient:None for patient in self.lightdl.dataset.list_patients} 52 | self._queue_onfinish_collectedstats = mp.Queue() 53 | if(self.str_collectortype.startswith("stream_to_file")): 54 | self.streamwriter = StreamWriter(lightdl.dataset.list_patients, **kwargs_streamwriter) 55 | 56 | 57 | 58 | def start_collecting(self): 59 | 60 | 61 | #make the following line tunable 62 | os.system("taskset -a -cp {} {}".format("0,1,2,3,4", os.getpid())) 63 | self.lightdl.start() 64 | if(self.str_collectortype.startswith("stream_to_file")): 65 | self.streamwriter.start() 66 | # ~ print(" statcollector.streamwriter started ") 67 | 68 | 69 | #TODO:make the following lines tunable. 70 | time.sleep(1*10) 71 | os.system("taskset -cp {} {}".format("5,6,7", os.getpid())) 72 | time.sleep(1*10) 73 | 74 | time.sleep(2) #TODO:remove 75 | time_lastcheck = time.time()+2 #TODO:make tunable 76 | time_lastupdate_visstats = time.time()+5 #TODO:make tunable 77 | count = 0 78 | #plot the visstats if needed, 79 | # ~ if(self.flag_visualizestats == True): 80 | # ~ pass 81 | # ~ self.logfile = open(r"lotstat.txt","w+") 82 | #TODO:handle visualizestats 83 | # ~ fig = plt.figure() 84 | # ~ ax = fig.add_subplot(111) 85 | # ~ x = [u for u in range(len(list(self.dict_patient_to_liststats.keys())))] 86 | # ~ y = [len(self.dict_patient_to_liststats[pat]) for pat in self.dict_patient_to_liststats.keys()] 87 | # ~ line1, = ax.plot(x, y, 'r-') 88 | # ~ plt.show() 89 | while True: 90 | count +=1 91 | 92 | #get an instance from the DL ============ 93 | retval_dl = self.lightdl.get() 94 | flag_invalid_retvaldl = False 95 | if(isinstance(retval_dl, str)): 96 | if(retval_dl == pydmed.lightdl.PYDMEDRESERVED_DLRETURNEDLASTINSTANCE): 97 | flag_invalid_retvaldl = True 98 | 99 | #collect stat only if the retval is valid ===== 100 | if(flag_invalid_retvaldl == False): 101 | list_collectedstats = self.process_pieceofstream(retval_dl) 102 | list_patients = [st.source_smallchunk.patient for st in list_collectedstats] 103 | self._manage_stats(list_collectedstats, list_patients) 104 | 105 | 106 | #stop collecting if needed ====== 107 | if((time.time()-time_lastcheck) > 5):#TODO:make tunable 108 | time_lastcheck = time.time() 109 | if(self.get_flag_finishcollecting() == True): 110 | toret_onfinish_collectedstats = {} 111 | #colllate all statistics 112 | if(self.str_collectortype == "saveall"): 113 | for patient in self.lightdl.dataset.list_patients: 114 | toret_onfinish_collectedstats[patient] = self.collate_stats_onfinishcollecting(patient, self.dict_patient_to_liststats[patient]) 115 | elif(self.str_collectortype == "accum"): 116 | for patient in self.lightdl.dataset.list_patients: 117 | toret_onfinish_collectedstats[patient] = \ 118 | self.dict_patient_to_accumstat[patient] 119 | elif(self.str_collectortype.startswith("stream_to_file")): 120 | self.streamwriter.flush_and_close() 121 | 122 | self._onfinish_collectedstats = toret_onfinish_collectedstats 123 | if(self.str_collectortype.startswith("stream_to_file") == False): 124 | pass #self._queue_onfinish_collectedstats.put_nowait(toret_onfinish_collectedstats) 125 | time.sleep(3) #TODO:make tunable 126 | #stop the lightdl 127 | self.lightdl.pause_loading() 128 | break 129 | 130 | def get_finalstats(self): 131 | try: 132 | if(self.str_collectortype.startswith("stream_to_file") == False): 133 | toret = self._onfinish_collectedstats #self._queue_onfinish_collectedstats.get() 134 | return toret 135 | 136 | # ~ return self._onfinish_collectedstats 137 | except: 138 | print("Error in getting the final collected stats. Is the StreamCollector finished when you called `StreamCollector.get_finalstats`?") 139 | 140 | def _manage_stats(self, list_collectedstats, list_patients): 141 | for n in range(len(list_patients)): 142 | patient = list_patients[n] 143 | if(self.str_collectortype == "saveall"): 144 | self.dict_patient_to_liststats[patient].append(list_collectedstats[n]) 145 | elif(self.str_collectortype == "accum"): 146 | self.dict_patient_to_accumstat[patient] = self.accum_statistics(self.dict_patient_to_accumstat[patient], 147 | list_collectedstats[n], 148 | patient) 149 | elif(self.str_collectortype.startswith("stream_to_file")): 150 | self.streamwriter.write(patient, list_collectedstats[n].stat) 151 | 152 | 153 | 154 | @abstractmethod 155 | def accum_statistics(self, prev_accum, new_stat, patient): 156 | ''' 157 | TODO:adddoc 158 | Outputs. 159 | - new_accum: TODO:adddoc. 160 | ''' 161 | pass 162 | 163 | 164 | @abstractmethod 165 | def get_statistics(self, returnvalue_of_collatefunction): 166 | ''' 167 | Same as `process_pieceofstream`. To support previous version. 168 | ''' 169 | pass #return self.process_pieceofstream(returnvalue_of_collatefunction) 170 | 171 | 172 | @abstractmethod 173 | def process_pieceofstream(self, returnvalue_of_collatefunction): 174 | ''' 175 | The abstract method that specifies the stat.collector's behaviour. 176 | Inputs. 177 | - TODO:adddoc, retval of collatefunc is by default, `x, list_patients, list_smallchunks`. But if you have overriden that output, ... 178 | Outputs. 179 | - list_liststats: a list of the same lenght as list_patients. Each element of the list is an instance of `Statistic`. 180 | ''' 181 | return self.get_statistics(returnvalue_of_collatefunction) #return self.process_pieceofstream(returnvalue_of_collatefunction) 182 | 183 | 184 | @abstractmethod 185 | def get_flag_finishcollecting(self): #TODO: make this method patient-wise, so that user would not need to work with self.dict_patient_to_liststats 186 | pass 187 | 188 | @abstractmethod 189 | def collate_stats_onfinishcollecting(self, patient, list_collectedstats): 190 | ''' 191 | This function is called when collecting is finished. 192 | This fucntion should collate and return the collected stats for the input patient and collected stats. 193 | Inputs. 194 | - patient: the patient, and instance of of utils.data.Patient. 195 | - list_collectedstats: the collected statistics, a list of objects as returned by `StreamCollector.process_pieceofstream`. 196 | ''' 197 | pass 198 | 199 | 200 | 201 | StatCollector = StreamCollector #to support previous versions. 202 | Statistic = ProcessedPiece #to support previous versions. 203 | -------------------------------------------------------------------------------- /sample_notebooks/Sample3_Heatmap_for_WSIs/sample3_heatmap_for_WSIs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "import os\n", 11 | "import copy\n", 12 | "from abc import ABC, abstractmethod\n", 13 | "import math\n", 14 | "import copy\n", 15 | "from copy import deepcopy\n", 16 | "import PIL\n", 17 | "from skimage.color import rgb2gray\n", 18 | "from skimage.filters import threshold_otsu\n", 19 | "import torchvision\n", 20 | "import torchvision.models as torchmodels\n", 21 | "import torch.nn.functional as F\n", 22 | "import openslide\n", 23 | "import torch.utils.data\n", 24 | "\n", 25 | "list_pathstoadd = [\"../../\"]\n", 26 | "for path in list_pathstoadd:\n", 27 | " if(path not in sys.path):\n", 28 | " sys.path.append(path)\n", 29 | "import pydmed\n", 30 | "from pydmed.utils.data import *\n", 31 | "import pydmed.lightdl\n", 32 | "from pydmed.lightdl import *\n", 33 | "import pydmed.extensions.wsi\n", 34 | "import pydmed.streamcollector" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "#settings ====\n", 44 | "kernel_size = 2000" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 54 | "print(device)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "#make the model ====\n", 64 | "module_resnet = torchvision.models.resnet18(pretrained=True)\n", 65 | "list_modules = list(module_resnet.children())[0:-2]\n", 66 | "model = torch.nn.Sequential(*list_modules)\n", 67 | "\n", 68 | "def func_setpaddingmodes_for_conv2dlayers(module_input, str_paddingmode):\n", 69 | " '''\n", 70 | " Sets the padding mode of all conv2d modules, either on imediate children or non-imediate children.\n", 71 | " '''\n", 72 | " #get num_children\n", 73 | " num_children = 0\n", 74 | " for child in module_input.children():\n", 75 | " num_children += 1\n", 76 | " \n", 77 | " #base case, the module has no children ====\n", 78 | " if(num_children == 0):\n", 79 | " if(isinstance(module_input, torch.nn.Conv2d)):\n", 80 | " module_input.padding_mode = str_paddingmode\n", 81 | " return module_input\n", 82 | " \n", 83 | " #non-base case, loop over children ====\n", 84 | " for child in module_input.children():\n", 85 | " func_setpaddingmodes_for_conv2dlayers(child, str_paddingmode)\n", 86 | " return module_input\n", 87 | "\n", 88 | "model = func_setpaddingmodes_for_conv2dlayers(model, \"reflect\") #set paddingmode to \"reflect\"" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "#make dataset ===================\n", 98 | "rootdir = \"../../NonGit/Data/\"\n", 99 | "list_relativedirs = [\"1.svs\", \"2.svs\", \"3.svs\", \"4.svs\", \"5.svs\",\\\n", 100 | " \"6.svs\", \"7.svs\", \"8.svs\", \"9.svs\", \"10.svs\"]\n", 101 | "list_relativedirs.sort()\n", 102 | "#make a list of patients\n", 103 | "list_patients = []\n", 104 | "for fname in list_relativedirs:\n", 105 | " new_patient = Patient(\\\n", 106 | " int_uniqueid = list_relativedirs.index(fname),\n", 107 | " dict_records = {\n", 108 | " \"H&E\":Record(rootdir, fname, {\"resolution\":\"40x\"}),\n", 109 | " \"somelabel\": np.random.randint(0,4)\n", 110 | " }\n", 111 | " )\n", 112 | " list_patients.append(new_patient)\n", 113 | "#make the dataset\n", 114 | "dataset = pydmed.utils.data.Dataset(\"dataset_sample3Heatmap\", list_patients)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "tfms_onsmallchunkcollection = None \n", 124 | "#`tfms_onsmallchunkcollection` has to be None: https://github.com/amirakbarnejad/PyDmed/issues/3\n", 125 | "tfms_oncolate = torchvision.transforms.Compose([\n", 126 | " torchvision.transforms.ToPILImage(),\n", 127 | " torchvision.transforms.ToTensor(),\n", 128 | " torchvision.transforms.Normalize(\n", 129 | " mean=[0.485, 0.456, 0.406],\n", 130 | " std=[0.229, 0.224, 0.225]\n", 131 | " )\n", 132 | "])\n", 133 | "const_global_info = {\n", 134 | " \"attention_levelidx\":1,\n", 135 | " \"num_bigchunkloaders\":5,\n", 136 | " \"maxlength_queue_smallchunk\":np.inf,\n", 137 | " \"maxlength_queue_lightdl\":np.inf,\n", 138 | " \"interval_resched\": 2,\n", 139 | " \"core-assignment\":{\n", 140 | " \"lightdl\":None,\n", 141 | " \"smallchunkloaders\":None,\n", 142 | " \"bigchunkloaders\":None\n", 143 | " }\n", 144 | " }" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "dl_forheatmap = pydmed.extensions.wsi.SlidingWindowDL(\n", 154 | " intorfunc_opslevel = 1,\n", 155 | " kernel_size = kernel_size,\n", 156 | " stride = kernel_size,\n", 157 | " mininterval_loadnewbigchunk = 15,\n", 158 | " dataset = dataset,\\\n", 159 | " type_bigchunkloader=pydmed.extensions.wsi.SlidingWindowBigChunkLoader,\\\n", 160 | " type_smallchunkcollector=pydmed.extensions.wsi.SlidingWindowSmallChunkCollector,\\\n", 161 | " const_global_info=const_global_info,\\\n", 162 | " batch_size=1,\\\n", 163 | " tfms_onsmallchunkcollection=tfms_onsmallchunkcollection,\\\n", 164 | " tfms = tfms_oncolate,\n", 165 | " flag_grabqueue_onunsched = True\n", 166 | " )" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "import pydmed.streamcollector\n", 176 | "from pydmed.streamcollector import *\n", 177 | "\n", 178 | "class HeatmapStreamCollector(StreamCollector):\n", 179 | " def __init__(self, module_pipeline, device, *args, **kwargs):\n", 180 | " #grab privates\n", 181 | " self.module_pipeline = module_pipeline\n", 182 | " self.device = device\n", 183 | " #make other initial operations\n", 184 | " self.module_pipeline.to(device)\n", 185 | " self.module_pipeline.eval()\n", 186 | " self.num_calls_to_getflagfinished = 0\n", 187 | " super(HeatmapStreamCollector, self).__init__(*args, **kwargs)\n", 188 | " \n", 189 | " \n", 190 | " @abstractmethod\n", 191 | " def process_pieceofstream(self, retval_collatefunc):\n", 192 | " x, list_patients, list_smallchunks = retval_collatefunc\n", 193 | " with torch.no_grad():\n", 194 | " netout = \\\n", 195 | " self.module_pipeline(x.to(self.device))#[32x1x7x7]\n", 196 | " list_processedpiece = []\n", 197 | " for n in range(netout.shape[0]):\n", 198 | " tensor_piecen = netout[n,0,:,:].unsqueeze(0)\n", 199 | " str_piecen = pydmed.extensions.wsi.Tensor3DtoPdmcsvrow(\n", 200 | " tensor_piecen.detach().cpu().numpy(), list_smallchunks[n]\n", 201 | " )\n", 202 | " list_processedpiece.append(\n", 203 | " ProcessedPiece(\n", 204 | " data = str_piecen,\\\n", 205 | " source_smallchunk = list_smallchunks[n]\n", 206 | " )\n", 207 | " )\n", 208 | " return list_processedpiece\n", 209 | " \n", 210 | " @abstractmethod\n", 211 | " def get_flag_finishcollecting(self):\n", 212 | " self.num_calls_to_getflagfinished += 1\n", 213 | " is_dl_running = self.lightdl.is_dl_running()\n", 214 | " if((is_dl_running==False) and (self.lightdl.queue_lightdl.qsize()==0)):\n", 215 | " return True\n", 216 | " else:\n", 217 | " return False" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "statcollector = HeatmapStreamCollector(\n", 227 | " module_pipeline=model,\n", 228 | " device = device,\n", 229 | " lightdl = dl_forheatmap,\n", 230 | " str_collectortype = \"stream_to_file\",\n", 231 | " flag_visualizestats= False,\n", 232 | " kwargs_streamwriter = {\n", 233 | " \"rootpath\": \"./Output/GeneratedHeatmaps/\",\n", 234 | " \"fname_tosave\":None, \n", 235 | " \"waiting_time_before_flush\":3\n", 236 | " }\n", 237 | " )" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": { 244 | "scrolled": true 245 | }, 246 | "outputs": [], 247 | "source": [ 248 | "statcollector.start_collecting()" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "metadata": { 255 | "scrolled": true 256 | }, 257 | "outputs": [], 258 | "source": [ 259 | "#make the default raster to WSI\n", 260 | "default_wsitoraster = pydmed.extensions.wsi.DefaultWSIxyWHvaltoRasterPoints()\n", 261 | "\n", 262 | "list_paddedregions = []\n", 263 | "for patient in dataset.list_patients:\n", 264 | " print(\"idx_patient = {}\".format(patient.int_uniqueid))\n", 265 | " fname_pdmcsv = \"Output/GeneratedHeatmaps/patient_{}.csv\".format(\n", 266 | " patient.int_uniqueid\n", 267 | " )\n", 268 | " np_heatmap = pydmed.extensions.wsi.pdmcsvtoarray(\n", 269 | " fname_pdmcsv,\n", 270 | " default_wsitoraster.func_WSIxyWHval_to_rasterpoints,\n", 271 | " scale_upsampleraster = 1.0\n", 272 | " )\n", 273 | " plt.figure(figsize=(10,10))\n", 274 | " plt.imshow(np_heatmap[:,:,0], cmap=\"jet\")\n", 275 | " plt.colorbar()\n", 276 | " plt.axis(\"off\")\n", 277 | " plt.savefig(\n", 278 | " \"Output/FromNdarraytoImage/patient_{}.png\".format(patient.int_uniqueid),\n", 279 | " dpi=100, pad_inches=0, bbox_inches=\"tight\"\n", 280 | " )\n", 281 | " plt.show()\n", 282 | " \n", 283 | " np.save(\n", 284 | " \"Output/FromNdarraytoImage/patient_{}\".format(patient.int_uniqueid),\n", 285 | " np_heatmap\n", 286 | " )\n", 287 | " print(\"\\n\")\n", 288 | " " 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [] 297 | } 298 | ], 299 | "metadata": { 300 | "kernelspec": { 301 | "display_name": "Python 3", 302 | "language": "python", 303 | "name": "python3" 304 | }, 305 | "language_info": { 306 | "codemirror_mode": { 307 | "name": "ipython", 308 | "version": 3 309 | }, 310 | "file_extension": ".py", 311 | "mimetype": "text/x-python", 312 | "name": "python", 313 | "nbconvert_exporter": "python", 314 | "pygments_lexer": "ipython3", 315 | "version": "3.7.10" 316 | } 317 | }, 318 | "nbformat": 4, 319 | "nbformat_minor": 4 320 | } 321 | -------------------------------------------------------------------------------- /pydmed/utils/data.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | import os, sys 5 | import math 6 | from pathlib import Path 7 | import re 8 | import time 9 | import random 10 | import copy 11 | import multiprocessing as mp 12 | from multiprocessing import Process, Queue 13 | import pydmed.utils.minimath 14 | 15 | 16 | class Patient: 17 | def __init__(self, int_uniqueid, dict_records): 18 | ''' 19 | - int_id: the unique id of the patient, an integer. 20 | - dict_records: a dict of objects, where each object can be, e.g., "H&E":WSI(). 21 | ''' 22 | self.int_uniqueid = int_uniqueid 23 | self.dict_records = dict_records 24 | 25 | def __hash__(self): 26 | return self.int_uniqueid #TODO:adding datasetinfo to patient's uniqueid is safer. Because two datasets may have patitents with same ids. 27 | 28 | def __repr__(self): 29 | return "utils.data.Patient with unique id: {}".format(self.int_uniqueid) 30 | 31 | def __eq__(self, other): 32 | return (self.int_uniqueid == other.int_uniqueid) 33 | 34 | def __lt__(self, other): 35 | return self.int_uniqueid < other.int_uniqueid 36 | 37 | def __le__(self, other): 38 | return self.int_uniqueid <= other.int_uniqueid 39 | 40 | def __gt__(self, other): 41 | return self.int_uniqueid > other.int_uniqueid 42 | 43 | def __ge__(self, other): 44 | return self.int_uniqueid >= other.int_uniqueid 45 | 46 | class Record: 47 | def __init__(self, rootdir, relativedir, dict_infos): 48 | ''' 49 | - rootdir: the rootdirectory of the dataset. a string, e.g., "/usr/Dataset1/" 50 | - relativedir: the relative dir with respect to the rootdir, a string like "1010.svs" 51 | - dict_infos: a dictionary containing information about the WSI, e.g., zooming "40x", 52 | "20x", "10x", the date that the WSI is scanned, etc. 53 | ''' 54 | # ~ rootdir = "/media/user1/9894F11594F0F69A/Ak/Data/CCI_RecurrenceScore/" 55 | # ~ relativedir = "Gilbert2020-03-24/10101010.svs" 56 | 57 | self.rootdir = rootdir 58 | self.relativedir = relativedir 59 | self.dict_infos = dict_infos 60 | 61 | class Dataset: 62 | def __init__(self, str_dsname, list_patients): 63 | ''' 64 | - str_dsname: the name of the dataset, a string. 65 | - list_patients: a list whose elements are an instance of `Patient`. 66 | ''' 67 | self.str_dsname = str_dsname 68 | self.list_patients = list_patients 69 | for pat in self.list_patients: 70 | if(isinstance(pat, Patient) == False): 71 | raise Exception("The second argument of Dataset.__init__, i.e., list_patients "+\ 72 | " contains an object which is not an instance of Patient.") 73 | 74 | @staticmethod 75 | def balance_by_repeat(ds, func_getlabel_of_patient, newlen_each_class=None): 76 | ''' 77 | Repeats `Patients` in the dataset to make the labels balances. 78 | Inputs: 79 | - ds: TODO:adddoc 80 | - dict_patient_to_label: TODO:adddoc. 81 | - newlen_each_class: TODO:adddoc, if None is passed the lcm of frequencies would be used. 82 | ''' 83 | #make dict_patient_to_label ==== 84 | dict_patient_to_label = {patient:func_getlabel_of_patient(patient) for patient in ds.list_patients} 85 | #make dict_label_to_freq ==== 86 | numdigits_old_idx = len(str(max([patient.int_uniqueid for patient in ds.list_patients]))) 87 | list_labels = set([dict_patient_to_label[patient] for patient in dict_patient_to_label.keys()]) 88 | dict_label_to_freq = {label:0 for label in list_labels} 89 | for patient in dict_patient_to_label.keys(): 90 | label = dict_patient_to_label[patient] 91 | dict_label_to_freq[label] = dict_label_to_freq[label] + 1 92 | #if needed, set newlen_each_class to lcm of frequencies ========= 93 | if(newlen_each_class == None): 94 | list_freqs = list(set([dict_label_to_freq[label] for label in dict_label_to_freq.keys()])) 95 | newlen_each_class = pydmed.utils.minimath.lcm(list_freqs) 96 | #repeat patients to makde newds =================== 97 | list_patients_of_newds = [] 98 | for patient in dict_patient_to_label.keys(): 99 | label = dict_patient_to_label[patient] 100 | freq_of_label = dict_label_to_freq[label] 101 | repeatcount = int(newlen_each_class/freq_of_label) 102 | for idx_patient_copy in range(repeatcount): 103 | new_dict_records = copy.deepcopy(patient.dict_records) 104 | new_dict_records["TODO:packagename reserved, original patient"] = patient 105 | copy_of_patient = Patient(int_uniqueid = idx_patient_copy*(10**numdigits_old_idx)+patient.int_uniqueid,\ 106 | dict_records = new_dict_records) 107 | list_patients_of_newds.append(copy_of_patient) 108 | newds = Dataset(ds.str_dsname, list_patients_of_newds) 109 | return newds 110 | 111 | @staticmethod 112 | def splits_from(dataset, percentage_partitions): 113 | ''' 114 | Splits a dataset to different datasets, e.g., [training-validation-test]. 115 | Inputs: 116 | - dataset: the dataset, an instance of Dataset. 117 | - percentage_partitions: the percentage of the partitions, a list. 118 | ''' 119 | #get constants/values 120 | if(np.sum(percentage_partitions) != 100): 121 | raise Exception("The elements of `percentage_partitions` must sum up to 100.") 122 | num_chunks = len(percentage_partitions) 123 | list_patients = dataset.list_patients 124 | N = len(list_patients) 125 | #make random splits 126 | random.shuffle(list_patients) 127 | toret_list_patients = [] 128 | for percentage in percentage_partitions: 129 | picked_so_far = sum([len(u) for u in toret_list_patients]) 130 | size_partition = math.floor(percentage*N/100.0) 131 | idx_begin = picked_so_far 132 | idx_end = min(picked_so_far+size_partition, N) 133 | toret_list_patients.append(list_patients[idx_begin:idx_end]) 134 | #make datasets from list_patients 135 | toret = [Dataset(dataset.str_dsname, u) for u in toret_list_patients] 136 | return toret 137 | 138 | 139 | @staticmethod 140 | def _split_list(list_input, percentage_partitions): 141 | list_toret = [] 142 | for idx_percentage, percentage in enumerate(percentage_partitions): 143 | picked_so_far = sum([len(u) for u in list_toret]) 144 | size_partition = math.floor(percentage* len(list_input)/100.0) 145 | idx_begin = picked_so_far 146 | idx_end = min(picked_so_far+size_partition, len(list_input)) 147 | if(idx_percentage == (len(percentage_partitions)-1)): 148 | idx_end = len(list_input) 149 | list_toret.append(list_input[idx_begin:idx_end]) 150 | assert(len(list_input) == sum([len(u) for u in list_toret])) 151 | assert(set(list_input) == set.union(*[set(u) for u in list_toret])) 152 | return list_toret 153 | 154 | 155 | @staticmethod 156 | def labelbalanced_splits_from(dataset, percentage_partitions,\ 157 | func_getlabel_of_patient, verbose=False): 158 | ''' 159 | Splits a dataset to different datasets, e.g., [training-validation-test] such that all partitions 160 | have equal share from different classes. 161 | Inputs: 162 | - dataset: the dataset, an instance of Dataset. 163 | - percentage_partitions: the percentage of the partitions, a list. 164 | - func_get_function_name: the labeling function for which the split is balanced. 165 | ''' 166 | #get some constants/values ==== 167 | if(np.sum(percentage_partitions) != 100): 168 | raise Exception("The elements of `percentage_partitions` must sum up to 100.") 169 | num_chunks = len(percentage_partitions) 170 | list_patients = dataset.list_patients 171 | N = len(list_patients) 172 | #split patients based on label ====== 173 | possible_labels = list( 174 | set( 175 | [func_getlabel_of_patient(patient)\ 176 | for patient in dataset.list_patients] 177 | ) 178 | ) 179 | # ~ print("possible_labels = {}".format(possible_labels)) 180 | dict_label_to_listpatients = {label:[] for label in possible_labels} 181 | for patient in dataset.list_patients: 182 | label_of_patient = func_getlabel_of_patient(patient) 183 | dict_label_to_listpatients[label_of_patient].append(patient) 184 | #make splits from each class ======= 185 | dict_label_to_listpartitions = {label:None for label in possible_labels} 186 | for label in possible_labels: 187 | patients_of_class = dict_label_to_listpatients[label] 188 | random.shuffle(patients_of_class) 189 | dict_label_to_listpartitions[label] = Dataset._split_list(patients_of_class, percentage_partitions) 190 | #aggregate the splits of each class ==== 191 | list_toret = [[] for n in range(len(percentage_partitions))] 192 | for label in possible_labels: 193 | partitions_of_label = dict_label_to_listpartitions[label] 194 | for idx_partition in range(len(percentage_partitions)): 195 | list_toret[idx_partition] = list_toret[idx_partition] + partitions_of_label[idx_partition] 196 | toret = [Dataset(dataset.str_dsname, u) for u in list_toret] 197 | #make some assertions ===== 198 | set_union_of_splits = set.union(*[set(dataset.list_patients) for dataset in toret]) 199 | assert(set_union_of_splits == set(dataset.list_patients)) 200 | for i in range(len(toret)): 201 | for j in range(len(toret)): 202 | if(i != j): 203 | set_i = set(toret[i].list_patients) 204 | set_j = set(toret[j].list_patients) 205 | assert(set_i.isdisjoint(set_j)) 206 | assert(set_j.isdisjoint(set_i)) 207 | #report the label frequencies, in verbose mode ==== 208 | if(verbose == True): 209 | #TODO:HERE 210 | assert False 211 | return toret 212 | 213 | 214 | 215 | 216 | 217 | @staticmethod 218 | def create_onetoone(str_dsname, rootdir, imgsprefix,\ 219 | func_get_patientrecords, func_get_wsiinfos): 220 | ''' 221 | If there is a one to one mapping between patients and images (i.e. one image per patient) 222 | this function can create the dataset. 223 | Inputs. 224 | - str_dsname: name of the str_dsname, a string. 225 | - rootdir: rootdir of the dataset, a string. 226 | - imgsprefix: prefix of the images, e.g., "svs", "ndpi", ... . 227 | - func_get_patientrecords: a function that takes in the file name, and has to return 228 | dict_patientrecords (excluding the WSI). 229 | - func_get_wsiinfos: a function that takes in the file name, and has to return 230 | dict_wsiinfos. 231 | ''' 232 | #initial checks ========================== 233 | if(rootdir[-1]!="/"): 234 | raise Exception("Arguement: \n {} \n does not end with `/`") 235 | #get all file-names========================= 236 | #get the absolute fnames 237 | list_fnames = [] 238 | for fname in Path(rootdir).rglob("*.{}".format(imgsprefix)): 239 | list_fnames.append(os.path.abspath(fname)) 240 | #remove the rootdir from the beginning 241 | for idx_fname in range(len(list_fnames)): 242 | list_fnames[idx_fname] = list_fnames[idx_fname][len(rootdir)::] 243 | #sort fnames (to get consistent patient_names in different machines) 244 | list_fnames.sort() 245 | #make list_patients ================================= 246 | list_patients = [] 247 | count_createdpatients = 0 248 | for fname in list_fnames: 249 | new_record = Record(rootdir=rootdir,\ 250 | relativedir=fname,\ 251 | dict_infos=func_get_wsiinfos(fname)) 252 | dict_patientrecord = func_get_patientrecords(fname) 253 | dict_patientrecord["wsi"] = new_record 254 | new_patient = Patient(int_uniqueid = count_createdpatients,\ 255 | dict_records = dict_patientrecord) 256 | count_createdpatients += 1 257 | list_patients.append(new_patient) 258 | #make the Dataset ======== 259 | dataset = Dataset(str_dsname, list_patients) 260 | return dataset 261 | 262 | 263 | 264 | 265 | -------------------------------------------------------------------------------- /sample_notebooks/sample_1_train_classifier.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import sys\n", 11 | "import os\n", 12 | "import copy\n", 13 | "from abc import ABC, abstractmethod\n", 14 | "import math\n", 15 | "import copy\n", 16 | "from copy import deepcopy\n", 17 | "import PIL\n", 18 | "from skimage.color import rgb2gray\n", 19 | "from skimage.filters import threshold_otsu\n", 20 | "import torchvision\n", 21 | "import torchvision.models as torchmodels\n", 22 | "import torch.nn.functional as F\n", 23 | "import openslide\n", 24 | "import torch.utils.data\n", 25 | "\n", 26 | "list_pathstoadd = [\"../\"]\n", 27 | "for path in list_pathstoadd:\n", 28 | " if(path not in sys.path):\n", 29 | " sys.path.append(path)\n", 30 | "import pydmed\n", 31 | "from pydmed.utils.data import *\n", 32 | "import pydmed.lightdl\n", 33 | "from pydmed.lightdl import *" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "cuda:0\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 51 | "print(device)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 3, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "#make dataset (section 1 of tutorial) ===================\n", 61 | "rootdir = \"../NonGit/Data/\"\n", 62 | "list_relativedirs = [\n", 63 | " \"1.svs\", \"2.svs\", \"3.svs\", \"4.svs\", \"5.svs\",\n", 64 | " \"6.svs\", \"7.svs\", \"8.svs\", \"9.svs\", \"10.svs\"\n", 65 | "]\n", 66 | "list_relativedirs.sort()\n", 67 | "#make a list of patients\n", 68 | "list_patients = []\n", 69 | "for fname in list_relativedirs:\n", 70 | " new_patient = Patient(\\\n", 71 | " int_uniqueid = list_relativedirs.index(fname),\n", 72 | " dict_records = \\\n", 73 | " {\"H&E\":Record(rootdir, fname, {\"resolution\":\"40x\"}),\\\n", 74 | " \"HER2-status\": np.random.randint(0,4)}) #TODO:set real labels\n", 75 | " list_patients.append(new_patient)\n", 76 | "#make the dataset\n", 77 | "dataset = pydmed.utils.data.Dataset(\"myHER2dataset\", list_patients)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 4, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "def otsu_getpoint_from_foreground(fname_wsi):\n", 87 | " #settings =======\n", 88 | " scale_thumbnail = 0.01\n", 89 | " width_targetpatch = 5000 \n", 90 | " #extract the foreground =========================\n", 91 | " osimage = openslide.OpenSlide(fname_wsi)\n", 92 | " W, H = osimage.dimensions\n", 93 | " size_thumbnail = (int(scale_thumbnail*W), int(scale_thumbnail*H))\n", 94 | " pil_thumbnail = osimage.get_thumbnail(size_thumbnail)\n", 95 | " np_thumbnail = np.array(pil_thumbnail)\n", 96 | " np_thumbnail = np_thumbnail[:,:,0:3]\n", 97 | " np_thumbnail = rgb2gray(np_thumbnail)\n", 98 | " thresh = threshold_otsu(np_thumbnail)\n", 99 | " background = (np_thumbnail > thresh) + 0.0\n", 100 | " foreground = 1.0 - background\n", 101 | " #apply the padding on foreground\n", 102 | " w_padding_of_thumbnail = int(width_targetpatch * scale_thumbnail)\n", 103 | " foreground[0:w_padding_of_thumbnail, :] = 0\n", 104 | " foreground[-w_padding_of_thumbnail::, :] = 0\n", 105 | " foreground[: , 0:w_padding_of_thumbnail] = 0\n", 106 | " foreground[: , -w_padding_of_thumbnail::] = 0\n", 107 | " #select a random point =========================\n", 108 | " one_indices = np.where(foreground==1.0)\n", 109 | " i_oneindices, j_oneindices = one_indices[0].tolist(), one_indices[1].tolist()\n", 110 | " n = random.choice(range(len(i_oneindices)))\n", 111 | " i_selected, j_selected = i_oneindices[n], j_oneindices[n]\n", 112 | " assert(foreground[i_selected, j_selected] == 1)\n", 113 | " i_selected_realscale, j_selected_realscale =\\\n", 114 | " int(i_selected/scale_thumbnail), int(j_selected/scale_thumbnail)\n", 115 | " x, y = j_selected_realscale, i_selected_realscale\n", 116 | " return x,y \n", 117 | " \n", 118 | "class WSIRandomBigchunkLoader(BigChunkLoader):\n", 119 | " @abstractmethod\n", 120 | " def extract_bigchunk(self, last_message_fromroot):\n", 121 | " '''\n", 122 | " Extract and return a bigchunk. \n", 123 | " Please note that in this function you have access to\n", 124 | " self.patient and self.const_global_info.\n", 125 | " '''\n", 126 | " self.log(\"in time {} a BigChunk loaded.\\n\".format(time.time()))\n", 127 | " list_bigchunks = []\n", 128 | " for idx_bigpatch in range(5):\n", 129 | " #settings ==== \n", 130 | " flag_use_otsu = True\n", 131 | " #===\n", 132 | " wsi = self.patient.dict_records[\"H&E\"]\n", 133 | " fname_wsi = wsi.rootdir + wsi.relativedir\n", 134 | " osimage = openslide.OpenSlide(fname_wsi)\n", 135 | " w, h = 1000, 1000\n", 136 | " W, H = osimage.dimensions\n", 137 | " if(flag_use_otsu == True):\n", 138 | " rand_x, rand_y = otsu_getpoint_from_foreground(fname_wsi)\n", 139 | " rand_x, rand_y = int(rand_x-(w*0.5)), int(rand_y-(h*0.5))\n", 140 | " else:\n", 141 | " rand_x, rand_y = np.random.randint(0, W-w), np.random.randint(0, H-h)\n", 142 | " pil_bigchunk = osimage.read_region([rand_x, rand_y], 0, [w,h])\n", 143 | " np_bigchunk = np.array(pil_bigchunk)[:,:,0:3]\n", 144 | " bigchunk = BigChunk(data=np_bigchunk,\\\n", 145 | " dict_info_of_bigchunk={\"x\":rand_x, \"y\":rand_y},\\\n", 146 | " patient=self.patient)\n", 147 | " list_bigchunks.append(bigchunk)\n", 148 | " return list_bigchunks\n", 149 | "\n", 150 | "class WSIRandomSmallchunkCollector(SmallChunkCollector):\n", 151 | " def __init__(self, *args, **kwargs):\n", 152 | " super(WSIRandomSmallchunkCollector, self).__init__(*args, **kwargs)\n", 153 | " \n", 154 | " \n", 155 | " @abstractmethod \n", 156 | " def extract_smallchunk(self, call_count, list_bigchunks, last_message_fromroot):\n", 157 | " '''\n", 158 | " Extract and return a smallchunk. Please note that in this function you have access to \n", 159 | " self.bigchunk, self.patient, self.const_global_info.\n", 160 | " Inputs:\n", 161 | " - list_bigchunks: the list of extracted bigchunks.\n", 162 | " - Other arguemtns are not needed in this sample notebook.\n", 163 | " '''\n", 164 | " bigchunk = random.choice(list_bigchunks)\n", 165 | " W, H = bigchunk.data.shape[1], bigchunk.data.shape[0]\n", 166 | " w, h = 224, 224\n", 167 | " rand_x, rand_y = np.random.randint(0, W-w), np.random.randint(0, H-h)\n", 168 | " np_smallchunk = bigchunk.data[rand_y:rand_y+h, rand_x:rand_x+w, :]\n", 169 | " \n", 170 | " #wrap in SmallChunk\n", 171 | " smallchunk = SmallChunk(\n", 172 | " data=np_smallchunk,\\\n", 173 | " dict_info_of_smallchunk={\"x\":rand_x, \"y\":rand_y},\\\n", 174 | " dict_info_of_bigchunk = bigchunk.dict_info_of_bigchunk,\\\n", 175 | " patient=bigchunk.patient\n", 176 | " )\n", 177 | " return smallchunk " 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 5, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "#make dataloader ================== \n", 187 | "tfms = torchvision.transforms.Compose([\n", 188 | " torchvision.transforms.ToPILImage(),\n", 189 | " torchvision.transforms.Resize((224,224)),\n", 190 | " torchvision.transforms.ColorJitter(\n", 191 | " brightness=0,\n", 192 | " contrast=0,\n", 193 | " saturation=0.5,\n", 194 | " hue=[-0.1, 0.1]\n", 195 | " ),\n", 196 | " torchvision.transforms.ToTensor(),\\\n", 197 | " torchvision.transforms.Normalize(\n", 198 | " mean=[0.485, 0.456, 0.406],\n", 199 | " std=[0.229, 0.224, 0.225]\n", 200 | " )\n", 201 | "])\n", 202 | "const_global_info = {\n", 203 | " \"num_bigchunkloaders\":5,\n", 204 | " \"maxlength_queue_smallchunk\":40,\n", 205 | " \"maxlength_queue_lightdl\":200,\n", 206 | " \"interval_resched\": 10,\n", 207 | " \"core-assignment\":{\n", 208 | " \"lightdl\":None,\n", 209 | " \"smallchunkloaders\":None,\n", 210 | " \"bigchunkloaders\":None\n", 211 | " }\n", 212 | "}\n", 213 | "dataloader = LightDL(\n", 214 | " dataset = dataset,\n", 215 | " type_bigchunkloader = WSIRandomBigchunkLoader,\\\n", 216 | " type_smallchunkcollector = WSIRandomSmallchunkCollector,\\\n", 217 | " const_global_info = const_global_info,\\\n", 218 | " batch_size = 10,\n", 219 | " tfms = tfms,\n", 220 | " flag_grabqueue_onunsched = False\n", 221 | ")" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 6, 227 | "metadata": {}, 228 | "outputs": [ 229 | { 230 | "name": "stdout", 231 | "output_type": "stream", 232 | "text": [ 233 | "\n" 234 | ] 235 | } 236 | ], 237 | "source": [ 238 | "#build the model and optimizer====================\n", 239 | "model = torchmodels.resnet50(pretrained=True)\n", 240 | "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", 241 | "criterion = torch.nn.CrossEntropyLoss()\n", 242 | "model.to(device)\n", 243 | "model.train()\n", 244 | "print(\"\")" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 7, 250 | "metadata": { 251 | "scrolled": false 252 | }, 253 | "outputs": [ 254 | { 255 | "name": "stdout", 256 | "output_type": "stream", 257 | "text": [ 258 | " loading initial bigchunks, please wait ....\n", 259 | " bigchunk 0 from 5, please wait ...\n", 260 | "\n", 261 | " bigchunk 1 from 5, please wait ...\n", 262 | "\n", 263 | " bigchunk 2 from 5, please wait ...\n", 264 | "\n", 265 | " bigchunk 3 from 5, please wait ...\n", 266 | "\n", 267 | " bigchunk 4 from 5, please wait ...\n", 268 | "\n", 269 | "The initial loading of bigchunks took 16.85529065132141 seconds.\n", 270 | "************* batchcount = 10 ************\n", 271 | "************* batchcount = 20 ************\n", 272 | "************* batchcount = 30 ************\n", 273 | "************* batchcount = 40 ************\n", 274 | "************* batchcount = 50 ************\n", 275 | "************* batchcount = 60 ************\n", 276 | "************* batchcount = 70 ************\n", 277 | "************* batchcount = 80 ************\n", 278 | "************* batchcount = 90 ************\n", 279 | "************* batchcount = 100 ************\n", 280 | "************* batchcount = 110 ************\n", 281 | "************* batchcount = 120 ************\n", 282 | "************* batchcount = 130 ************\n", 283 | "************* batchcount = 140 ************\n", 284 | "************* batchcount = 150 ************\n", 285 | "************* batchcount = 160 ************\n", 286 | "************* batchcount = 170 ************\n", 287 | "************* batchcount = 180 ************\n", 288 | "************* batchcount = 190 ************\n", 289 | "************* batchcount = 200 ************\n", 290 | "************* batchcount = 210 ************\n", 291 | "************* batchcount = 220 ************\n", 292 | "************* batchcount = 230 ************\n", 293 | "************* batchcount = 240 ************\n", 294 | "************* batchcount = 250 ************\n", 295 | "************* batchcount = 260 ************\n", 296 | "************* batchcount = 270 ************\n", 297 | "************* batchcount = 280 ************\n", 298 | "************* batchcount = 290 ************\n", 299 | "************* batchcount = 300 ************\n", 300 | "************* batchcount = 310 ************\n", 301 | "************* batchcount = 320 ************\n", 302 | "************* batchcount = 330 ************\n", 303 | "************* batchcount = 340 ************\n", 304 | "************* batchcount = 350 ************\n", 305 | "************* batchcount = 360 ************\n", 306 | "************* batchcount = 370 ************\n", 307 | "************* batchcount = 380 ************\n", 308 | "************* batchcount = 390 ************\n", 309 | "************* batchcount = 400 ************\n", 310 | "************* batchcount = 410 ************\n", 311 | "************* batchcount = 420 ************\n", 312 | "************* batchcount = 430 ************\n", 313 | "************* batchcount = 440 ************\n", 314 | "************* batchcount = 450 ************\n", 315 | "************* batchcount = 460 ************\n", 316 | "************* batchcount = 470 ************\n", 317 | "************* batchcount = 480 ************\n", 318 | "************* batchcount = 490 ************\n", 319 | "************* batchcount = 500 ************\n" 320 | ] 321 | } 322 | ], 323 | "source": [ 324 | "#train the model ============================\n", 325 | "dataloader.start()\n", 326 | "time.sleep(20)\n", 327 | "tstart = time.time()\n", 328 | "batchcount = 0\n", 329 | "while True:\n", 330 | " x, list_patients, list_smallchunks = dataloader.get()\n", 331 | " y = torch.from_numpy(np.array([patient.dict_records['HER2-status']\n", 332 | " for patient in list_patients])).to(device)\n", 333 | " batchcount += 1\n", 334 | " optimizer.zero_grad()\n", 335 | " netout = model(x.to(device))\n", 336 | " loss = criterion(netout, y)\n", 337 | " loss.backward()\n", 338 | " if((batchcount%10)==0):\n", 339 | " print(\"************* batchcount = {} ************\".format(batchcount))\n", 340 | " if(batchcount>500): \n", 341 | " dataloader.pause_loading()\n", 342 | " break" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": null, 348 | "metadata": {}, 349 | "outputs": [], 350 | "source": [] 351 | } 352 | ], 353 | "metadata": { 354 | "kernelspec": { 355 | "display_name": "Python 3", 356 | "language": "python", 357 | "name": "python3" 358 | }, 359 | "language_info": { 360 | "codemirror_mode": { 361 | "name": "ipython", 362 | "version": 3 363 | }, 364 | "file_extension": ".py", 365 | "mimetype": "text/x-python", 366 | "name": "python", 367 | "nbconvert_exporter": "python", 368 | "pygments_lexer": "ipython3", 369 | "version": "3.7.10" 370 | } 371 | }, 372 | "nbformat": 4, 373 | "nbformat_minor": 4 374 | } 375 | -------------------------------------------------------------------------------- /pydmed/extensions/wsi.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import numpy as np 4 | from abc import ABC, abstractmethod 5 | import random 6 | import time 7 | import openslide 8 | import copy 9 | import torchvision 10 | import pydmed 11 | import pydmed.lightdl 12 | from pydmed import * 13 | from pydmed.lightdl import * 14 | from datetime import datetime 15 | 16 | 17 | def default_func_patient_to_fnameimage(patient_input): 18 | fname_wsi = os.path.join( 19 | patient_input.dict_records["H&E"].rootdir, 20 | patient_input.dict_records["H&E"].relativedir 21 | ) 22 | return fname_wsi 23 | 24 | 25 | def Tensor3DtoPdmcsvrow(np_input, smalchunk_input): 26 | ''' 27 | Converts a Tensor of shape [C x H x W] to pdmcsv format. 28 | Inputs. 29 | - np_input: a numpy array of shape [CxHxW]. 30 | - smallchunk_input: an instnace of SmallChunk, 31 | the smallchunk that that the tensor corresponds to. 32 | ''' 33 | chw = list(np_input.shape) 34 | str_toret = str([ 35 | smalchunk_input.dict_info_of_bigchunk["y"],\ 36 | smalchunk_input.dict_info_of_smallchunk["x"],\ 37 | smalchunk_input.dict_info_of_bigchunk["H"],\ 38 | smalchunk_input.dict_info_of_bigchunk["W"],\ 39 | smalchunk_input.dict_info_of_smallchunk["patch_levelidx"],\ 40 | smalchunk_input.dict_info_of_smallchunk["kernel_size"], 41 | smalchunk_input.dict_info_of_bigchunk["downsample_of_patchlevel"], 42 | chw[0], chw[1], chw[2] 43 | ])[1:-1]+ "," +\ 44 | str(np_input.flatten().tolist())[1:-1] + "\n" 45 | return str_toret 46 | 47 | def _float_or_nan(x): 48 | try: 49 | return float(x) 50 | except: 51 | return np.nan 52 | 53 | 54 | def pdmcsvtoarray(fname_pdmcsv, func_WSIxyWHval_to_rasterpoints, scale_upsampleraster=1.0): 55 | ''' 56 | Converts a pdmcsv file to an array. 57 | Inputs. 58 | - fname_pdmcsv: a string, the path-filename to the pdmcsv file. 59 | - outputsize: a float, the scale of output. The default value is 1.0 meaning the output 60 | array is not scaled. 61 | - func_WSIxyval_to_rasterpoints: a function. 62 | - Inputs. 63 | x: a number, as in one line of the pdm.csv file. 64 | y: a number, as in one line of the pdm.csv file. 65 | W: an integer. 66 | H: an integer. 67 | val: list of values. 68 | -Outputs. 69 | - list_x_onraster: 70 | - list_y_onraster: 71 | - list_val_onraster: 72 | ''' 73 | #read the file line-by-line ===== 74 | file_pdmcsv = open(fname_pdmcsv, 'r') 75 | count_line = 0 76 | dict_raster = {} 77 | while True: 78 | count_line += 1 79 | line = file_pdmcsv.readline() 80 | 81 | if not line: 82 | break 83 | 84 | list_numbers = line.split(",") 85 | for idx, u in enumerate(list_numbers): 86 | if(isinstance(list_numbers[idx], str)): 87 | if("None" in list_numbers[idx]): 88 | list_numbers[idx] = np.nan 89 | list_numbers = [_float_or_nan(u) for u in list_numbers] 90 | 91 | if(count_line == 1): 92 | H, W = list_numbers[2], list_numbers[3] 93 | H, W = int(H), int(W) 94 | 95 | #order: y,x,H,W,.... 96 | y, x = list_numbers[0], list_numbers[1] 97 | patch_levelidx = list_numbers[4] 98 | kernel_size = list_numbers[5] 99 | downsample_of_patchlevel = list_numbers[6] 100 | c = int(list_numbers[7]) 101 | h = int(list_numbers[8]) 102 | w = int(list_numbers[9]) 103 | val = list_numbers[10:] #np.mean(np.array([list_numbers[4:]])) 104 | 105 | print("W={} , H = {}".format(W, H)) 106 | 107 | #convert the points to raster space using the function 108 | list_x_onraster, list_y_onraster, val = func_WSIxyWHval_to_rasterpoints( 109 | x, y, W, H, 110 | patch_levelidx, 111 | kernel_size, 112 | downsample_of_patchlevel, 113 | c, h, w, val 114 | ) 115 | #np_x_onraster, np_y_onraster = np.array(list_x_onraster), np.array(list_y_onraster) 116 | for idx_rasterpoint in range(len(list_x_onraster)): 117 | dict_raster["({},{})".format( 118 | math.floor(list_x_onraster[idx_rasterpoint]), 119 | math.floor(list_y_onraster[idx_rasterpoint]) 120 | ) 121 | ] = val[idx_rasterpoint] 122 | #print("set raster at [{} , {}]".format(math.floor(list_x_onraster[idx_rasterpoint]), 123 | # math.floor(list_y_onraster[idx_rasterpoint]))) 124 | 125 | for temp_xy in dict_raster.keys(): 126 | print(temp_xy) 127 | 128 | 129 | #convert dict_raster to np.ndarray ===== 130 | list_allrasterx, list_allrastery = [], [] 131 | for u in dict_raster.keys(): 132 | x, y = u[1:-1].split(',') 133 | x, y = float(x), float(y) 134 | if(scale_upsampleraster > 1.0): 135 | x, y = scale_upsampleraster*x, scale_upsampleraster*y 136 | x, y = math.floor(x), math.floor(y) 137 | list_allrasterx.append(x); list_allrastery.append(y) 138 | list_allrasterx = list(set(list_allrasterx)) 139 | list_allrastery = list(set(list_allrastery)) 140 | list_allrasterx.sort(); list_allrastery.sort() 141 | max_x, max_y = np.max(list_allrasterx), np.max(list_allrastery) 142 | output_raster = np.zeros((len(list_allrastery), len(list_allrasterx), c)) 143 | num_totalloops = len(list(dict_raster.keys())) 144 | count = 0 145 | for u in dict_raster.keys(): 146 | count += 1 147 | if((count%10000) == 0): 148 | print(" >>>>>> Interpolation in progress: point {} out of {}. Please wait .... .".format(count, num_totalloops), end="\r") 149 | x,y = u[1:-1].split(',') 150 | x, y = float(x), float(y) 151 | if(scale_upsampleraster > 1.0): 152 | x, y = scale_upsampleraster*x, scale_upsampleraster*y 153 | x, y = math.floor(x), math.floor(y) 154 | output_raster[list_allrastery.index(y), list_allrasterx.index(x),:] = dict_raster[u] 155 | #fill-in the zeros if scale_upsample>1.0 156 | if(scale_upsampleraster > 1.0): 157 | list_output_scaled = [] 158 | for count_c in range(c): 159 | f = interp2d( 160 | np.array(list_allrasterx), 161 | np.array(list_allrastery), 162 | output_raster[:,:,count_c], kind='cubic' 163 | ) 164 | output_raster_scaled_forchannel = f( 165 | np.array([j for j in range(max_x)]), 166 | np.array([i for i in range(max_y)]) 167 | ) 168 | list_output_scaled.append(output_raster_scaled_forchannel) 169 | return np.stack(list_output_scaled, 2) 170 | return output_raster 171 | 172 | 173 | class DefaultWSIxyWHvaltoRasterPoints: 174 | def __init__(self): 175 | pass 176 | 177 | def func_WSIxyWHval_to_rasterpoints( 178 | self, x, y, W, H, 179 | patch_levelidx, kernel_size, 180 | downsample_of_patchlevel, 181 | c, h, w, val): 182 | try: 183 | 184 | if((c*h*w) != len(val)): 185 | print("Warning: the line in the csv file is of length {} which is not equal to CxHxW.".format(len(val))) 186 | print(" [c,h,w] = [{},{},{}]".format(c,h,w)) 187 | else: 188 | pass 189 | #print("OK: ...") 190 | 191 | assert(isinstance(val, list)) 192 | assert((c*h*w)== len(val)) 193 | np_val = np.reshape(val, [c,h,w]) 194 | 195 | size_blockonraster = h #np.sqrt(len(val)) 196 | scale_wsi_to_raster = kernel_size/size_blockonraster 197 | x_onraster = (x+0.0)/scale_wsi_to_raster 198 | y_onraster = (y+0.0)/scale_wsi_to_raster 199 | #make list_x_onraster and list_y_onraster ====== 200 | np_x_onraster = np.array([[j for j in range(int(size_blockonraster))]\ 201 | for i in range(int(size_blockonraster))]).flatten()+x_onraster 202 | np_y_onraster = np.array([[i for j in range(int(size_blockonraster))]\ 203 | for i in range(int(size_blockonraster))]).flatten()+y_onraster 204 | list_x_onraster = np_x_onraster.tolist() 205 | list_y_onraster = np_y_onraster.tolist() 206 | toret_val = [] 207 | for i in range(h): 208 | for j in range(w): 209 | toret_val.append(np_val[:,i,j]) 210 | return list_x_onraster, list_y_onraster, toret_val 211 | except: 212 | size_blockonraster = h #np.sqrt(len(val)) 213 | scale_wsi_to_raster = kernel_size/size_blockonraster 214 | x_onraster = (x+0.0)/scale_wsi_to_raster 215 | y_onraster = (y+0.0)/scale_wsi_to_raster 216 | #make list_x_onraster and list_y_onraster ====== 217 | np_x_onraster = np.array([[j for j in range(int(size_blockonraster))]\ 218 | for i in range(int(size_blockonraster))]).flatten()+x_onraster 219 | np_y_onraster = np.array([[i for j in range(int(size_blockonraster))]\ 220 | for i in range(int(size_blockonraster))]).flatten()+y_onraster 221 | list_x_onraster = np_x_onraster.tolist() 222 | list_y_onraster = np_y_onraster.tolist() 223 | toret_val = [] 224 | for i in range(h): 225 | for j in range(w): 226 | toret_val.append(0.0) 227 | return list_x_onraster, list_y_onraster, toret_val 228 | 229 | class SlidingWindowSmallChunkCollector(pydmed.lightdl.SmallChunkCollector): 230 | def __init__(self, *args, **kwargs): 231 | ''' 232 | Inputs: 233 | - mode_trmodainortest (in const_global_info): a strings in {"train" and "test"}. 234 | We need this mode because, e.g., colorjitter is different in training and testing phase. 235 | ''' 236 | super(SlidingWindowSmallChunkCollector, self).__init__(*args, **kwargs) 237 | if("mode_trainortest" in kwargs["const_global_info"].keys()): 238 | self.mode_trainortest = kwargs["const_global_info"]["mode_trainortest"] 239 | else: 240 | self.mode_trainortest = "test" 241 | assert(self.mode_trainortest in ["train", "test"]) 242 | self.flag_unschedme = False 243 | #grab privates 244 | self.tfms_onsmallchunkcollection = self.const_global_info["pdmreserved_tfms_onsmallchunkcollection"] 245 | # ~ \ 246 | # ~ torchvision.transforms.Compose([ 247 | # ~ torchvision.transforms.ToPILImage(),\ 248 | # ~ torchvision.transforms.ToTensor(),\ 249 | # ~ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],\ 250 | # ~ std=[0.229, 0.224, 0.225]) 251 | # ~ ]) 252 | 253 | def slice_by_slidingwindow(self, W, kernel_size, stride): 254 | ''' 255 | Slices the length `W` by `kernel_size` and `stride`. 256 | Outputs the number of shifts. 257 | ''' 258 | if((stride%(W-kernel_size)) == 0): 259 | toret = int((W-kernel_size)/stride) + 1 260 | else: 261 | toret = math.floor((W-kernel_size)/stride) + 2 262 | return toret 263 | 264 | @abstractmethod 265 | def extract_smallchunk(self, call_count, bigchunk, last_message_fromroot): 266 | ''' 267 | Extract and return a smallchunk. Please note that in this function you have access to 268 | self.bigchunk, self.patient, self.const_global_info. 269 | Inputs: 270 | - list_bigchunks: the list of extracted bigchunks. 271 | '''"list_polygons" 272 | #list of statuses ====== 273 | status_busy = "busy" 274 | status_idle = "idle" 275 | status_idlefinished = "idlefinished" #========= 276 | 277 | 278 | try: 279 | #exit the call if needed. 280 | if(self.flag_unschedme == True): 281 | return None 282 | 283 | #handle the case where the returned BigChunk is None. 284 | if(isinstance(bigchunk, str)): 285 | assert(bigchunk == "None-Bigchunk") 286 | if(call_count == 0): 287 | self.set_status(status_idlefinished) 288 | self.set_checkpoint({"idx_bigrow":np.inf}) 289 | return None 290 | 291 | #if callcount is zero, increase the checkpoint by 1 292 | if(call_count == 0): 293 | checkpoint = self.get_checkpoint() 294 | if(checkpoint == None): 295 | self.set_checkpoint({"idx_bigrow":1}) 296 | else: 297 | self.set_checkpoint({"idx_bigrow":checkpoint["idx_bigrow"]+1}) 298 | self.set_status(status_busy) 299 | self.flag_unschedme = False 300 | 301 | #extract fields from const_global_info ==== 302 | intorfunc_opslevel = self.const_global_info["pdmreserved_intorfunc_opslevel"] 303 | kernel_size = self.const_global_info["pdmreserved_kernel_size"] 304 | stride = self.const_global_info["pdmreserved_stride"] 305 | func_patient_to_fnameimage = \ 306 | self.const_global_info["pdmreserved_func_patient_to_fnameimage"] 307 | if(isinstance(intorfunc_opslevel, int) == True): 308 | attention_levelidx = intorfunc_opslevel 309 | else: 310 | attention_levelidx = intorfunc_opslevel(self.patient) 311 | 312 | 313 | #attention_levelidx = attention_levelidx 314 | W, H = bigchunk.data.shape[1], bigchunk.data.shape[0] 315 | #osimage.level_dimensions[self.const_global_info["attention_levelidx"]] 316 | w, h = H+0, H+0 317 | x_begin = int(call_count*w) 318 | x_end = x_begin + w 319 | num_cols = self.slice_by_slidingwindow(W, kernel_size, stride) 320 | flag_auxlastcol = False 321 | vertbar_overlaptheprevpatch = "None" 322 | if(x_end > W): 323 | prev_x_end = x_end-stride 324 | x_end = W 325 | x_begin = W-w 326 | vertbar_overlaptheprevpatch = kernel_size-(x_end-prev_x_end) 327 | flag_auxlastcol = True 328 | 329 | WSI_H = bigchunk.dict_info_of_bigchunk["WSI_H"] 330 | bigchunk_numbigrows = bigchunk.dict_info_of_bigchunk["num_bigrows"] 331 | bigchunk_idxbigrow = bigchunk.dict_info_of_bigchunk["idx_bigrow"] 332 | flag_lastbigchunk = (bigchunk_idxbigrow == (bigchunk_numbigrows-1)) 333 | #(bigchunk.dict_info_of_bigchunk["y"]+2*h) > WSI_H 334 | if(np.random.rand() < 0.2): 335 | str_now = datetime.now().strftime("%d/%m/%Y %H:%M:%S").replace(" ", "-") 336 | str_now = str_now.replace("/","-") 337 | print("Please wait. SlidingWindowDL is still working ..... (printed on {})".format(str_now), end="\r") 338 | pass 339 | 340 | if(call_count > (num_cols-1)): 341 | #x out of boundary 342 | if(flag_lastbigchunk == False): 343 | self.flag_unschedme = True #next calls will return immediately. 344 | self.set_status(status_idle) 345 | return None 346 | elif(flag_lastbigchunk == True): 347 | self.flag_unschedme = True #next calls will return immediately. 348 | self.set_status(status_idlefinished) 349 | return None 350 | else: 351 | #X within boundary ==== 352 | np_smallchunk = bigchunk.data[:, x_begin:x_end, :] 353 | #apply the transformation =========== 354 | if(self.tfms_onsmallchunkcollection != None): 355 | toret = self.tfms_onsmallchunkcollection(np_smallchunk) 356 | toret = toret.cpu().detach().numpy() #[3 x 224 x 224] 357 | toret = np.transpose(toret, [1,2,0]) #[224 x 224 x 3] 358 | else: 359 | toret = np_smallchunk 360 | #wrap in SmallChunk 361 | smallchunk = SmallChunk(data=toret,\ 362 | dict_info_of_smallchunk={ 363 | "x":x_begin, "y":0,\ 364 | "patch_levelidx":attention_levelidx, 365 | "kernel_size":kernel_size, 366 | "flag_auxlastcol":flag_auxlastcol, 367 | "vertbar_overlaptheprevpatch":vertbar_overlaptheprevpatch 368 | },\ 369 | dict_info_of_bigchunk = bigchunk.dict_info_of_bigchunk,\ 370 | patient=bigchunk.patient 371 | ) 372 | return smallchunk 373 | except Exception as e: 374 | print("An exception occurred when collecting smallchunk.") 375 | print(str(e)) 376 | return None 377 | 378 | 379 | class SlidingWindowBigChunkLoader(pydmed.lightdl.BigChunkLoader): 380 | def slice_by_slidingwindow(self, W, kernel_size, stride): 381 | ''' 382 | Slices the length `W` by `kernel_size` and `stride`. 383 | Outputs the number of shifts. 384 | ''' 385 | if((stride%(W-kernel_size)) == 0): 386 | toret = int((W-kernel_size)/stride) + 1 387 | else: 388 | toret = math.floor((W-kernel_size)/stride) + 2 389 | return toret 390 | 391 | @abstractmethod 392 | def extract_bigchunk(self, last_message_fromroot): 393 | ''' 394 | Extract and return a bigchunk. 395 | Please note that in this function you have access to 396 | self.patient and self.const_global_info. 397 | ''' 398 | try: 399 | #get `idx_bigrow` to be extracted ===== 400 | checkpoint = self.get_checkpoint() 401 | if(checkpoint == None): 402 | idx_bigrow = 0 403 | else: 404 | idx_bigrow = checkpoint["idx_bigrow"] 405 | 406 | #extract fields from const_global_info ==== 407 | intorfunc_opslevel = self.const_global_info["pdmreserved_intorfunc_opslevel"] 408 | kernel_size = self.const_global_info["pdmreserved_kernel_size"] 409 | stride = self.const_global_info["pdmreserved_stride"] 410 | func_patient_to_fnameimage = self.const_global_info["pdmreserved_func_patient_to_fnameimage"] 411 | 412 | #compute some constants ==== 413 | fname_wsi = func_patient_to_fnameimage(self.patient) #os.path.join(wsi.rootdir, wsi.relativedir) 414 | osimage = openslide.OpenSlide(fname_wsi) 415 | if(isinstance(intorfunc_opslevel, int) == True): 416 | attention_levelidx = intorfunc_opslevel 417 | else: 418 | attention_levelidx = intorfunc_opslevel(self.patient) 419 | w, h = kernel_size, kernel_size #in the taget level 420 | W, H = osimage.level_dimensions[attention_levelidx] #size in the target level 421 | downsample_of_patchlevel = osimage.level_downsamples[attention_levelidx] 422 | num_bigrows = self.slice_by_slidingwindow(H, kernel_size, stride) 423 | 424 | #extract the target row ==== 425 | y_begin = int(stride*idx_bigrow) #size in the target level 426 | y_begin_at_level0 = int(y_begin*osimage.level_downsamples[attention_levelidx]) 427 | if(y_begin_at_level0 < 0): 428 | y_begin_at_level0 = 0 429 | if((num_bigrows-1) < idx_bigrow): 430 | return "None-Bigchunk" #it happens when a done case is loaded by schedule. 431 | y_end = y_begin + h 432 | flag_from_auxbigrow = False 433 | horizbar_overlaptheprevpatch = "None" 434 | if(y_end > H): 435 | #refine the range for the last bigrow 436 | old_y_end, old_y_begin = y_end+0.0, y_begin+0.0 437 | prev_y_end, prev_y_begin = old_y_end-stride+0.0, old_y_begin-stride+0.0 438 | y_end = H-1 439 | y_begin = H-kernel_size-1 440 | y_begin_at_level0 = int(y_begin*osimage.level_downsamples[attention_levelidx]) 441 | flag_from_auxbigrow = True 442 | horizbar_overlaptheprevpatch = kernel_size - (y_end-prev_y_end) 443 | pil_bigchunk = osimage.read_region( 444 | [0, y_begin_at_level0], 445 | attention_levelidx, 446 | [W,h] 447 | ) 448 | np_bigchunk = np.array(pil_bigchunk)[:,:,0:3] 449 | self.patient.dict_records["precomputed_opsimage"] = "none" 450 | patient_without_foregroundmask = copy.deepcopy(self.patient) 451 | for k in patient_without_foregroundmask.dict_records.keys(): 452 | if(k.startswith("precomputed")): 453 | patient_without_foregroundmask.dict_records[k] = None 454 | bigchunk = BigChunk(data=np_bigchunk,\ 455 | dict_info_of_bigchunk={ 456 | "W":W, "H":H, "x":0, "y":y_begin, 457 | "WSI_W":W, "WSI_H":H, 458 | "downsample_of_patchlevel":downsample_of_patchlevel, 459 | "num_bigrows":num_bigrows, 460 | "idx_bigrow":idx_bigrow, 461 | "flag_from_auxbigrow":flag_from_auxbigrow, 462 | "horizbar_overlaptheprevpatch": horizbar_overlaptheprevpatch 463 | },\ 464 | patient=patient_without_foregroundmask 465 | ) 466 | return bigchunk 467 | except Exception as exception: 468 | print("extractbigchunk failed for patient {}.".format(self.patient)) 469 | print(exception) 470 | return "None-Bigchunk" 471 | 472 | class SlidingWindowDL(pydmed.lightdl.LightDL): 473 | def __init__( 474 | self, intorfunc_opslevel, kernel_size, 475 | stride, mininterval_loadnewbigchunk, 476 | tfms_onsmallchunkcollection, func_patient_to_fnameimage = None, 477 | *args, **kwargs): 478 | ''' 479 | Inputs. 480 | - intorfunc_opslevel: it can be either an integer, or a function. 481 | This argument specifies the level of the openslide image 482 | from which the patches are extracted. 483 | If it is an integer, e.g., 0, the DL will return from level 0. 484 | If it is a function, it has to take in a patient and return 485 | the intended level based on the input patient. 486 | - kernel_size: an integer, the width of the sliding windon. 487 | - stride: stride of the sliding window, an integer. 488 | - func_patient_to_fnameimage: a function. 489 | This function has to take in a `Patient` and return the aboslute path 490 | of the image (or WSI). 491 | - mininterval_loadnewbigchunk: a floating point number, minimum time (in seconds) 492 | between loading two bigchunks. 493 | This number depends on how big each `BigChunk` is as well as system specs. 494 | - tfms_onsmallchunkcollection: a callable object, the transformations to be applied to each SmallChunk (i.e. each tile). 495 | 496 | 497 | ''' 498 | super(SlidingWindowDL, self).__init__(*args, **kwargs) 499 | #grab privates ==== 500 | kwargs["type_bigchunkloader"] = SlidingWindowBigChunkLoader 501 | kwargs["type_smallchunkcollector"] = SlidingWindowSmallChunkCollector 502 | self.time_lasteffective_sched = None 503 | self.list_itwaslastbigchunk = [] 504 | self._dict_patient_to_lastschedtime = { 505 | patient:None for patient in self.dataset.list_patients 506 | } #to avoid unscheduling right after scheduling. 507 | #place the input arguments within `const_global_info` 508 | self.const_global_info["pdmreserved_intorfunc_opslevel"] = intorfunc_opslevel 509 | self.const_global_info["pdmreserved_kernel_size"] = kernel_size 510 | self.const_global_info["pdmreserved_stride"] = stride 511 | if(func_patient_to_fnameimage != None): 512 | self.const_global_info["pdmreserved_func_patient_to_fnameimage"] = func_patient_to_fnameimage 513 | else: 514 | self.const_global_info["pdmreserved_func_patient_to_fnameimage"] = default_func_patient_to_fnameimage 515 | self.const_global_info["pdmreserved_mininterval_loadnewbigchunk"] = mininterval_loadnewbigchunk 516 | self.const_global_info["pdmreserved_tfms_onsmallchunkcollection"] = tfms_onsmallchunkcollection 517 | 518 | def initial_schedule(self): 519 | #Default is to choose randomly from dataset. 520 | toret = random.choices( 521 | self.dataset.list_patients,\ 522 | k=self.const_global_info["num_bigchunkloaders"] 523 | ) 524 | for patient in toret: 525 | self._dict_patient_to_lastschedtime[patient] = time.time() 526 | return toret 527 | 528 | def schedule(self): 529 | ''' 530 | This function is called when schedulling a new patient, i.e., loading a new BigChunk. 531 | This function has to return: 532 | - patient_toremove: the patient to remove, an instance of `utils.data.Patient`. 533 | - patient_toload: the patient to load, an instance of `utils.data.Patient`. 534 | In this function, you have access to the following fields: 535 | - self.dict_patient_to_schedcount: given a patient, returns the number of times the patients has been schedulled in dl, a dictionary. 536 | - self.list_loadedpatients: 537 | - self.list_waitingpatients: 538 | - TODO: add more fields here to provide more flexibility. For instance, total time that the patient have been loaded on DL. 539 | ''' 540 | #list of statuses ====== 541 | status_busy = "busy" 542 | status_idle = "idle" 543 | status_idlefinished = "idlefinished" #========= 544 | try: 545 | #see if any BigChunkLoader process is still running ==== 546 | list_activesubprocs = list(self.active_subprocesses) 547 | for subproc in list_activesubprocs: 548 | if(subproc.get_flag_bigchunkloader_terminated() == False): 549 | return None, None 550 | 551 | #see if all patients are done ==== 552 | #TODO:print? print("set of it was last bigchunk = {}\n".format(set(self.list_itwaslastbigchunk))) 553 | if(set(self.list_itwaslastbigchunk) == set(self.dataset.list_patients)): 554 | print("\n =================== dl's job is done! ========================") 555 | # self._queue_imdone.put_nowait("imdone") 556 | #print(" returned None, None (case imdone)") 557 | return pydmed.lightdl.PYDMEDRESERVED_HALTDL, None 558 | 559 | #see if current time is not too close to last effective schedule ====== 560 | if(self.time_lasteffective_sched == None): 561 | self.time_lasteffective_sched = time.time() 562 | mininterval_loadnewbigchunk = \ 563 | self.const_global_info["pdmreserved_mininterval_loadnewbigchunk"] 564 | if((time.time()-self.time_lasteffective_sched) < mininterval_loadnewbigchunk): 565 | return None, None #let the loaded subprocesses continue working 566 | 567 | #get initial fields ============================== 568 | list_loadedpatients = self.get_list_loadedpatients() 569 | list_activesubprocs = list(self.active_subprocesses) 570 | list_statuses = [subproc.get_status()\ 571 | for subproc in list_activesubprocs] 572 | 573 | #update list_itwaslastbigchunk 574 | for idx_status, status in enumerate(list_statuses): 575 | if(status == status_idlefinished): 576 | patient_itwaslastbigchunk = list_activesubprocs[idx_status].patient 577 | if((patient_itwaslastbigchunk in self.list_itwaslastbigchunk) == False): 578 | self.list_itwaslastbigchunk.append(patient_itwaslastbigchunk) 579 | 580 | #find a candiate to unschedule (with priority to idle+unfinished) ====== 581 | flag_exists_idlebutnotfinished = False 582 | for idx_status, status in enumerate(list_statuses): 583 | if(status == status_idle): 584 | flag_exists_idlebutnotfinished = True 585 | flag_foundcandidate_unsched = False 586 | for idx_status, status in enumerate(list_statuses): 587 | if(flag_exists_idlebutnotfinished == True): 588 | criteria_unsched = (status == status_idle) 589 | else: 590 | criteria_unsched = (status==status_idlefinished) 591 | 592 | if(criteria_unsched == True): 593 | patient_unschedcandidate = list_activesubprocs[idx_status].patient 594 | if((time.time()-\ 595 | self._dict_patient_to_lastschedtime[patient_unschedcandidate])>1.0): 596 | flag_foundcandidate_unsched = True 597 | idx_subproc_toremove = idx_status 598 | break 599 | 600 | if(flag_foundcandidate_unsched == False): 601 | #if no unsched candidates were found, return None, None 602 | return None, None #let the loaded subprocesses continue working 603 | elif(flag_foundcandidate_unsched == True): 604 | #find a candidate to schedule ===== 605 | list_waitingpatients = self.get_list_waitingpatients() 606 | for patient in self.list_itwaslastbigchunk: 607 | if(patient in list_waitingpatients): 608 | list_waitingpatients.remove(patient) 609 | 610 | patient_toremove = list_loadedpatients[idx_subproc_toremove] 611 | waitingpatients_schedcount = [self.get_schedcount_of(patient)\ 612 | for patient in list_waitingpatients] 613 | if(len(list_waitingpatients)>0): 614 | patient_toload = random.choice(list_waitingpatients) 615 | else: 616 | #load a new patient, which indeed won't return any big/small chunk. 617 | list_waitingpatients = self.get_list_waitingpatients() 618 | patient_toload = random.choice(list_waitingpatients) 619 | 620 | #check if the sched/unsched is useful (i.e. not swapping two finished patients) 621 | flag_usefule_sched = not( 622 | (patient_toload in self.list_itwaslastbigchunk) and\ 623 | (patient_toremove in self.list_itwaslastbigchunk) 624 | ) 625 | if(flag_usefule_sched == True): 626 | self.time_lasteffective_sched = time.time() 627 | self._dict_patient_to_lastschedtime[patient_toload] = time.time() 628 | return patient_toremove, patient_toload 629 | else: 630 | return None, None 631 | except Exception as e: 632 | print("exception in schedule.") 633 | print(str(e)) 634 | print(" but returned None, None (case 3)") 635 | return None, None 636 | -------------------------------------------------------------------------------- /pydmed/lightdl.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ''' 5 | General TODO:s 6 | - replace os.nice and taskset with cross-platform counterparts. 7 | ''' 8 | 9 | 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import os, sys 13 | import psutil 14 | from pathlib import Path 15 | import re 16 | import time 17 | import random 18 | import multiprocessing as mp 19 | import subprocess 20 | from abc import ABC, abstractmethod 21 | import openslide 22 | import torch 23 | import torchvision 24 | import torchvision.models as models 25 | from multiprocessing import Process, Queue 26 | import pydmed.utils.multiproc 27 | from pydmed.utils.multiproc import * 28 | 29 | 30 | ''' 31 | Global enumerations. 32 | As enum is only supported for python 3.4+, pydmed uses some global variables. 33 | ''' 34 | PYDMEDRESERVED_HALTDL = "PYDMEDRESERVED_HALTDL" 35 | PYDMEDRESERVED_DLRETURNEDLASTINSTANCE = "PYDMEDRESERVED_DL_RETURNED_LAST_INSTANCE" 36 | 37 | def get_default_constglobinf(): 38 | toret = { 39 | "num_bigchunkloaders":10, 40 | "maxlength_queue_smallchunk":100, 41 | "maxlength_queue_lightdl":10000, 42 | "interval_resched": 10, 43 | "core-assignment":{"lightdl":None, 44 | "smallchunkloaders":None, 45 | "bigchunkloaders":None} 46 | } 47 | return toret 48 | 49 | 50 | class BigChunk: 51 | def __init__(self, data, dict_info_of_bigchunk, patient): 52 | ''' 53 | Implementation of a "Big Data Chunk". 54 | Inputs: 55 | - data: the data part (e.g., a patch of size 1000x1000), an instance of numpy.ndarray. 56 | - dict_info_of_bigchunk: a dictionary containing information about the bigchunk. 57 | It may include, e.g., the (left,top) position of the big patch, etc. 58 | - patient: an instance of `utils.data.Patient`. 59 | ''' 60 | #grab privates 61 | self.data = data 62 | self.dict_info_of_bigchunk = dict_info_of_bigchunk 63 | self.patient = patient 64 | 65 | class SmallChunk: 66 | def __init__(self, data, dict_info_of_smallchunk, dict_info_of_bigchunk, patient): 67 | ''' 68 | Implementation of a "Big Data Chunk". 69 | Inputs: 70 | - data: the data part (e.g., a patch of size 1000x1000), an instance of numpy.ndarray. 71 | - dict_info_of_smallchunk: a dictionary containing information about the bigchunk. 72 | It may include, e.g., the (left,top) position of the small patch, etc. 73 | - dict_info_of_bigchunk: a dictionary containing information about the bigchunk. 74 | It may include, e.g., the (left,top) position of the big patch, etc. 75 | - patient: an instance of `utils.data.Patient`. 76 | ''' 77 | #grab privates 78 | self.data = data 79 | self.dict_info_of_smallchunk = dict_info_of_smallchunk 80 | self.dict_info_of_bigchunk = dict_info_of_bigchunk 81 | self.patient = patient 82 | 83 | 84 | class BigChunkLoader(mp.Process): 85 | def __init__(self, patient, queue_bigchunk, const_global_info, queue_logs, old_checkpoint, last_message_from_root): 86 | ''' 87 | Inputs: 88 | - patient: an instance of `Patient`. 89 | - queue_bigchunk: the queue to place the extracted bigchunk, 90 | an instance of multiprocessing.Queue. 91 | - const_global_info: global information visible by all subprocesses, 92 | a dictionary. It can contain, e.g., the lenght of the queues, 93 | waiting times. etc. 94 | - path_logfiles: a path where the logfiles are saved, a string. 95 | If provided, you can log to the file with the funciton self.log(" some string "). 96 | - last_message_from_root: TODO:adddoc for last_message_from_root. 97 | ''' 98 | super(BigChunkLoader, self).__init__() 99 | self.patient = patient 100 | self.queue_bigchunk = queue_bigchunk 101 | self.const_global_info = const_global_info 102 | self._queue_logs = queue_logs 103 | self.old_checkpoint = old_checkpoint 104 | self.last_message_from_root = last_message_from_root 105 | #make internals 106 | # ~ if(self.path_logfiles != None): 107 | # ~ self.logfile = open(self.path_logfiles + "patient_{}.txt".format(self.patient.int_uniqueid),"a") 108 | 109 | def get_checkpoint(self): 110 | return self.old_checkpoint 111 | 112 | def run(self): 113 | """Loads a bigchunk and waits for the patchcollector to enqueue the bigchunk.""" 114 | 115 | #set random seed using time ==== 116 | np.random.seed(int(time.time())) 117 | 118 | #assign the bigchunkloader to core 119 | if(self.const_global_info["core-assignment"]["bigchunkloaders"] != None): 120 | os.system("taskset -cp {} {}".format(self.const_global_info["core-assignment"]["bigchunkloaders"], os.getpid())) 121 | print(" taskset called for bigchunkloader") 122 | 123 | #extract a bigchunk ======= 124 | bigchunk = self.extract_bigchunk(self.last_message_from_root) 125 | #place the bigchunk in the queue 126 | self.queue_bigchunk.put_nowait(bigchunk) 127 | 128 | #the bigchunk loader is kept running (to be quitted by smallchunk collector). 129 | flag_enteredthewhileloop = False 130 | while(True): 131 | if(flag_enteredthewhileloop == False): 132 | flag_enteredthewhileloop = True 133 | #akdump("entered_the_whileafterputnowait") 134 | pass 135 | time.sleep(1) 136 | 137 | 138 | def log(self, str_input): 139 | ''' 140 | Logs to the log file, i.e. the file with `fname_logfile`. 141 | ''' 142 | self._queue_logs.put_nowait(str_input) 143 | 144 | @abstractmethod 145 | def extract_bigchunk(self, last_message_fromroot): 146 | ''' 147 | Extract and return a bigchunk. 148 | Inputs: 149 | - `last_message_fromroot`: The last messsage sent to this patient. Indeed, this is the message sent by calling the function 150 | `lightdl.send_message`. 151 | Note that in this function you have access to `self.patient` and `self.const_global_info`. 152 | ''' 153 | pass 154 | 155 | 156 | 157 | 158 | class SmallChunkCollector(mp.Process): 159 | def __init__(self, patient, queue_smallchunks, const_global_info,\ 160 | type_bigchunkloader, queue_logs, old_checkpoint, queue_checkpoint, last_message_from_root): 161 | ''' 162 | Inputs: 163 | - patient: an instance of `Patient`. 164 | - queue_smallchunks: the queue to place the extracted small chunks, 165 | an instance of multiprocessing.Queue. 166 | - const_global_info: global information visible by all subprocesses, 167 | a dictionary. It can contain, e.g., the lenght of the queues, 168 | waiting times. etc. 169 | - type_bigchunkloader: the type (i.e. Class) of bigchunkloader to instantiate from, 170 | a subclass of BigChunkLoader. 171 | - queue_logs: the queue in which logs are going to placed. 172 | - last_message_from_root: TODO:adddoc for last_message_from_root. 173 | ''' 174 | super(SmallChunkCollector, self).__init__() 175 | self.patient = patient 176 | self.queue_smallchunks = queue_smallchunks 177 | self.const_global_info = const_global_info 178 | self.type_bigchunkloader = type_bigchunkloader 179 | self._queue_logs = queue_logs 180 | self.old_checkpoint = old_checkpoint 181 | self.queue_checkpoint = queue_checkpoint 182 | self.last_message_from_root = last_message_from_root 183 | #make internals ===== 184 | self._cached_checkpoint = "TODO:packagename reserverd: empty cache" 185 | self._queue_status = mp.Queue() 186 | self._cached_status = "TODO:packagename reserverd: empty cache" 187 | self._queue_bigchunkloader_terminated = mp.Queue() 188 | 189 | def log(self, str_input): 190 | ''' 191 | Logs to the log file, i.e. the file with `fname_logfile`. 192 | ''' 193 | self._queue_logs.put_nowait(str_input) 194 | 195 | def set_status(self, status): 196 | ''' 197 | This function sets the status of smallchunkloader. 198 | The dataloader can read the status of the SmallChunkCollector by calling `SmallChunkCollector.get_status`. 199 | Input: 200 | - status: can be any pickleable object, e.g., a string, a dictionary, etc. 201 | ''' 202 | self._queue_status.put_nowait(status) 203 | 204 | def get_status(self): 205 | ''' 206 | This function returns the last status of `SmallChunkCollector`, which is previously set 207 | by calling the function `SmallChunkCollector.set_status`. 208 | ''' 209 | # ~ print("get status called") 210 | qsize_status = self._queue_status.qsize() 211 | # ~ print("qsize_status = {}".format(qsize_status)) 212 | if(qsize_status > 0): 213 | last_status = None 214 | for count in range(qsize_status): 215 | try: 216 | last_status = self._queue_status.get_nowait() 217 | except Exception as e: 218 | pass 219 | # ~ print("an excection occured in get_status function") 220 | # ~ print(str(e)) 221 | # ~ print(" get status returned") 222 | # ~ print(" get status returned") 223 | self._cached_status = last_status 224 | return last_status 225 | else: 226 | # ~ print("return cached status") 227 | return self._cached_status 228 | 229 | def get_checkpoint(self): 230 | if(isinstance(self._cached_checkpoint, str)): 231 | if(self._cached_checkpoint == "TODO:packagename reserverd: empty cache"): 232 | return self.old_checkpoint 233 | else: 234 | return self._cached_checkpoint 235 | else: 236 | return self._cached_checkpoint 237 | 238 | def set_checkpoint(self, checkpoint): 239 | self.queue_checkpoint.put_nowait(checkpoint) 240 | self._cached_checkpoint = checkpoint 241 | 242 | def run(self): 243 | ''' 244 | Loads a bigchunk, waits for the bigchunk to be loaded, and then makes calls to 245 | self.extract_smallchunk. 246 | ''' 247 | #set random seed using time ==== 248 | np.random.seed(int(time.time())) 249 | 250 | # ~ os.nice(1000) #TODO:make tunable 251 | # ~ print(" pid of smallchunkcollector: {}".format(os.getpid())) 252 | #print("in smallchunkloader, pid is: {}".format(os.getpid())) 253 | if(self.const_global_info["core-assignment"]["smallchunkloaders"] != None): 254 | # ~ p = psutil.Process() 255 | # ~ idx_cores = [int(u) for u in self.const_global_info["core-assignment"]["smallchunkloaders"].split(",")] 256 | # ~ p.cpu_affinity(idx_cores) 257 | # ~ print("in smallchunkloaders, idx_cores={}".format(idx_cores)) 258 | os.system("taskset -cp {} {}".format(self.const_global_info["core-assignment"]["smallchunkloaders"], os.getpid())) 259 | print(" taskset called for smallchunkloader") 260 | #print(" subprocess pinded to cores") 261 | 262 | # ~ print("reached here 1") 263 | #Load a bigchunk in a subprocess 264 | queue_bc = mp.Queue() 265 | # ~ print("reached here 2") 266 | proc_bcloader = self.type_bigchunkloader(self.patient, queue_bc,\ 267 | self.const_global_info, self._queue_logs, self.old_checkpoint, self.last_message_from_root) 268 | # ~ print("reached here 3") 269 | proc_bcloader.start() 270 | # ~ print("reached here 4") 271 | while(queue_bc.empty()): 272 | pass 273 | #print("Smallchunk collector is waiting to collect the bigchunk.") 274 | #wait for the sploader to load the superpatch 275 | #TODO:make the waiting more efficient 276 | #collect the bigchunk and start extracting patches from it 277 | # ~ print("reached here 5") 278 | self._queue_bigchunkloader_terminated.put_nowait("Finished loading a bigchunk") 279 | 280 | bigchunk = queue_bc.get() 281 | # ~ print("reached here 6") 282 | 283 | proc_bcloader.terminate() 284 | 285 | call_count = 0 286 | while(True): 287 | if(self.queue_smallchunks.qsize() < self.const_global_info["maxlength_queue_smallchunk"]): 288 | # ~ print(" ----------------- reached here 7") 289 | #print(" smallchunkcollector saw emtpy place in queue.") 290 | smallchunk = self.extract_smallchunk(call_count, bigchunk, self.last_message_from_root) 291 | call_count += 1 292 | #print("... and extracted a smallchunk.") 293 | if(isinstance(smallchunk, np.ndarray) == False): 294 | if(smallchunk == None): 295 | pass 296 | else: 297 | self.queue_smallchunks.put_nowait(smallchunk) 298 | else: 299 | self.queue_smallchunks.put_nowait(smallchunk) 300 | #print(" placed a smallchunk in queue.") 301 | 302 | def get_flag_bigchunkloader_terminated(self): 303 | try: 304 | if(self._queue_bigchunkloader_terminated.qsize() > 0): 305 | return True 306 | else: 307 | return False 308 | except: 309 | print("Exception occured in get_flag_bigchunkloader_terminated") 310 | return False 311 | 312 | @abstractmethod 313 | def extract_smallchunk(self, call_count, bigchunk, last_message_fromroot): 314 | ''' 315 | Extract and return a smallchunk. Please note that in this function you have access to 316 | self.bigchunk, self.patient, self.const_global_info. 317 | Inputs: 318 | - `call_count`: an integer. While the `SmallChunkCollector` is collecting `SmallChunk`s, 319 | the function `extract_smallchunk` is called several times. 320 | The argument `count_calls` is the number of times the `extract_smallchunk` is called 321 | since the `SmallChunkCollector` (and its child `BigChunkLoader`) has started working. 322 | - bigchunk: the extracted bigchunk. 323 | - `last_message_fromroot`: The last messsage sent to this patient. Indeed, this is the message sent by calling the function 324 | `lightdl.send_message`. 325 | Output: 326 | - smallchunk: has to be either an instance of `SmallChunk` or None. 327 | Returning None means the `SmallChunkCollector` is no longer willing to extract `SmallChunk`s for, e.g., 328 | it has sufficiently explored the patient's records. 329 | ''' 330 | pass 331 | 332 | 333 | class LightDL(mp.Process): 334 | def __init__(self, dataset, type_bigchunkloader, type_smallchunkcollector,\ 335 | const_global_info, batch_size, tfms, flag_grabqueue_onunsched=True, collate_func=None, fname_logfile=None, 336 | flag_enable_sendgetmessage = True, flag_enable_setgetcheckpoint = True): 337 | ''' 338 | Inputs: 339 | - dataset: an instance of pydmed.utils.Dataset. 340 | - type_bigchunkloader: the type (Class) of bigchunkloader, 341 | a subplcass of BigChunkLoader. 342 | - type_smallchunkcollector: the type (Class) of smallchunkloader, 343 | a subplcass of SmallChunkLoader. 344 | - const_global_info: global information visible by all subprocesses, 345 | a dictionary. It can contain, e.g., the lenght of the queues, 346 | waiting times. etc. 347 | - batch_size: the size of each batch, an integer. 348 | - fname_logfile: the name of the file to which `.log(str)` function will write. 349 | ''' 350 | #grab privates ==== 351 | super(LightDL, self).__init__() 352 | self.dataset = dataset 353 | self.type_bigchunkloader = type_bigchunkloader 354 | self.type_smallchunkcollector = type_smallchunkcollector 355 | self.const_global_info = const_global_info 356 | self.queue_lightdl = mp.Queue() 357 | self.batch_size = batch_size 358 | self.flag_grabqueue_onunsched = flag_grabqueue_onunsched 359 | self.fname_logfile = fname_logfile 360 | if(collate_func == None): 361 | self.collate_func = LightDL.default_collate 362 | else: 363 | self.collate_func = collate_func 364 | self.tfms = tfms 365 | self.flag_enable_setgetcheckpoint = flag_enable_setgetcheckpoint 366 | self.flag_enable_sendgetmessage = flag_enable_sendgetmessage 367 | #make internals ==== 368 | self.active_subprocesses = set() #set of currently active processes 369 | self._queue_pid_of_lightdl = mp.Queue() 370 | self._queue_message_lightdlfinished = mp.Queue() 371 | self.dict_patient_to_schedcount = {patient:0 for patient in self.dataset.list_patients} 372 | #self.list_poped_entities = [] 373 | self.list_smallchunksforvis = [] #smallchunks without data and only for visualization. 374 | if(flag_enable_setgetcheckpoint == True): 375 | self.dict_patient_to_checkpoint = {patient:None for patient in self.dataset.list_patients} 376 | self._dict_patient_to_queueckpoint = {patient:mp.Queue() for patient in self.dataset.list_patients} 377 | else: 378 | self.dict_patient_to_checkpoint = None 379 | self._dict_patient_to_queueckpoint = None 380 | if(flag_enable_sendgetmessage == True): 381 | self._queue_messages_to_subprocs = {patient:mp.Queue() for patient in self.dataset.list_patients} 382 | else: 383 | self._queue_messages_to_subprocs = None 384 | self._queue_logs = mp.Queue() 385 | if(self.fname_logfile != None): 386 | self.logfile = open(self.fname_logfile, "a") 387 | 388 | def flush_log(self): 389 | if(self.fname_logfile == None): 390 | return #DO nothing 391 | size_queue_log = self._queue_logs.qsize() 392 | for count in range(size_queue_log): 393 | try: 394 | elem = self._queue_logs.get_nowait() 395 | self.logfile.write(elem) 396 | except: 397 | pass 398 | self.logfile.flush() 399 | self.logfile.close() 400 | 401 | def log(self, str_input): 402 | ''' 403 | Logs to the log file, i.e. the file with `fname_logfile`. 404 | ''' 405 | self._queue_logs.put_nowait(str_input) 406 | 407 | def send_message(self, patient, message): 408 | ''' 409 | Sends message to a subprocess corresponding to the patient. 410 | Once the subproc is schedulled to run, it will recieve the last sent message. 411 | You can access the last recieved message in `SmallChunkCollector.extract_smallchunk` and `BigChunkLoader.extract_bigchunk`. 412 | ''' 413 | self._queue_messages_to_subprocs[patient].put_nowait(message) 414 | 415 | 416 | @staticmethod 417 | def _terminaterecursively(pid): 418 | parent = psutil.Process(pid)#TODO:copyright, https://www.reddit.com/r/learnpython/comments/7vwyez/how_to_kill_child_processes_when_using/ 419 | for child in parent.children(recursive=True): 420 | try: 421 | child.kill() 422 | except: 423 | pass 424 | #print(" killed subprocess {}".format(child)) 425 | #if including_parent: 426 | try: 427 | parent.kill() 428 | except: 429 | pass 430 | 431 | def pause_loading(self): 432 | lightdl_pid = self._queue_pid_of_lightdl.get() 433 | self.flush_log() 434 | parent = psutil.Process(lightdl_pid)#TODO:copyright, https://www.reddit.com/r/learnpython/comments/7vwyez/how_to_kill_child_processes_when_using/ 435 | for child in parent.children(recursive=True): 436 | try: 437 | child.kill() 438 | except: 439 | pass 440 | #print(" killed subprocess {}".format(child)) 441 | #if including_parent: 442 | try: 443 | parent.kill() 444 | except: 445 | pass 446 | 447 | def is_dl_running(self): 448 | ''' 449 | If the DL is still working, returns True. 450 | Otherwise returns False. 451 | Warning: executing this function may take alot of time. 452 | Avoid making frequent calls to this function. 453 | ''' 454 | #try to get the qsize of finish message queue. 455 | try: 456 | qsize_finishmessage = self._queue_message_lightdlfinished.qsize() 457 | except: 458 | return True 459 | 460 | if(qsize_finishmessage > 0): 461 | return False 462 | else: 463 | return True 464 | # ~ p = subprocess.Popen(['date']) #date command exists in both linux and windows. 465 | # ~ poll = p.poll() 466 | # ~ if poll is None: 467 | # ~ return True 468 | # ~ else: 469 | # ~ return False 470 | 471 | def visualize(self, func_visualize_one_patient): 472 | ''' 473 | When visualizing the collected instances by lightdl, you should call `visualize` function. 474 | You should pass in the function `func_visualize_one_patient` that works as follows: 475 | Given all smallchunks collected for a specific patient, this function 476 | should visualize the patient. 477 | Inputs: 478 | - patient: the patient under considerations, an instance of `utils.data.Patient`. 479 | - list_smallchunks: the list of all collected small chunks for the patient, 480 | a list whose elements are an instance of `lightdl.SmallChunk`. 481 | ''' 482 | #separate collected smallchunks based on patients ====== 483 | dict_patient_to_listsmallchunnks = {patient:[] for patient in self.dataset.list_patients} 484 | for smallchunk in self.list_smallchunksforvis: 485 | dict_patient_to_listsmallchunnks[smallchunk.patient].append(smallchunk) 486 | #call the visualization function ====== 487 | for patient in dict_patient_to_listsmallchunnks.keys(): 488 | func_visualize_one_patient(patient, dict_patient_to_listsmallchunnks[patient]) 489 | 490 | @staticmethod 491 | def default_collate(list_smallchunks, tfms): 492 | list_data = [smallchunks.data for smallchunks in list_smallchunks] 493 | if(tfms != None): 494 | for n in range(len(list_data)): #apply transforms 495 | list_data[n] = tfms(list_data[n]) 496 | x = torch.stack(list_data, axis=0) 497 | list_patients = [smallchunks.patient for smallchunks in list_smallchunks] 498 | list_smallchunks = [smallchunk for smallchunk in list_smallchunks] 499 | for smallchunk in list_smallchunks: 500 | smallchunk.data = "None to avoid memory leak" 501 | return x, list_patients, list_smallchunks 502 | 503 | def get(self): 504 | #make toret values ================= 505 | list_poped_smallchunks = [] 506 | flag_dl_running = self.is_dl_running() 507 | if(flag_dl_running == True): 508 | while(len(list_poped_smallchunks) < self.batch_size): 509 | #try to get a new instance ==== 510 | if(self.queue_lightdl.qsize()>0): 511 | try: 512 | smallchunk = self.queue_lightdl.get_nowait() 513 | list_poped_smallchunks.append(smallchunk) 514 | except: 515 | pass 516 | #if dl_is_finished and Q is empty, exit the while loop 517 | if((self.is_dl_running()==False) and (self.queue_lightdl.qsize()==0)): 518 | #in this case, `get` function has no chance to collect more instances 519 | 520 | if(len(list_poped_smallchunks) == 0): 521 | #if the list is empty, it means no more instances are to be collected. 522 | return PYDMEDRESERVED_DLRETURNEDLASTINSTANCE 523 | else: 524 | #some instances are collected, then break the while loop and continue as before 525 | break 526 | elif(flag_dl_running == False): 527 | #in this case, `get` will return the Q instances one-by-one regardless of the the `batch_size`. 528 | if(self.queue_lightdl.qsize() > 0): 529 | while(len(list_poped_smallchunks) < 1): 530 | #try to get a new instance ==== 531 | if(self.queue_lightdl.qsize()>0): 532 | try: 533 | smallchunk = self.queue_lightdl.get_nowait() 534 | list_poped_smallchunks.append(smallchunk) 535 | except: 536 | pass 537 | else: 538 | #in this case, dl is finished and Q is empty. 539 | return PYDMEDRESERVED_DLRETURNEDLASTINSTANCE 540 | 541 | 542 | # ~ print("get: reached here 2") 543 | returnvalue_of_collatefunc = self.collate_func(list_poped_smallchunks, self.tfms) 544 | # ~ print("get: reached here 3") 545 | #grab visualization info ============ 546 | for smallchunk in list_poped_smallchunks: 547 | smallchunk_datafree = SmallChunk(data = "None to avoid memory leak",\ 548 | dict_info_of_smallchunk = smallchunk.dict_info_of_smallchunk,\ 549 | dict_info_of_bigchunk = smallchunk.dict_info_of_bigchunk,\ 550 | patient = smallchunk.patient) 551 | self.list_smallchunksforvis.append(smallchunk_datafree) 552 | # ~ toret_list_smallchunks = [smallchunk for smallchunk in list_poped_smallchunks] 553 | # ~ for smallchunk in toret_list_smallchunks: 554 | # ~ smallchunk.data = "None to avoid memory leak" 555 | # ~ print("get: reached here 4") 556 | return returnvalue_of_collatefunc #batch_smallchunks, batch_patients, toret_list_smallchunks 557 | 558 | def get_list_loadedpatients(self): 559 | ''' 560 | Returns the list of `Patient`s that are loaded, 561 | i.e., one `SmallChunkCollector` is collecting `SmallChunk`s from them. 562 | ''' 563 | list_loadedpatients = [subproc.patient\ 564 | for subproc in list(self.active_subprocesses)] 565 | return list_loadedpatients 566 | 567 | 568 | def get_list_waitingpatients(self): 569 | ''' 570 | Returns the list of `Patient`s that are not loaded, 571 | i.e., no `SmallChunkCollector` is collecting `SmallChunk`s from them. 572 | ''' 573 | set_running_patients = set(self.get_list_loadedpatients()) 574 | set_waiting_patients = set(self.dataset.list_patients).difference( 575 | set_running_patients) 576 | return list(set_waiting_patients) 577 | 578 | 579 | def get_schedcount_of(self, patient): 580 | ''' 581 | Reuturns the number of times that a specific `Patient` has 582 | been schedulled by scheduller. 583 | ''' 584 | return self.dict_patient_to_schedcount[patient] 585 | 586 | 587 | def initial_schedule(self): 588 | ''' 589 | Used for selecting the initiail BigChunks. 590 | This funciton has to return, 591 | - `list_initial_patients`: a list containing `Patients` who are initially loaded. 592 | The length of the list must be equal to `self.const_global_info["num_bigchunkloaders"]` 593 | ''' 594 | #Default is to choose randomly from dataset. 595 | return random.choices(self.dataset.list_patients,\ 596 | k=self.const_global_info["num_bigchunkloaders"]) 597 | 598 | def schedule(self): 599 | ''' 600 | This function is called when schedulling a new patient, i.e., loading a new BigChunk. 601 | This function has to return: 602 | - patient_toremove: the patient to remove, an instance of `utils.data.Patient`. 603 | - patient_toload: the patient to load, an instance of `utils.data.Patient`. 604 | In this function, you have access to the following fields: 605 | - self.dict_patient_to_schedcount: given a patient, returns the number of times the patients has been schedulled in dl, a dictionary. 606 | - self.list_loadedpatients: 607 | - self.list_waitingpatients: 608 | - TODO: add more fields here to provide more flexibility. For instance, total time that the patient have been loaded on DL. 609 | ''' 610 | #get initial fields ============================== 611 | list_loadedpatients = self.get_list_loadedpatients() 612 | list_waitingpatients = self.get_list_waitingpatients() 613 | waitingpatients_schedcount = [self.get_schedcount_of(patient)\ 614 | for patient in list_waitingpatients] 615 | 616 | #patient_toremove is selected randomly ======================= 617 | patient_toremove = random.choice(list_loadedpatients) 618 | 619 | #when choosing a patient to load, give huge weight to the instances which are not schedulled so far. 620 | weights = 1.0/(1.0+np.array(waitingpatients_schedcount)) 621 | weights[weights==1.0] =10000000.0 #the places where schedcount is zero 622 | patient_toload = random.choices(list_waitingpatients,\ 623 | weights = weights, k=1)[0] 624 | 625 | 626 | return patient_toremove, patient_toload 627 | 628 | 629 | def run(self): 630 | 631 | try: 632 | 633 | #assign to core 634 | if(self.const_global_info["core-assignment"]["lightdl"] != None): 635 | # ~ p = psutil.Process() 636 | # ~ idx_cores = [int(u) for u in self.const_global_info["core-assignment"]["lightdl"].split(",")] 637 | # ~ p.cpu_affinity(idx_cores) 638 | # ~ print("in lightdl, idx_cores={}".format(idx_cores)) 639 | os.system("taskset -a -cp {} {}".format(self.const_global_info["core-assignment"]["lightdl"], os.getpid())) 640 | print(" taskset called for lightdl") 641 | #save pid of lightdl (to do recursive kill on finish) 642 | self._queue_pid_of_lightdl.put_nowait(os.getpid()) 643 | #initially fill the pool of subprocesses ======== 644 | patients_forinitialload = self.initial_schedule() 645 | # ~ random.choices(self.dataset.list_patients,\ 646 | # ~ k=self.const_global_info["num_bigchunkloaders"]) 647 | print(" loading initial bigchunks, please wait ....") 648 | t1 = time.time() 649 | for i in range(len(patients_forinitialload)): 650 | # ~ print(" reached here 1") 651 | print(" bigchunk {} from {}, please wait ...\n".format(i, len(patients_forinitialload))) 652 | if(self._queue_messages_to_subprocs != None): 653 | last_message_from_root = pydmed.utils.multiproc.poplast_from_queue( 654 | self._queue_messages_to_subprocs[patients_forinitialload[i]] 655 | ) 656 | else: 657 | last_message_from_root = None 658 | if(self._queue_messages_to_subprocs != None): 659 | old_checkpoint = self.dict_patient_to_checkpoint[patients_forinitialload[i]] 660 | queue_checkpoint = self._dict_patient_to_queueckpoint[patients_forinitialload[i]] 661 | else: 662 | old_checkpoint = None 663 | queue_checkpoint = None 664 | subproc = self.type_smallchunkcollector( 665 | patient=patients_forinitialload[i],\ 666 | queue_smallchunks=mp.Queue(),\ 667 | const_global_info=self.const_global_info,\ 668 | type_bigchunkloader=self.type_bigchunkloader,\ 669 | queue_logs=self._queue_logs,\ 670 | old_checkpoint = old_checkpoint,\ 671 | queue_checkpoint = queue_checkpoint,\ 672 | last_message_from_root = last_message_from_root 673 | ) 674 | # ~ print(" reached here 2") 675 | self.active_subprocesses.add(subproc) 676 | # ~ print(" reached here 3") 677 | subproc.start() 678 | # ~ print(" reached here 4") 679 | self.dict_patient_to_schedcount[subproc.patient] = self.dict_patient_to_schedcount[subproc.patient] + 1 680 | # ~ print(" reached here 5") 681 | while(subproc.queue_smallchunks.qsize() == 0): 682 | pass 683 | # ~ print(" reached here 6") 684 | #wait untill the bigchunk is loaded and something the smallchunkcollector collects the first smallchunk. 685 | t2 = time.time() 686 | print("The initial loading of bigchunks took {} seconds.".format(t2-t1)) 687 | #patrol the subprocesses ====================== 688 | time_lastresched = time.time() + 1*self.const_global_info["interval_resched"] 689 | while(True): 690 | # ~ print("============= lightdl-queue.qsize() = {} ===========".format(self.queue_lightdl.qsize())) 691 | #collect patches from the subporcesses ============ 692 | while(self.queue_lightdl.qsize() >=\ 693 | self.const_global_info["maxlength_queue_lightdl"]): 694 | pass 695 | #wait until queue_lightdl becomes less heavy. 696 | for subproc in list(self.active_subprocesses): 697 | if(subproc.queue_smallchunks.empty() == False): 698 | try: 699 | smallchunk = subproc.queue_smallchunks.get_nowait() 700 | self.queue_lightdl.put_nowait(smallchunk) 701 | #print("lightdl placed smallchunk in queue") 702 | #print("LightPatcher collected patches from WSI {}"\ 703 | # .format(subproc.fname_wsi)) 704 | except: 705 | pass 706 | #replace a subprocesses if needed ======================= 707 | tnow = time.time() 708 | time_from_lastresched = tnow - time_lastresched 709 | if(time_from_lastresched > self.const_global_info["interval_resched"]): 710 | # ~ print("rescheduling -------- in time = {}".format(time.time())) 711 | time_lastresched = time.time() 712 | #setlect a ptient to add, and a subrpocess to remove 713 | # ~ set_running_patients = set([subproc.patient\ 714 | # ~ for subproc in list(self.active_subprocesses)]) 715 | # ~ list_loadedpatients = [subproc.patient\ 716 | # ~ for subproc in list(self.active_subprocesses)] 717 | # ~ set_waiting_patiens = set(self.dataset.list_patients).difference( 718 | # ~ set_running_patients) 719 | # ~ list_waitingpatients = list(set_waiting_patiens) 720 | # ~ print("reached before schedule") 721 | patient_toremove, patient_toadd = self.schedule() 722 | if(isinstance(patient_toremove, str)): 723 | if(patient_toremove == PYDMEDRESERVED_HALTDL): 724 | self._queue_message_lightdlfinished.put_nowait("DL-Finished") 725 | while(True): pass #the DL process has to be terminated by the parent process. LightDL._terminaterecursively(self.pid) 726 | 727 | 728 | # ~ print("reached after schedule") 729 | flag_sched_returnednan = True 730 | if(isinstance(patient_toremove, pydmed.utils.data.Patient)): 731 | flag_sched_returnednan = False 732 | assert(isinstance(patient_toadd, pydmed.utils.data.Patient)) 733 | if(flag_sched_returnednan == True): 734 | # ~ print(" reached here 1") 735 | pass #do not reschedule 736 | else: 737 | # ~ print(" reached here 2") 738 | subproc_toremove = None 739 | for subproc in list(self.active_subprocesses): 740 | if(subproc.patient == patient_toremove): 741 | # ~ print(" found subproc_toremove") 742 | subproc_toremove = subproc 743 | break 744 | # ~ print(subproc_toremove) 745 | # ~ print(" list_loadedpatients: ") 746 | # ~ print(list_loadedpatients) 747 | # ~ print(" patient_toremove: ") 748 | # ~ print(patient_toremove) 749 | 750 | #add the smallchunks of subproc_toremove to lightdl.queue =============== 751 | if(self.flag_grabqueue_onunsched == True): 752 | size_queueof_subproctoremove = subproc_toremove.queue_smallchunks.qsize() 753 | for count in range(size_queueof_subproctoremove): 754 | try: 755 | smallchunk = subproc_toremove.queue_smallchunks.get_nowait() 756 | self.queue_lightdl.put_nowait(smallchunk) 757 | except Exception as e: 758 | print("Warning: Some smallchunks may have lost. If not, you can safely ignore this warning.") 759 | #print(str(e)) 760 | 761 | 762 | #grab the last checkpoint of the subproc ======================= 763 | if(self.flag_enable_setgetcheckpoint == True): 764 | numcheckpoints_subproctoremove = subproc_toremove.queue_checkpoint.qsize() 765 | last_checkpoint = None 766 | for count in range(numcheckpoints_subproctoremove): 767 | try: 768 | last_checkpoint = subproc_toremove.queue_checkpoint.get_nowait() 769 | except Exception as e: 770 | print("an exception occured at line 652") 771 | print(str(e)) 772 | 773 | self.dict_patient_to_checkpoint[patient_toremove] = last_checkpoint 774 | # ~ print(" reached here 3") 775 | if(subproc_toremove == None): 776 | print("patient_toremove not found in the list of waiting patients.") 777 | _terminaterecursively(self.pid) 778 | # ~ print(" reached here 4") 779 | # ~ print(subproc_toremove) 780 | # ~ print(" reached here 5") 781 | #print(" patient toremove: {}".format(subproc_toremove.patient.name)) 782 | #print(" patient toadd: {}".format(patient_toadd.name)) 783 | #remove the subprocess 784 | self.active_subprocesses.remove(subproc_toremove) 785 | # ~ print(" reached here 6") 786 | #print("reached here 4") 787 | # ~ dlmed.utils.multiproc.terminaterecursively(subproc_toremove.pid) 788 | LightDL._terminaterecursively(subproc_toremove.pid) 789 | # ~ print(" reached here 7") 790 | #subproc_toremove.kill() 791 | #print("reached here 5") 792 | #add a new process for patient_toadd 793 | if(self._queue_messages_to_subprocs != None): 794 | last_message_from_root = pydmed.utils.multiproc.poplast_from_queue( 795 | self._queue_messages_to_subprocs[patients_forinitialload[i]] 796 | ) 797 | else: 798 | last_message_from_root = None 799 | if(self._queue_messages_to_subprocs != None): 800 | old_checkpoint = self.dict_patient_to_checkpoint[patient_toadd] 801 | queue_checkpoint = self._dict_patient_to_queueckpoint[patient_toadd] 802 | else: 803 | old_checkpoint = None 804 | queue_checkpoint = None 805 | new_subproc = self.type_smallchunkcollector(\ 806 | patient=patient_toadd,\ 807 | queue_smallchunks=mp.Queue(),\ 808 | const_global_info=self.const_global_info,\ 809 | type_bigchunkloader=self.type_bigchunkloader,\ 810 | queue_logs=self._queue_logs,\ 811 | old_checkpoint = old_checkpoint,\ 812 | queue_checkpoint = queue_checkpoint,\ 813 | last_message_from_root = last_message_from_root 814 | ) 815 | # ~ print(" reached here 8") 816 | #print("reached here 6") 817 | self.active_subprocesses.add(new_subproc) 818 | # ~ print(" reached here 9") 819 | #print("reached here 7") 820 | self.dict_patient_to_schedcount[new_subproc.patient] = self.dict_patient_to_schedcount[new_subproc.patient] + 1 821 | # ~ print(" reached here 10") 822 | new_subproc.start() 823 | # ~ print(" reached here 11") 824 | except Exception as e: 825 | print("An exception occured in LightDL.") 826 | print(str(e)) 827 | exc_type, exc_obj, exc_tb = sys.exc_info() 828 | fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] 829 | print(exc_type, fname, exc_tb.tb_lineno) 830 | print("\n\n\n") 831 | _terminaterecursively(self.pid) 832 | 833 | 834 | 835 | 836 | 837 | --------------------------------------------------------------------------------