├── PredictTxInfo.py ├── README.md ├── TxInfo.py ├── cassandra_rw.py ├── credit.csv ├── credit_test.json ├── flask_ml_api.py ├── kafka-json-model.py ├── kafka-json-producer.py ├── main.py └── ml_model.py /PredictTxInfo.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from cassandra.cqlengine import columns 3 | from cassandra.cqlengine.models import Model 4 | 5 | class PredictTxInfoModel(Model): 6 | tx_id = columns.UUID(primary_key=True, default=uuid.uuid4) 7 | Time = columns.Integer(index=True) 8 | V1 = columns.Float() 9 | V2 = columns.Float() 10 | V3 = columns.Float() 11 | V4 = columns.Float() 12 | V5 = columns.Float() 13 | V6 = columns.Float() 14 | V7 = columns.Float() 15 | V8 = columns.Float() 16 | V9 = columns.Float() 17 | V10 = columns.Float() 18 | V11 = columns.Float() 19 | V12 = columns.Float() 20 | V13 = columns.Float() 21 | V14 = columns.Float() 22 | V15 = columns.Float() 23 | V16 = columns.Float() 24 | V17 = columns.Float() 25 | V18 = columns.Float() 26 | V19 = columns.Float() 27 | V20 = columns.Float() 28 | V21 = columns.Float() 29 | V22 = columns.Float() 30 | V23 = columns.Float() 31 | V24 = columns.Float() 32 | V25 = columns.Float() 33 | V26 = columns.Float() 34 | V27 = columns.Float() 35 | V28 = columns.Float() 36 | Amount = columns.Float(required=False) 37 | P = columns.Integer(index=True) 38 | 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ai-project-fraud-detection 2 | -------------------------------------------------------------------------------- /TxInfo.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from cassandra.cqlengine import columns 3 | from cassandra.cqlengine.models import Model 4 | 5 | class TxInfoModel(Model): 6 | tx_id = columns.UUID(primary_key=True, default=uuid.uuid4) 7 | Time = columns.Integer(index=True) 8 | V1 = columns.Float() 9 | V2 = columns.Float() 10 | V3 = columns.Float() 11 | V4 = columns.Float() 12 | V5 = columns.Float() 13 | V6 = columns.Float() 14 | V7 = columns.Float() 15 | V8 = columns.Float() 16 | V9 = columns.Float() 17 | V10 = columns.Float() 18 | V11 = columns.Float() 19 | V12 = columns.Float() 20 | V13 = columns.Float() 21 | V14 = columns.Float() 22 | V15 = columns.Float() 23 | V16 = columns.Float() 24 | V17 = columns.Float() 25 | V18 = columns.Float() 26 | V19 = columns.Float() 27 | V20 = columns.Float() 28 | V21 = columns.Float() 29 | V22 = columns.Float() 30 | V23 = columns.Float() 31 | V24 = columns.Float() 32 | V25 = columns.Float() 33 | V26 = columns.Float() 34 | V27 = columns.Float() 35 | V28 = columns.Float() 36 | Amount = columns.Float(required=False) 37 | C = columns.Integer(index=True) 38 | 39 | -------------------------------------------------------------------------------- /cassandra_rw.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from cassandra.cqlengine import connection 3 | from datetime import datetime 4 | from cassandra.cqlengine.management import sync_table 5 | import csv 6 | from TxInfo import TxInfoModel 7 | import pandas as pd 8 | 9 | class CassandraReadWriteDb: 10 | 11 | def __init__(self, ip_addrs, keyspace): 12 | connection.setup( ip_addrs, keyspace, protocol_version=3) 13 | 14 | def sync_class_table(self, typeOfClass): 15 | self.typeOfClass = typeOfClass 16 | sync_table(typeOfClass) 17 | 18 | #Write CSV to cassandra 19 | def write_file_table(self, credit_logs): 20 | with open(credit_logs) as csv_file: 21 | csv_reader = csv.DictReader(csv_file, delimiter=',') 22 | for row in csv_reader: 23 | self.typeOfClass.create(**dict(row)) 24 | 25 | #Read Cassandra data to pandas 26 | def get_pandas_from_cassandra(self): 27 | tx_info = pd.DataFrame() 28 | 29 | for q in TxInfoModel.objects(): 30 | d = pd.DataFrame.from_records([q.values()]) 31 | tx_info = tx_info.append(d) 32 | 33 | tx_info.columns = q.keys() 34 | return tx_info 35 | 36 | def write_json_table(self, data): 37 | print (data) 38 | self.typeOfClass.create(**dict(data)) 39 | 40 | 41 | 42 | if __name__ == '__main__': 43 | 44 | cwd = CassandraReadWriteDb(ip_addrs=['172.17.0.2'], keyspace="emp") 45 | cwd.sync_class_table(TxInfoModel) 46 | #cwd.write_file_table('credit.csv') 47 | print(cwd.get_pandas_from_cassandra()) 48 | -------------------------------------------------------------------------------- /credit.csv: -------------------------------------------------------------------------------- 1 | "Time","V1","V2","V3","V4","V5","V6","V7","V8","V9","V10","V11","V12","V13","V14","V15","V16","V17","V18","V19","V20","V21","V22","V23","V24","V25","V26","V27","V28","Amount","C" 2 | 0,-1.3598071336738,-0.0727811733098497,2.53634673796914,1.37815522427443,-0.338320769942518,0.462387777762292,0.239598554061257,0.0986979012610507,0.363786969611213,0.0907941719789316,-0.551599533260813,-0.617800855762348,-0.991389847235408,-0.311169353699879,1.46817697209427,-0.470400525259478,0.207971241929242,0.0257905801985591,0.403992960255733,0.251412098239705,-0.018306777944153,0.277837575558899,-0.110473910188767,0.0669280749146731,0.128539358273528,-0.189114843888824,0.133558376740387,-0.0210530534538215,149.62,"0" 3 | 0,1.19185711131486,0.26615071205963,0.16648011335321,0.448154078460911,0.0600176492822243,-0.0823608088155687,-0.0788029833323113,0.0851016549148104,-0.255425128109186,-0.166974414004614,1.61272666105479,1.06523531137287,0.48909501589608,-0.143772296441519,0.635558093258208,0.463917041022171,-0.114804663102346,-0.183361270123994,-0.145783041325259,-0.0690831352230203,-0.225775248033138,-0.638671952771851,0.101288021253234,-0.339846475529127,0.167170404418143,0.125894532368176,-0.00898309914322813,0.0147241691924927,2.69,"0" 4 | 1,-1.35835406159823,-1.34016307473609,1.77320934263119,0.379779593034328,-0.503198133318193,1.80049938079263,0.791460956450422,0.247675786588991,-1.51465432260583,0.207642865216696,0.624501459424895,0.066083685268831,0.717292731410831,-0.165945922763554,2.34586494901581,-2.89008319444231,1.10996937869599,-0.121359313195888,-2.26185709530414,0.524979725224404,0.247998153469754,0.771679401917229,0.909412262347719,-0.689280956490685,-0.327641833735251,-0.139096571514147,-0.0553527940384261,-0.0597518405929204,378.66,"0" 5 | 1,-0.966271711572087,-0.185226008082898,1.79299333957872,-0.863291275036453,-0.0103088796030823,1.24720316752486,0.23760893977178,0.377435874652262,-1.38702406270197,-0.0549519224713749,-0.226487263835401,0.178228225877303,0.507756869957169,-0.28792374549456,-0.631418117709045,-1.0596472454325,-0.684092786345479,1.96577500349538,-1.2326219700892,-0.208037781160366,-0.108300452035545,0.00527359678253453,-0.190320518742841,-1.17557533186321,0.647376034602038,-0.221928844458407,0.0627228487293033,0.0614576285006353,123.5,"0" 6 | 2,-1.15823309349523,0.877736754848451,1.548717846511,0.403033933955121,-0.407193377311653,0.0959214624684256,0.592940745385545,-0.270532677192282,0.817739308235294,0.753074431976354,-0.822842877946363,0.53819555014995,1.3458515932154,-1.11966983471731,0.175121130008994,-0.451449182813529,-0.237033239362776,-0.0381947870352842,0.803486924960175,0.408542360392758,-0.00943069713232919,0.79827849458971,-0.137458079619063,0.141266983824769,-0.206009587619756,0.502292224181569,0.219422229513348,0.215153147499206,69.99,"0" 7 | 2,-0.425965884412454,0.960523044882985,1.14110934232219,-0.168252079760302,0.42098688077219,-0.0297275516639742,0.476200948720027,0.260314333074874,-0.56867137571251,-0.371407196834471,1.34126198001957,0.359893837038039,-0.358090652573631,-0.137133700217612,0.517616806555742,0.401725895589603,-0.0581328233640131,0.0686531494425432,-0.0331937877876282,0.0849676720682049,-0.208253514656728,-0.559824796253248,-0.0263976679795373,-0.371426583174346,-0.232793816737034,0.105914779097957,0.253844224739337,0.0810802569229443,3.67,"0" 8 | 4,1.22965763450793,0.141003507049326,0.0453707735899449,1.20261273673594,0.191880988597645,0.272708122899098,-0.00515900288250983,0.0812129398830894,0.464959994783886,-0.0992543211289237,-1.41690724314928,-0.153825826253651,-0.75106271556262,0.16737196252175,0.0501435942254188,-0.443586797916727,0.00282051247234708,-0.61198733994012,-0.0455750446637976,-0.21963255278686,-0.167716265815783,-0.270709726172363,-0.154103786809305,-0.780055415004671,0.75013693580659,-0.257236845917139,0.0345074297438413,0.00516776890624916,4.99,"0" 9 | 7,-0.644269442348146,1.41796354547385,1.0743803763556,-0.492199018495015,0.948934094764157,0.428118462833089,1.12063135838353,-3.80786423873589,0.615374730667027,1.24937617815176,-0.619467796121913,0.291474353088705,1.75796421396042,-1.32386521970526,0.686132504394383,-0.0761269994382006,-1.2221273453247,-0.358221569869078,0.324504731321494,-0.156741852488285,1.94346533978412,-1.01545470979971,0.057503529867291,-0.649709005559993,-0.415266566234811,-0.0516342969262494,-1.20692108094258,-1.08533918832377,40.8,"1" 10 | 7,-0.89428608220282,0.286157196276544,-0.113192212729871,-0.271526130088604,2.6695986595986,3.72181806112751,0.370145127676916,0.851084443200905,-0.392047586798604,-0.410430432848439,-0.705116586646536,-0.110452261733098,-0.286253632470583,0.0743553603016731,-0.328783050303565,-0.210077268148783,-0.499767968800267,0.118764861004217,0.57032816746536,0.0527356691149697,-0.0734251001059225,-0.268091632235551,-0.204232669947878,1.0115918018785,0.373204680146282,-0.384157307702294,0.0117473564581996,0.14240432992147,93.2,"1" 11 | -------------------------------------------------------------------------------- /credit_test.json: -------------------------------------------------------------------------------- 1 | [{"Time":0,"V1":-1.3598071337,"V2":-0.0727811733,"V3":2.536346738,"V4":1.3781552243,"V5":-0.3383207699,"V6":0.4623877778,"V7":0.2395985541,"V8":0.0986979013,"V9":0.3637869696,"V10":0.090794172,"V11":-0.5515995333,"V12":-0.6178008558,"V13":-0.9913898472,"V14":-0.3111693537,"V15":1.4681769721,"V16":-0.4704005253,"V17":0.2079712419,"V18":0.0257905802,"V19":0.4039929603,"V20":0.2514120982,"V21":-0.0183067779,"V22":0.2778375756,"V23":-0.1104739102,"V24":0.0669280749,"V25":0.1285393583,"V26":-0.1891148439,"V27":0.1335583767,"V28":-0.0210530535,"Amount":149.62},{"Time":0,"V1":1.1918571113,"V2":0.2661507121,"V3":0.1664801134,"V4":0.4481540785,"V5":0.0600176493,"V6":-0.0823608088,"V7":-0.0788029833,"V8":0.0851016549,"V9":-0.2554251281,"V10":-0.166974414,"V11":1.6127266611,"V12":1.0652353114,"V13":0.4890950159,"V14":-0.1437722964,"V15":0.6355580933,"V16":0.463917041,"V17":-0.1148046631,"V18":-0.1833612701,"V19":-0.1457830413,"V20":-0.0690831352,"V21":-0.225775248,"V22":-0.6386719528,"V23":0.1012880213,"V24":-0.3398464755,"V25":0.1671704044,"V26":0.1258945324,"V27":-0.0089830991,"V28":0.0147241692,"Amount":2.69},{"Time":1,"V1":-1.3583540616,"V2":-1.3401630747,"V3":1.7732093426,"V4":0.379779593,"V5":-0.5031981333,"V6":1.8004993808,"V7":0.7914609565,"V8":0.2476757866,"V9":-1.5146543226,"V10":0.2076428652,"V11":0.6245014594,"V12":0.0660836853,"V13":0.7172927314,"V14":-0.1659459228,"V15":2.345864949,"V16":-2.8900831944,"V17":1.1099693787,"V18":-0.1213593132,"V19":-2.2618570953,"V20":0.5249797252,"V21":0.2479981535,"V22":0.7716794019,"V23":0.9094122623,"V24":-0.6892809565,"V25":-0.3276418337,"V26":-0.1390965715,"V27":-0.055352794,"V28":-0.0597518406,"Amount":378.66},{"Time":1,"V1":-0.9662717116,"V2":-0.1852260081,"V3":1.7929933396,"V4":-0.863291275,"V5":-0.0103088796,"V6":1.2472031675,"V7":0.2376089398,"V8":0.3774358747,"V9":-1.3870240627,"V10":-0.0549519225,"V11":-0.2264872638,"V12":0.1782282259,"V13":0.50775687,"V14":-0.2879237455,"V15":-0.6314181177,"V16":-1.0596472454,"V17":-0.6840927863,"V18":1.9657750035,"V19":-1.2326219701,"V20":-0.2080377812,"V21":-0.108300452,"V22":0.0052735968,"V23":-0.1903205187,"V24":-1.1755753319,"V25":0.6473760346,"V26":-0.2219288445,"V27":0.0627228487,"V28":0.0614576285,"Amount":123.5},{"Time":2,"V1":-1.1582330935,"V2":0.8777367548,"V3":1.5487178465,"V4":0.403033934,"V5":-0.4071933773,"V6":0.0959214625,"V7":0.5929407454,"V8":-0.2705326772,"V9":0.8177393082,"V10":0.753074432,"V11":-0.8228428779,"V12":0.5381955501,"V13":1.3458515932,"V14":-1.1196698347,"V15":0.17512113,"V16":-0.4514491828,"V17":-0.2370332394,"V18":-0.038194787,"V19":0.803486925,"V20":0.4085423604,"V21":-0.0094306971,"V22":0.7982784946,"V23":-0.1374580796,"V24":0.1412669838,"V25":-0.2060095876,"V26":0.5022922242,"V27":0.2194222295,"V28":0.2151531475,"Amount":69.99},{"Time":2,"V1":-0.4259658844,"V2":0.9605230449,"V3":1.1411093423,"V4":-0.1682520798,"V5":0.4209868808,"V6":-0.0297275517,"V7":0.4762009487,"V8":0.2603143331,"V9":-0.5686713757,"V10":-0.3714071968,"V11":1.34126198,"V12":0.359893837,"V13":-0.3580906526,"V14":-0.1371337002,"V15":0.5176168066,"V16":0.4017258956,"V17":-0.0581328234,"V18":0.0686531494,"V19":-0.0331937878,"V20":0.0849676721,"V21":-0.2082535147,"V22":-0.5598247963,"V23":-0.026397668,"V24":-0.3714265832,"V25":-0.2327938167,"V26":0.1059147791,"V27":0.2538442247,"V28":0.0810802569,"Amount":3.67},{"Time":4,"V1":1.2296576345,"V2":0.141003507,"V3":0.0453707736,"V4":1.2026127367,"V5":0.1918809886,"V6":0.2727081229,"V7":-0.0051590029,"V8":0.0812129399,"V9":0.4649599948,"V10":-0.0992543211,"V11":-1.4169072431,"V12":-0.1538258263,"V13":-0.7510627156,"V14":0.1673719625,"V15":0.0501435942,"V16":-0.4435867979,"V17":0.0028205125,"V18":-0.6119873399,"V19":-0.0455750447,"V20":-0.2196325528,"V21":-0.1677162658,"V22":-0.2707097262,"V23":-0.1541037868,"V24":-0.780055415,"V25":0.7501369358,"V26":-0.2572368459,"V27":0.0345074297,"V28":0.0051677689,"Amount":4.99},{"Time":7,"V1":-0.6442694423,"V2":1.4179635455,"V3":1.0743803764,"V4":-0.4921990185,"V5":0.9489340948,"V6":0.4281184628,"V7":1.1206313584,"V8":-3.8078642387,"V9":0.6153747307,"V10":1.2493761782,"V11":-0.6194677961,"V12":0.2914743531,"V13":1.757964214,"V14":-1.3238652197,"V15":0.6861325044,"V16":-0.0761269994,"V17":-1.2221273453,"V18":-0.3582215699,"V19":0.3245047313,"V20":-0.1567418525,"V21":1.9434653398,"V22":-1.0154547098,"V23":0.0575035299,"V24":-0.6497090056,"V25":-0.4152665662,"V26":-0.0516342969,"V27":-1.2069210809,"V28":-1.0853391883,"Amount":40.8},{"Time":7,"V1":-0.8942860822,"V2":0.2861571963,"V3":-0.1131922127,"V4":-0.2715261301,"V5":2.6695986596,"V6":3.7218180611,"V7":0.3701451277,"V8":0.8510844432,"V9":-0.3920475868,"V10":-0.4104304328,"V11":-0.7051165866,"V12":-0.1104522617,"V13":-0.2862536325,"V14":0.0743553603,"V15":-0.3287830503,"V16":-0.2100772681,"V17":-0.4997679688,"V18":0.118764861,"V19":0.5703281675,"V20":0.0527356691,"V21":-0.0734251001,"V22":-0.2680916322,"V23":-0.2042326699,"V24":1.0115918019,"V25":0.3732046801,"V26":-0.3841573077,"V27":0.0117473565,"V28":0.1424043299,"Amount":93.2}] -------------------------------------------------------------------------------- /flask_ml_api.py: -------------------------------------------------------------------------------- 1 | #!flask/bin/python 2 | from flask import Flask, jsonify, request 3 | import joblib 4 | import pandas as pd 5 | from cassandra_rw import CassandraReadWriteDb 6 | from PredictTxInfo import PredictTxInfoModel 7 | 8 | app = Flask(__name__) 9 | cwd = CassandraReadWriteDb(ip_addrs=['172.17.0.2'], keyspace="emp") 10 | cwd.sync_class_table(PredictTxInfoModel) 11 | 12 | @app.route('/predict/tx', methods=['POST']) 13 | def create_task(): 14 | tx_data = request.json 15 | df = pd.DataFrame.from_records([tx_data]) 16 | df = df.drop(['Time'],axis=1) 17 | model = joblib.load('model3.pipeline') 18 | tx_data['P'] = model.best_estimator_.predict(df)[0] 19 | cwd.write_json_table(tx_data) 20 | tx_data['P'] = str(tx_data['P']) 21 | return jsonify(tx_data), 201 22 | 23 | if __name__ == '__main__': 24 | app.run(debug=True) 25 | -------------------------------------------------------------------------------- /kafka-json-model.py: -------------------------------------------------------------------------------- 1 | import threading, logging, time 2 | import multiprocessing 3 | import json 4 | from kafka import KafkaConsumer, KafkaProducer 5 | import joblib 6 | import pandas as pd 7 | from cassandra_rw import CassandraReadWriteDb 8 | from PredictTxInfo import PredictTxInfoModel 9 | 10 | class Consumer(): 11 | def __init__(self): 12 | self.model = joblib.load('model3.pipeline') 13 | self.cwd = CassandraReadWriteDb(ip_addrs=['172.17.0.2'], keyspace="emp") 14 | self.cwd.sync_class_table(PredictTxInfoModel) 15 | 16 | 17 | def run(self): 18 | consumer = KafkaConsumer(bootstrap_servers='localhost:9092', 19 | auto_offset_reset='earliest', 20 | consumer_timeout_ms=1000, value_deserializer=lambda m: json.loads(m.decode('ascii'))) 21 | consumer.subscribe(['credit-card-tx']) 22 | 23 | while True: 24 | for message in consumer: 25 | df = pd.DataFrame.from_records([message.value]) 26 | df = df.drop(['Time'],axis=1) 27 | outcome = self.model.best_estimator_.predict(df)[0] 28 | message.value['P'] = outcome 29 | self.cwd.write_json_table(message.value) 30 | 31 | consumer.close() 32 | 33 | 34 | def main(): 35 | tasks = [ 36 | Consumer() 37 | ] 38 | 39 | for t in tasks: 40 | t.run() 41 | 42 | 43 | 44 | if __name__ == "__main__": 45 | logging.basicConfig( 46 | format='%(asctime)s.%(msecs)s:%(name)s:%(thread)d:%(levelname)s:%(process)d:%(message)s', 47 | level=logging.INFO 48 | ) 49 | main() 50 | -------------------------------------------------------------------------------- /kafka-json-producer.py: -------------------------------------------------------------------------------- 1 | import threading, logging, time 2 | import multiprocessing 3 | import json 4 | 5 | from kafka import KafkaProducer 6 | 7 | 8 | class Producer(threading.Thread): 9 | def __init__(self): 10 | threading.Thread.__init__(self) 11 | self.stop_event = threading.Event() 12 | 13 | def stop(self): 14 | self.stop_event.set() 15 | 16 | def run(self): 17 | producer = KafkaProducer(bootstrap_servers='localhost:9092', value_serializer=lambda m: json.dumps(m).encode('ascii')) 18 | 19 | with open('credit_test.json') as json_file: 20 | data = json.load(json_file) 21 | for p in data: 22 | print (p) 23 | producer.send('credit-card-tx', p) 24 | time.sleep(5) 25 | 26 | producer.close() 27 | 28 | 29 | def main(): 30 | tasks = [ 31 | Producer(), 32 | ] 33 | 34 | for t in tasks: 35 | t.start() 36 | 37 | time.sleep(10) 38 | 39 | for task in tasks: 40 | task.stop() 41 | 42 | for task in tasks: 43 | task.join() 44 | 45 | 46 | if __name__ == "__main__": 47 | logging.basicConfig( 48 | format='%(asctime)s.%(msecs)s:%(name)s:%(thread)d:%(levelname)s:%(process)d:%(message)s', 49 | level=logging.INFO 50 | ) 51 | main() 52 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from cassandra_rw import CassandraReadWriteDb 2 | from TxInfo import TxInfoModel 3 | from ml_model import BuildMlPipeline 4 | from sklearn.model_selection import train_test_split 5 | 6 | if __name__ == '__main__': 7 | 8 | cass_rw = CassandraReadWriteDb(ip_addrs=['172.17.0.2'], keyspace="emp") 9 | 10 | #Load data in cassandra from csv files 11 | cass_rw.sync_class_table(TxInfoModel) 12 | cass_rw.write_file_table('creditcard.csv') 13 | 14 | #Load cassandra data into pandas 15 | credit_data = cass_rw.get_pandas_from_cassandra() 16 | 17 | print ('Data loaded into dataframe') 18 | 19 | #Create models 20 | ml_pipeline = BuildMlPipeline() 21 | ml_pipeline.set_estimators('sgdClassifier','randomForestClassifier') 22 | ml_pipeline.set_scalers('standardscaler') 23 | ml_pipeline.set_samplers('smote','smoteenn') 24 | ml_pipeline.create_pipelines() 25 | 26 | #Hyperparameter Configuration 27 | params_dict = {} 28 | params_dict['smote'] = {'smote__k_neighbors':[5,10,15]} 29 | params_dict['smoteenn'] = {'smoteenn__sampling_strategy':['auto','all','not majority']} 30 | params_dict['randomforestclassifier'] = {'randomforestclassifier__n_estimators':[8,12]} 31 | params_dict['svc'] = {'svc__kernel':['linear','rbf','poly'],'svc__C':[.1,1,10]} 32 | ml_pipeline.set_hyperparameters(params_dict) 33 | 34 | #credit_data = credit_data.sample(10000) 35 | 36 | X = credit_data.drop(['tx_id','Time','C'],axis=1) 37 | y = credit_data.C 38 | trainX, testX, trainY, testY = train_test_split(X,y) 39 | 40 | print ('Model Training') 41 | 42 | #model training 43 | ml_pipeline.fit(trainX,trainY) 44 | 45 | #Calculating model performance 46 | ml_pipeline.score(testX,testY) 47 | 48 | 49 | -------------------------------------------------------------------------------- /ml_model.py: -------------------------------------------------------------------------------- 1 | from sklearn.model_selection import GridSearchCV 2 | from sklearn.metrics import confusion_matrix 3 | 4 | from sklearn.ensemble import RandomForestClassifier 5 | from sklearn.linear_model import SGDClassifier 6 | from sklearn.svm import SVC 7 | 8 | from sklearn.preprocessing import StandardScaler, MinMaxScaler 9 | from joblib import dump, load 10 | 11 | from imblearn.pipeline import make_pipeline 12 | from imblearn.over_sampling import SMOTE 13 | from imblearn.combine import SMOTEENN 14 | 15 | 16 | class BuildMlPipeline: 17 | 18 | def __init__(self): 19 | pass 20 | 21 | def set_estimators(self, *args): 22 | estimator_db = { 23 | 'randomForestClassifier': RandomForestClassifier(), 24 | 'svc': SVC(), 25 | 'sgdClassifier': SGDClassifier(), 26 | } 27 | self.estimators = list(map( lambda algo: estimator_db[algo],args)) 28 | 29 | def set_scalers(self, *args): 30 | scaler_db = { 31 | 'standardscaler':StandardScaler(), 32 | 'minmaxscaler':MinMaxScaler(), 33 | } 34 | self.scalers = list(map( lambda scaler: scaler_db[scaler],args)) 35 | 36 | def set_samplers(self, *args): 37 | sampler_db = { 38 | 'smote':SMOTE(), 39 | 'smoteenn':SMOTEENN(), 40 | } 41 | self.samplers = list(map( lambda sampler: sampler_db[sampler],args)) 42 | 43 | def set_hyperparameters(self, params): 44 | self.hyperparameters = params 45 | 46 | 47 | def create_pipelines(self): 48 | self.model_pipelines = [] 49 | for estimator in self.estimators: 50 | for sampler in self.samplers: 51 | for scaler in self.scalers: 52 | pipeline = make_pipeline(scaler, sampler, estimator) 53 | self.model_pipelines.append(pipeline) 54 | 55 | 56 | def fit(self, trainX, trainY): 57 | self.gs_pipelines = [] 58 | for idx,pipeline in enumerate(self.model_pipelines): 59 | elems = list(map(lambda x:x[0] ,pipeline.steps)) 60 | param_grid = {} 61 | for elem in elems: 62 | if elem in self.hyperparameters: 63 | param_grid.update(self.hyperparameters[elem]) 64 | 65 | gs = GridSearchCV(pipeline, param_grid= param_grid, n_jobs=-1, cv=5) 66 | gs.fit(trainX, trainY) 67 | #dump(gs, 'model'+idx+'.pipeline') 68 | self.gs_pipelines.append(gs) 69 | 70 | 71 | def score(self, testX, testY): 72 | for idx,model in enumerate(self.gs_pipelines): 73 | y_pred = model.best_estimator_.predict(testX) 74 | print (model.best_estimator_) 75 | print (idx,confusion_matrix(y_true=testY,y_pred=y_pred)) 76 | 77 | 78 | import pandas as pd 79 | from sklearn.model_selection import train_test_split 80 | if __name__ == '__main__': 81 | 82 | ml_pipeline = BuildMlPipeline() 83 | ml_pipeline.set_estimators('randomForestClassifier') 84 | ml_pipeline.set_scalers('standardscaler') 85 | ml_pipeline.set_samplers('smote','smoteenn') 86 | ml_pipeline.create_pipelines() 87 | 88 | print (ml_pipeline.model_pipelines) 89 | 90 | params_dict = {} 91 | params_dict['smote'] = {'smote__k_neighbors':[5,10,15]} 92 | params_dict['smoteenn'] = {'smoteenn__sampling_strategy':['auto','all','not majority']} 93 | params_dict['randomforestclassifier'] = {'randomforestclassifier__n_estimators':[8,12]} 94 | params_dict['svc'] = {'svc__kernel':['linear','rbf','poly'],'svc__C':[.1,1,10]} 95 | 96 | ml_pipeline.set_hyperparameters(params_dict) 97 | 98 | credit_data = pd.read_csv('creditcard.csv').sample(20000) 99 | X = credit_data.drop(['Time','C'],axis=1) 100 | y = credit_data.C 101 | trainX, testX, trainY, testY = train_test_split(X,y) 102 | ml_pipeline.fit(X,y) 103 | ml_pipeline.score(testX,testY) 104 | 105 | --------------------------------------------------------------------------------