├── skip.py ├── cfdna.model ├── util.py ├── LICENSE ├── .gitignore ├── predict.py ├── fastq.py ├── feature.py ├── README.md ├── draw.py └── train.py /skip.py: -------------------------------------------------------------------------------- 1 | file_to_skip = [] -------------------------------------------------------------------------------- /cfdna.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGene/CfdnaPattern/HEAD/cfdna.model -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import fastq 3 | import skip 4 | 5 | def is_to_skip(path): 6 | for item in skip.file_to_skip: 7 | if item[0] in path and item[1] in path: 8 | return True 9 | return False 10 | 11 | def get_arg_files(): 12 | files = [] 13 | for f in sys.argv: 14 | if fastq.is_fastq(f) and os.path.exists(f): 15 | path = os.path.join(os.getcwd(), f) 16 | if not is_to_skip(path): 17 | files.append(path) 18 | return files 19 | 20 | def has_adapter_sequenced(data): 21 | # work around for skipping the data with 6bp index, sequenced in 8bp index setting 22 | count = 0 23 | for i in range(min(len(data), 8)): 24 | if data[i]>0.7: 25 | count += 1 26 | return count >= 1 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 OpenGene - Open Source Genetics Toolbox 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | skip.py 3 | .DS_Store 4 | *_fig 5 | *figures 6 | cache.json 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | env/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *,cover 51 | .hypothesis/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # IPython Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | 93 | # Rope project settings 94 | .ropeproject 95 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys, os 3 | from optparse import OptionParser 4 | import time 5 | from util import * 6 | from draw import * 7 | from feature import * 8 | import numpy as np 9 | from sklearn import svm, neighbors 10 | import random 11 | import json 12 | import pickle 13 | 14 | def parseCommand(): 15 | usage = "extract the features, and train the model, from the training set of fastq files. \n\npython training.py [-f feature_file] [-m model_file] " 16 | version = "0.0.1" 17 | parser = OptionParser(usage = usage, version = version) 18 | parser.add_option("-m", "--model", dest = "model_file", default = "cfdna.model", 19 | help = "specify which file stored the built model.") 20 | parser.add_option("-q", "--quite", dest = "quite", action='store_true', default = False, 21 | help = "only print those prediction conflicts with filename") 22 | return parser.parse_args() 23 | 24 | def preprocess(options): 25 | 26 | data = [] 27 | samples = [] 28 | fq_files = get_arg_files() 29 | 30 | number = 0 31 | for fq in fq_files: 32 | number += 1 33 | #print(str(number) + ": " + fq) 34 | 35 | extractor = FeatureExtractor(fq) 36 | extractor.extract() 37 | feature = extractor.feature() 38 | 39 | if feature == None: 40 | #print("======== Warning: bad feature from:") 41 | #print(fq) 42 | continue 43 | 44 | data.append(feature) 45 | samples.append(fq) 46 | 47 | return data, samples 48 | 49 | def get_type_name(label): 50 | if label == 1: 51 | return "cfdna" 52 | else: 53 | return "not-cfdna" 54 | 55 | def load_model(options): 56 | filename = options.model_file 57 | if not os.path.exists(filename): 58 | filename = os.path.join(os.path.dirname(sys.argv[0]), options.model_file) 59 | if not os.path.exists(filename): 60 | print("Error: the model file not found: " + options.model_file) 61 | sys.exit(1) 62 | f = open(filename, "rb") 63 | model = pickle.load(f) 64 | f.close() 65 | return model 66 | 67 | def main(): 68 | if sys.version_info.major >2: 69 | print('python3 is not supported yet, please use python2') 70 | sys.exit(1) 71 | 72 | (options, args) = parseCommand() 73 | 74 | data, samples = preprocess(options) 75 | 76 | model = load_model(options) 77 | 78 | labels = model.predict(data) 79 | 80 | for i in xrange(len(samples)): 81 | if options.quite == False or (labels[i] == 0 and "cfdna" in samples[i].lower()) or (labels[i] == 1 and "cfdna" not in samples[i].lower()): 82 | print(get_type_name(labels[i]) + ": " + samples[i]) 83 | 84 | plot_data_list(samples, data, "predict_fig") 85 | 86 | if __name__ == "__main__": 87 | main() -------------------------------------------------------------------------------- /fastq.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import gzip 4 | import os,sys 5 | 6 | def is_fastq(f): 7 | fqext = (".fq", ".fastq", "fq.gz", ".fastq.gz") 8 | for ext in fqext: 9 | if f.endswith(ext) and not os.path.isdir(f): 10 | return True 11 | return False 12 | 13 | ################################ 14 | #fastq.reader 15 | 16 | class Reader: 17 | 18 | def __init__(self, fname): 19 | self.__file = None 20 | self.__gz = False 21 | self.__eof = False 22 | self.filename = fname 23 | if self.filename.endswith(".gz"): 24 | self.__gz = True 25 | self.__file = gzip.open(self.filename, "r") 26 | else: 27 | self.__gz = False 28 | self.__file = open(self.filename, "r") 29 | if self.__file == None: 30 | print("Failed to open file " + self.filename) 31 | sys.exit(1) 32 | 33 | def __del__(self): 34 | if self.__file != None: 35 | self.__file.close() 36 | 37 | def nextRead(self): 38 | if self.__eof == True or self.__file == None: 39 | return None 40 | 41 | lines = [] 42 | #read 4 (lines, name, sequence, strand, quality) 43 | for i in xrange(0,4): 44 | line = self.__file.readline().rstrip() 45 | if len(line) == 0: 46 | self.__eof = True 47 | return None 48 | lines.append(line) 49 | return lines 50 | 51 | def isEOF(self): 52 | return False 53 | 54 | ################################ 55 | #fastq.writer 56 | 57 | class Writer: 58 | 59 | filename = "" 60 | 61 | __file = None 62 | __gz = False 63 | 64 | def __init__(self, fname): 65 | self.filename = fname 66 | if self.filename.endswith(".gz"): 67 | self.__gz = True 68 | self.__file = gzip.open(self.filename, "w") 69 | else: 70 | self.__gz = False 71 | self.__file = open(self.filename, "w") 72 | if self.__file == None: 73 | print("Failed to open file " + self.filename + " to write") 74 | sys.exit(1) 75 | 76 | def __del__(self): 77 | if self.__file != None: 78 | self.__file.flush() 79 | self.__file.close() 80 | 81 | def flush(self): 82 | if self.__file !=None: 83 | self.__file.flush() 84 | 85 | def writeLines(self, lines): 86 | if self.__file == None: 87 | return False 88 | 89 | for line in lines: 90 | self.__file.write(line+"\n") 91 | return True 92 | 93 | def writeRead(self, name, seqence, strand, quality): 94 | if self.__file == None: 95 | return False 96 | 97 | self.__file.write(name+"\n") 98 | self.__file.write(seqence+"\n") 99 | self.__file.write(strand+"\n") 100 | self.__file.write(quality+"\n") 101 | 102 | return True 103 | -------------------------------------------------------------------------------- /feature.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import fastq 3 | 4 | STAT_LEN_LIMIT = 10 5 | READ_TO_SKIP = 1000 6 | ALL_BASES = ("A", "T", "C", "G"); 7 | 8 | class FeatureExtractor: 9 | def __init__(self, filename, sample_limit=10000): 10 | self.sample_limit = sample_limit 11 | self.filename = filename 12 | self.base_counts = {} 13 | self.percents = {} 14 | self.read_count = 0 15 | self.stat_len = 0 16 | self.total_num = [0 for x in xrange(STAT_LEN_LIMIT)] 17 | for base in ALL_BASES: 18 | self.base_counts[base] = [0 for x in xrange(STAT_LEN_LIMIT)] 19 | self.percents[base] = [0.0 for x in xrange(STAT_LEN_LIMIT)] 20 | 21 | def stat_read(self, read): 22 | seq = read[1] 23 | seqlen = len(seq) 24 | for i in xrange(min(seqlen, STAT_LEN_LIMIT)): 25 | self.total_num[i] += 1 26 | b = seq[i] 27 | if b in ALL_BASES: 28 | self.base_counts[b][i] += 1 29 | 30 | def extract(self): 31 | reader = fastq.Reader(self.filename) 32 | stat_reads_num = 0 33 | skipped_reads = [] 34 | #sample up to maxSample reads for stat 35 | while True: 36 | read = reader.nextRead() 37 | if read==None: 38 | break 39 | self.read_count += 1 40 | # here we skip the first 1000 reads because usually they are usually not stable 41 | if self.read_count < READ_TO_SKIP: 42 | skipped_reads.append(read) 43 | continue 44 | stat_reads_num += 1 45 | if stat_reads_num > self.sample_limit and self.sample_limit>0: 46 | break 47 | self.stat_read(read) 48 | 49 | # if the fq file is too small, then we stat the skipped reads again 50 | if stat_reads_num < READ_TO_SKIP: 51 | for read in skipped_reads: 52 | self.stat_read(read) 53 | 54 | self.calc_read_len() 55 | self.calc_percents() 56 | 57 | def calc_read_len(self): 58 | for pos in xrange(STAT_LEN_LIMIT): 59 | has_data = False 60 | for base in ALL_BASES: 61 | if self.base_counts[base][pos]>0: 62 | has_data = True 63 | if has_data == False: 64 | self.stat_len = pos 65 | return 66 | if has_data: 67 | self.stat_len = STAT_LEN_LIMIT 68 | 69 | def calc_percents(self): 70 | #calc percents of each base 71 | for pos in xrange(self.stat_len): 72 | total = 0 73 | for base in ALL_BASES: 74 | total += self.base_counts[base][pos] 75 | for base in ALL_BASES: 76 | self.percents[base][pos] = float(self.base_counts[base][pos])/float(total) 77 | 78 | def feature(self): 79 | # bad feature 80 | if self.stat_len < STAT_LEN_LIMIT: 81 | return None 82 | #calc percents of each base 83 | feature_vector = [] 84 | # data is packed as values of ATCGATCGATCGATCGATCG 85 | for pos in xrange(self.stat_len): 86 | total = 0 87 | for base in ALL_BASES: 88 | feature_vector.append(self.percents[base][pos]) 89 | return feature_vector 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CfdnaPattern 2 | Pattern Recognition for Cell-free DNA 3 | 4 | # Predict a fastq is cfdna or not 5 | ```shell 6 | # predict a single file 7 | python predict.py 8 | 9 | # predict files 10 | python predict.py ... 11 | 12 | # predict files with wildcard 13 | python predict.py *.fq 14 | ``` 15 | 16 | ***warning: this tool doesn't work for trimmed fastq*** 17 | 18 | ## prediction output 19 | For each file given in the command line, this tool will output a line `: `, like 20 | ``` 21 | cfdna: /fq/160220_NS500713_0040_AHVNG2BGXX/20160220-cfdna-001_S1_R1_001.fastq.gz 22 | cfdna: /fq/160220_NS500713_0040_AHVNG2BGXX/20160220-cfdna-001_S1_R2_001.fastq.gz 23 | not-cfdna: /fq/160220_NS500713_0040_AHVNG2BGXX/20160220-gdna-002_S2_R1_001.fastq.gz 24 | not-cfdna: /fq/160220_NS500713_0040_AHVNG2BGXX/20160220-gdna-002_S2_R2_001.fastq.gz 25 | ``` 26 | Add `-q` or `--quite` to enable quite output mode, in which it will only output: 27 | * a file with name of `cfdna`, but prediction is `not-cfdna` 28 | * a file without name of `cfdna`, but prediction is `cfdna` 29 | 30 | # Train a model 31 | This tool has a pre-trained model (`cfdna.model`), which can be used for prediction. But you still can train a model by yourself. 32 | * prepare/link all your fastq files in some folder 33 | * for files from `cfdna`, include `cfdna` (case-insensitive) in the filename, like `20160220-cfdna-015_S15_R1_001.fq` 34 | * for files from `genomic DNA`, include `gdna` (case-insensitive) in the filename, like `20160220-gdna-002_S2_R1_001.fq` 35 | * for files from `FFPE DNA`, include `ffpe` (case-insensitive) in the filename, like `20160123-ffpe-040_S0_R1_001.fq` 36 | * run: 37 | ```shell 38 | python train.py /fastq_folder/*.fq 39 | ``` 40 | 41 | # Citation 42 | If you used CfdnaPattern for your publication, please cite: https://doi.org/10.1109/TCBB.2017.2723388 43 | 44 | Full options: 45 | ```shell 46 | python training.py [options] 47 | 48 | Options: 49 | --version show program's version number and exit 50 | -h, --help show this help message and exit 51 | -m MODEL_FILE, --model=MODEL_FILE 52 | specify which file to store the built model. 53 | -a ALGORITHM, --algorithm=ALGORITHM 54 | specify which algorithm to use for classfication, 55 | candidates are svm/knn/rbf/rf/gnb/benchmark, rbf means 56 | svm using rbf kernel, rf means random forest, gnb 57 | means Gaussian Naive Bayes, benchmark will try every 58 | algorithm and plot the score figure, default is knn. 59 | -c CFDNA_FLAG, --cfdna_flag=CFDNA_FLAG 60 | specify the filename flag of cfdna files, separated by 61 | semicolon. default is: cfdna 62 | -o OTHER_FLAG, --other_flag=OTHER_FLAG 63 | specify the filename flag of other files, separated by 64 | semicolon. default is: gdna;ffpe 65 | -p PASSES, --passes=PASSES 66 | specify how many passes to do training and validating, 67 | default is 10. 68 | -n, --no_cache_check if the cache file exists, use it without checking the 69 | identity with input files 70 | ``` 71 | -------------------------------------------------------------------------------- /draw.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | from feature import * 3 | 4 | # try to load matplotlib 5 | HAVE_MATPLOTLIB = True 6 | 7 | try: 8 | import matplotlib 9 | # fix matplotlib DISPLAY issue 10 | matplotlib.use('Agg') 11 | import matplotlib.pyplot as plt 12 | matplotlib.rc('axes',edgecolor='#AAAAAA', labelcolor='#666666') 13 | matplotlib.rc('xtick',color='#666666') 14 | matplotlib.rc('ytick',color='#666666') 15 | except Exception: 16 | HAVE_MATPLOTLIB = False 17 | 18 | def plot_data_list(wrong_files, wrong_data, figure_dir): 19 | if not HAVE_MATPLOTLIB: 20 | print("\nmatplotlib not installed, skip plotting figures for files with wrong predictions") 21 | return 22 | 23 | if not os.path.exists(figure_dir): 24 | try: 25 | os.mkdir(figure_dir) 26 | except Exception: 27 | print("failed to create folder to store figures") 28 | return 29 | for i in xrange(len(wrong_files)): 30 | filename = wrong_files[i] 31 | f = os.path.join(figure_dir, filename.strip('/').replace("/", "-") + ".png") 32 | plot_data(wrong_data[i], f, filename[filename.rfind('/')+1:]) 33 | 34 | def plot_data(data, filename, title): 35 | # data is packed as values of ATCGATCGATCGATCGATCG 36 | colors = {'A':'red', 'T':'purple', 'C':'blue', 'G':'green'} 37 | base_num = len(ALL_BASES) 38 | cycles = len(data)/base_num 39 | percents = {} 40 | for b in xrange(base_num): 41 | percents[ALL_BASES[b]]=[ 0.0 for c in xrange(cycles)] 42 | 43 | for c in xrange(cycles): 44 | total = 0 45 | for b in xrange(base_num): 46 | total += data[c * base_num + b] 47 | for b in xrange(base_num): 48 | percents[ALL_BASES[b]][c] = float(data[c * base_num + b]) / float(total) 49 | 50 | x = range(1, cycles+1) 51 | plt.figure(1, figsize=(5.5,3), edgecolor='#cccccc') 52 | plt.title(title[0:title.find('.')], size=10) 53 | plt.xlim(1, cycles) 54 | max_y = 0.35 55 | min_y = 0.15 56 | for base in ALL_BASES: 57 | max_of_base = max(percents[base][0:cycles]) 58 | max_y = max(max_y, max_of_base+0.05) 59 | min_of_base = min(percents[base][0:cycles]) 60 | min_y = min(min_y, min_of_base-0.05) 61 | plt.ylim(min_y, max_y ) 62 | plt.ylabel('Ratio') 63 | #plt.xlabel('Cycle') 64 | for base in ALL_BASES: 65 | plt.plot(x, percents[base][0:cycles], color = colors[base], label=base, alpha=0.5, linewidth=2, marker='o', markeredgewidth=0.0, markersize=4) 66 | #plt.legend(loc='upper right', ncol=5) 67 | plt.savefig(filename) 68 | plt.close(1) 69 | 70 | def plot_benchmark(scores_arr, algorithms_arr, filename): 71 | colors = ['#FF6600', '#009933', '#2244AA', '#552299', '#11BBDD'] 72 | linestyles = ['-', '--', ':'] 73 | passes = len(scores_arr[0]) 74 | 75 | x = range(1, passes+1) 76 | title = "Benchmark Result" 77 | plt.figure(1, figsize=(8,8)) 78 | plt.title(title, size=20, color='#333333') 79 | plt.xlim(1, passes) 80 | plt.ylim(0.97, 1.001) 81 | plt.ylabel('Score', size=16, color='#333333') 82 | plt.xlabel('Validation pass (sorted by score)', size=16, color='#333333') 83 | for i in xrange(len(scores_arr)): 84 | plt.plot(x, scores_arr[i], color = colors[i%5], label=algorithms_arr[i], alpha=0.5, linewidth=2, linestyle = linestyles[i%3]) 85 | plt.legend(loc='lower left') 86 | plt.savefig(filename) 87 | plt.close(1) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys, os 3 | import fastq 4 | from optparse import OptionParser 5 | from multiprocessing import Process, Queue 6 | import time 7 | from util import * 8 | from draw import * 9 | from feature import * 10 | import numpy as np 11 | from sklearn import svm, neighbors 12 | from sklearn.ensemble import RandomForestClassifier 13 | from sklearn import linear_model 14 | from sklearn.naive_bayes import GaussianNB 15 | from sklearn.preprocessing import normalize 16 | import random 17 | import json 18 | import pickle 19 | 20 | def parseCommand(): 21 | usage = "extract the features, and train the model, from the training set of fastq files. \n\npython training.py [options] " 22 | version = "0.0.1" 23 | parser = OptionParser(usage = usage, version = version) 24 | parser.add_option("-m", "--model", dest = "model_file", default = "cfdna.model", 25 | help = "specify which file to store the built model.") 26 | parser.add_option("-a", "--algorithm", dest = "algorithm", default = "knn", 27 | help = "specify which algorithm to use for classfication, candidates are svm/knn/rbf/rf/gnb/benchmark, rbf means svm using rbf kernel, rf means random forest, gnb means Gaussian Naive Bayes, benchmark will try every algorithm and plot the score figure, default is knn.") 28 | parser.add_option("-c", "--cfdna_flag", dest = "cfdna_flag", default = "cfdna", 29 | help = "specify the filename flag of cfdna files, separated by semicolon. default is: cfdna") 30 | parser.add_option("-o", "--other_flag", dest = "other_flag", default = "gdna;ffpe", 31 | help = "specify the filename flag of other files, separated by semicolon. default is: gdna;ffpe") 32 | parser.add_option("-p", "--passes", dest = "passes", type="int", default = 100, 33 | help = "specify how many passes to do training and validating, default is 10.") 34 | parser.add_option("-n", "--no_cache_check", dest = "no_cache_check", action='store_true', default = False, 35 | help = "if the cache file exists, use it without checking the identity with input files") 36 | return parser.parse_args() 37 | 38 | def is_file_type(filename, file_flags): 39 | for flag in file_flags: 40 | if flag.lower().strip() in filename.lower(): 41 | return True 42 | return False 43 | 44 | def preprocess(options): 45 | cfdna_flags = options.cfdna_flag.split(";") 46 | other_flags = options.other_flag.split(";") 47 | print("cfdna file flags (-c ): " + ";".join(cfdna_flags)) 48 | print("other file flags (-o ): " + ";".join(other_flags)) 49 | 50 | data = [] 51 | label = [] 52 | samples = [] 53 | fq_files = get_arg_files() 54 | 55 | # try to load from cache.json 56 | json_file_name = "cache.json" 57 | if os.path.exists(json_file_name) and os.access(json_file_name, os.R_OK): 58 | json_file = open(json_file_name, "r") 59 | json_loaded = json.loads(json_file.read()) 60 | print("\nfound feature cache (cache.json), loading it now...") 61 | if options.no_cache_check or len(json_loaded["fq_files"]) == len(fq_files): 62 | data = json_loaded["data"] 63 | label = json_loaded["label"] 64 | samples = json_loaded["samples"] 65 | print("feature cache is valid, if you want to do feature extraction again, delete cache.json") 66 | return data, label, samples 67 | else: 68 | print("cache is invalid") 69 | 70 | # cannot load from cache.json, we compute it 71 | print("\nextracting features...") 72 | number = 0 73 | for fq in fq_files: 74 | if is_file_type(fq, cfdna_flags) == False and is_file_type(fq, other_flags) == False: 75 | continue 76 | number += 1 77 | print(str(number) + ": " + fq) 78 | 79 | extractor = FeatureExtractor(fq) 80 | extractor.extract() 81 | feature = extractor.feature() 82 | 83 | if feature == None: 84 | print("======== Warning: bad feature from:") 85 | print(fq) 86 | continue 87 | 88 | if has_adapter_sequenced(feature): 89 | continue 90 | 91 | if is_file_type(fq, cfdna_flags): 92 | data.append(feature) 93 | label.append(1) 94 | elif is_file_type(fq, other_flags): 95 | data.append(feature) 96 | label.append(0) 97 | else: 98 | continue 99 | samples.append(fq) 100 | 101 | if len(samples)<=2: 102 | return data, label, samples 103 | 104 | # save the data, label and samples to cache.json to speed up the training test 105 | try: 106 | json_file = open(json_file_name, "w") 107 | except Exception: 108 | return data, label, samples 109 | if os.access(json_file_name, os.W_OK): 110 | json_store = {} 111 | json_store["data"]=data 112 | json_store["label"]=label 113 | json_store["samples"]=samples 114 | json_store["fq_files"]=fq_files 115 | print("\nsave to cache.json") 116 | json_str = json.dumps(json_store) 117 | json_file.write(json_str) 118 | json_file.close() 119 | 120 | return data, label, samples 121 | 122 | def bootstrap_split(data, label, samples): 123 | training_set = {"data":[], "label":[], "samples":[]} 124 | validation_set = {"data":[], "label":[], "samples":[]} 125 | total_num = len(data) 126 | 127 | # we should make sure the training set contains both positive and negative samples 128 | while( len(np.unique(training_set["label"])) <= 1 ): 129 | training_ids = np.random.choice(total_num, size = total_num, replace=True) 130 | training_set_percentage = len(np.unique(training_ids)) / float(total_num) 131 | print("bootstrap sampling: " + str(training_set_percentage) + " trainning set, " + str(1.0-training_set_percentage) + " validating set") 132 | training_set["data"] = [] 133 | training_set["label"] = [] 134 | training_set["samples"] = [] 135 | validation_set["data"] = [] 136 | validation_set["label"] = [] 137 | validation_set["samples"] = [] 138 | for i in xrange(total_num): 139 | if i in training_ids: 140 | training_set["data"].append(data[i]) 141 | training_set["label"].append(label[i]) 142 | training_set["samples"].append(samples[i]) 143 | else: 144 | validation_set["data"].append(data[i]) 145 | validation_set["label"].append(label[i]) 146 | validation_set["samples"].append(samples[i]) 147 | 148 | return training_set, validation_set 149 | 150 | def train(model, data, label, samples, options, benchmark = False): 151 | print("\ntraining and validating for " + str(options.passes) + " times...") 152 | total_score = 0 153 | scores = [] 154 | wrong_files = [] 155 | wrong_data = [] 156 | for i in xrange(options.passes): 157 | print(str(i+1) + " / " + str(options.passes)); 158 | training_set, validation_set = bootstrap_split(data, label, samples) 159 | model.fit(np.array(training_set["data"]), np.array(training_set["label"])) 160 | # get scores 161 | score_train = model.score(np.array(training_set["data"]), np.array(training_set["label"])) 162 | score_boot = model.score(np.array(validation_set["data"]), np.array(validation_set["label"])) 163 | score = 0.632 * score_boot + 0.368 * score_train 164 | total_score += score 165 | scores.append(score) 166 | 167 | # get wrongly predicted elements 168 | arr = np.array(validation_set["data"]) 169 | for v in xrange(len(validation_set["data"])): 170 | result = model.predict(arr[v:v+1]) 171 | if result[0] != validation_set["label"][v]: 172 | #print("Truth: " + str(validation_set["label"][v]) + ", predicted: " + str(result[0]) + ": " + validation_set["samples"][v]) 173 | if validation_set["samples"][v] not in wrong_files: 174 | wrong_files.append(validation_set["samples"][v]) 175 | wrong_data.append(validation_set["data"][v]) 176 | if not benchmark: 177 | print("scores of all " + str(options.passes) + " passes:") 178 | print(scores) 179 | print("\nscore mean: " + str(np.mean(scores)) + ", std: " + str(np.std(scores))) 180 | print("\n" + str(len(wrong_files)) + " files with at least 1 wrong prediction:") 181 | print(" ".join(wrong_files)) 182 | 183 | print("\nplotting figures for files with wrong predictions...") 184 | plot_data_list(wrong_files, wrong_data, "train_fig") 185 | 186 | save_model(model, options) 187 | return sorted(scores, reverse=True) 188 | 189 | def save_model(model, options): 190 | print("\nsave model to: " + options.model_file) 191 | try: 192 | f = open(options.model_file, "wb") 193 | pickle.dump(model, f, True) 194 | except Exception: 195 | print("failed to write file") 196 | 197 | def main(): 198 | time1 = time.time() 199 | if sys.version_info.major >2: 200 | print('python3 is not supported yet, please use python2') 201 | sys.exit(1) 202 | 203 | (options, args) = parseCommand() 204 | 205 | data, label, samples = preprocess(options) 206 | 207 | if len(data) == 0: 208 | print("no enough training data, usage:\n\tpython training.py \twildcard(*) is supported\n") 209 | sys.exit(1) 210 | elif len(np.unique(label)) < 2: 211 | if np.unique(label) == 0: 212 | print("no cfdna training data") 213 | else: 214 | print("no gdna training data") 215 | sys.exit(1) 216 | 217 | if options.algorithm.lower() == "svm": 218 | model = svm.LinearSVC(penalty='l2', dual=True, tol=1e-4, max_iter=1000) 219 | train(model, data, label, samples, options) 220 | elif options.algorithm.lower() == "knn": 221 | model = neighbors.KNeighborsClassifier(n_neighbors=8, weights='distance', leaf_size=100) 222 | train(model, data, label, samples, options) 223 | elif options.algorithm.lower() == "rf": 224 | model = RandomForestClassifier(n_estimators=20) 225 | train(model, data, label, samples, options) 226 | elif options.algorithm.lower() == "rbf": 227 | model = svm.SVC(kernel='rbf') 228 | train(model, data, label, samples, options) 229 | elif options.algorithm.lower() == "gnb": 230 | model = GaussianNB() 231 | train(model, data, label, samples, options) 232 | elif options.algorithm.lower() == "benchmark": 233 | print("\nstarting benchmark...") 234 | names = ["KNN", "Random Forest","SVM Linear Kernel", "Gaussian Naive Bayes", "SVM RBF Kernel"] 235 | models = [neighbors.KNeighborsClassifier(n_neighbors=8, weights='distance', leaf_size=100), 236 | RandomForestClassifier(n_estimators=20), 237 | svm.LinearSVC(penalty='l2', dual=True, tol=1e-4, max_iter=1000), 238 | GaussianNB(), 239 | svm.SVC(kernel='rbf')] 240 | scores_arr = [] 241 | for m in xrange(len(models)): 242 | print("\nbenchmark with: " + names[m]) 243 | scores_arr.append(train(models[m], data, label, samples, options, True)) 244 | for m in xrange(len(models)): 245 | print("\nbenchmark mean score with: " + names[m] + " mean " + str(np.mean(scores_arr[m])) + ", std " + str(np.std(scores_arr[m]))) 246 | print("\nploting benchmark result...") 247 | plot_benchmark(scores_arr, names, "benchmark.png") 248 | else: 249 | print("algorithm " + options.algorithm + " is not supported, please use svm/knn") 250 | 251 | time2 = time.time() 252 | print('\nTime used: ' + str(time2-time1)) 253 | 254 | if __name__ == "__main__": 255 | main() 256 | --------------------------------------------------------------------------------