├── Help.pdf ├── Image Style Transfer Using Convolutional Neural Network.py ├── README.md └── logo.py /Help.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShaharAssenheim/Image-Style-Transfer-Using-Convolutional-Neural-Network/b1a0ea197248e77cdb3d9acf8f586d5fb0666a73/Help.pdf -------------------------------------------------------------------------------- /Image Style Transfer Using Convolutional Neural Network.py: -------------------------------------------------------------------------------- 1 | """Image Style Transfer Using Convolutional Neural Network 2 | code Written in python, Ui made with PyQt5""" 3 | 4 | from PyQt5 import QtCore, QtGui, QtWidgets 5 | from PyQt5.QtCore import QThread, pyqtSignal 6 | import logo 7 | import threading 8 | 9 | # global variables created to control the UI and code parameters. 10 | global content_path 11 | global style_path 12 | global outputImage 13 | global pixmap 14 | global exitflag 15 | exitflag=0 16 | global flag1 17 | flag1=0 18 | global flag2 19 | flag2=0 20 | global flag3 21 | flag3=0 22 | global count 23 | count=0 24 | global iter 25 | iter = 0 26 | 27 | """Ui_MainWindow is the main class of the UI, 28 | all UI parameters and code functions defined here.""" 29 | class Ui_MainWindow(object): 30 | def setupUi(self, MainWindow): 31 | MainWindow.setObjectName("MainWindow") 32 | MainWindow.setFixedSize(946,600) 33 | MainWindow.setStyleSheet("font: 75 22pt \"MS Shell Dlg 2\";") 34 | self.centralwidget = QtWidgets.QWidget(MainWindow) 35 | self.centralwidget.setObjectName("centralwidget") 36 | self.label = QtWidgets.QLabel(self.centralwidget) 37 | self.label.setGeometry(QtCore.QRect(0, 0, 946, 800)) 38 | self.label.setText("") 39 | self.label.setPixmap(QtGui.QPixmap(":/logo/grey-background-v1.jpg")) 40 | self.label.setScaledContents(True) 41 | self.label.setObjectName("label") 42 | self.mainlog = QtWidgets.QLabel(self.centralwidget) 43 | self.mainlog.setGeometry(QtCore.QRect(290, 80, 411, 251)) 44 | self.mainlog.setText("") 45 | self.mainlog.setPixmap(QtGui.QPixmap(":/logo/logo.png")) 46 | self.mainlog.setScaledContents(True) 47 | self.mainlog.setObjectName("mainlog") 48 | self.About = QtWidgets.QLabel(self.centralwidget) 49 | self.About.setGeometry(QtCore.QRect(20, 0, 711, 501)) 50 | self.About.setText("") 51 | self.About.setPixmap(QtGui.QPixmap(":/logo/AOUT.png")) 52 | self.About.setScaledContents(True) 53 | self.About.setObjectName("About") 54 | self.About.hide() 55 | self.smalllogo = QtWidgets.QLabel(self.centralwidget) 56 | self.smalllogo.setGeometry(QtCore.QRect(440, 480, 91, 51)) 57 | self.smalllogo.setText("") 58 | self.smalllogo.setPixmap(QtGui.QPixmap(":/logo/logo.png")) 59 | self.smalllogo.setScaledContents(True) 60 | self.smalllogo.setObjectName("smalllogo") 61 | self.smalllogo.hide() 62 | self.contentbutton = QtWidgets.QPushButton(self.centralwidget) 63 | self.contentbutton.setGeometry(QtCore.QRect(60, 40, 151, 41)) 64 | self.contentbutton.setStyleSheet("font: 75 12pt \"News706 BT\";") 65 | self.contentbutton.setObjectName("contentbutton") 66 | self.contentbutton.hide() 67 | self.stylebutton = QtWidgets.QPushButton(self.centralwidget) 68 | self.stylebutton.setGeometry(QtCore.QRect(400, 40, 151, 41)) 69 | self.stylebutton.setStyleSheet("font: 75 12pt \"News706 BT\";") 70 | self.stylebutton.setObjectName("stylebutton") 71 | self.stylebutton.hide() 72 | self.generatebutton = QtWidgets.QPushButton(self.centralwidget) 73 | self.generatebutton.setGeometry(QtCore.QRect(400, 400, 151, 41)) 74 | self.generatebutton.setStyleSheet("font: 75 12pt \"News706 BT\";") 75 | self.generatebutton.setObjectName("generatebutton") 76 | self.generatebutton.hide() 77 | self.progressBar = QtWidgets.QProgressBar(self.centralwidget) 78 | self.progressBar.setGeometry(QtCore.QRect(280, 400, 491, 31)) 79 | self.progressBar.setProperty("value", 24) 80 | self.progressBar.setObjectName("progressBar") 81 | self.progressBar.hide() 82 | self.contentframe = QtWidgets.QLabel(self.centralwidget) 83 | self.contentframe.setGeometry(QtCore.QRect(10, 90, 251, 191)) 84 | self.contentframe.setFrameShape(QtWidgets.QFrame.NoFrame) 85 | self.contentframe.setText("") 86 | self.contentframe.setPixmap(QtGui.QPixmap(":/logo/image.png")) 87 | self.contentframe.setScaledContents(True) 88 | self.contentframe.setObjectName("contentframe") 89 | self.contentframe.hide() 90 | self.styleframe = QtWidgets.QLabel(self.centralwidget) 91 | self.styleframe.setGeometry(QtCore.QRect(350, 90, 251, 191)) 92 | self.styleframe.setFrameShape(QtWidgets.QFrame.NoFrame) 93 | self.styleframe.setText("") 94 | self.styleframe.setPixmap(QtGui.QPixmap(":/logo/image.png")) 95 | self.styleframe.setScaledContents(True) 96 | self.styleframe.setObjectName("styleframe") 97 | self.styleframe.hide() 98 | self.outputframe = QtWidgets.QLabel(self.centralwidget) 99 | self.outputframe.setGeometry(QtCore.QRect(680, 90, 251, 191)) 100 | self.outputframe.setFrameShape(QtWidgets.QFrame.NoFrame) 101 | self.outputframe.setText("") 102 | self.outputframe.setPixmap(QtGui.QPixmap(":/logo/qustion.png")) 103 | self.outputframe.setScaledContents(True) 104 | self.outputframe.setObjectName("outputframe") 105 | self.outputframe.hide() 106 | self.savebutton = QtWidgets.QPushButton(self.centralwidget) 107 | self.savebutton.setGeometry(QtCore.QRect(730, 40, 151, 41)) 108 | self.savebutton.setStyleSheet("font: 75 12pt \"News706 BT\";") 109 | self.savebutton.setObjectName("savebutton") 110 | self.savebutton.hide() 111 | self.comboBox = QtWidgets.QComboBox(self.centralwidget) 112 | self.comboBox.setGeometry(QtCore.QRect(280, 330, 121, 31)) 113 | self.comboBox.setStyleSheet("font: 75 14pt \"MS Shell Dlg 2\";") 114 | self.comboBox.setObjectName("comboBox") 115 | self.comboBox.addItem("") 116 | self.comboBox.addItem("") 117 | self.comboBox.addItem("") 118 | self.comboBox.hide() 119 | self.qualty = QtWidgets.QLabel(self.centralwidget) 120 | self.qualty.setGeometry(QtCore.QRect(170, 320, 111, 41)) 121 | self.qualty.setObjectName("qualty") 122 | self.qualty.hide() 123 | self.when = QtWidgets.QLabel(self.centralwidget) 124 | self.when.setGeometry(QtCore.QRect(195, 340, 571, 121)) 125 | self.when.setText("") 126 | self.when.setPixmap(QtGui.QPixmap(":/logo/when.png")) 127 | self.when.setScaledContents(True) 128 | self.when.setObjectName("when") 129 | self.resbox = QtWidgets.QComboBox(self.centralwidget) 130 | self.resbox.setGeometry(QtCore.QRect(660, 330, 121, 31)) 131 | self.resbox.setStyleSheet("font: 75 14pt \"MS Shell Dlg 2\";") 132 | self.resbox.setObjectName("resbox") 133 | self.resbox.addItem("") 134 | self.resbox.addItem("") 135 | self.resbox.addItem("") 136 | self.resbox.addItem("") 137 | self.resbox.hide() 138 | self.res = QtWidgets.QLabel(self.centralwidget) 139 | self.res.setGeometry(QtCore.QRect(510, 320, 141, 41)) 140 | self.res.setObjectName("res") 141 | self.res.hide() 142 | self.warninglabel = QtWidgets.QLabel(self.centralwidget) 143 | self.warninglabel.setGeometry(QtCore.QRect(260, 390, 421, 61)) 144 | self.warninglabel.setText("") 145 | self.warninglabel.setPixmap(QtGui.QPixmap(":/logo/warning.png")) 146 | self.warninglabel.setScaledContents(True) 147 | self.warninglabel.setObjectName("warninglabel") 148 | self.warninglabel.hide() 149 | self.equalabel = QtWidgets.QLabel(self.centralwidget) 150 | self.equalabel.setGeometry(QtCore.QRect(610, 150, 61, 71)) 151 | self.equalabel.setText("") 152 | self.equalabel.setPixmap(QtGui.QPixmap(":/logo/equal.png")) 153 | self.equalabel.setScaledContents(True) 154 | self.equalabel.setObjectName("equalabel") 155 | self.equalabel.hide() 156 | self.pluslabel = QtWidgets.QLabel(self.centralwidget) 157 | self.pluslabel.setGeometry(QtCore.QRect(260, 140, 91, 81)) 158 | self.pluslabel.setText("") 159 | self.pluslabel.setPixmap(QtGui.QPixmap(":/logo/plus-big-512.png")) 160 | self.pluslabel.setScaledContents(True) 161 | self.pluslabel.setObjectName("pluslabel") 162 | self.pluslabel.hide() 163 | self.mainlog.raise_() 164 | self.About.raise_() 165 | self.when.raise_() 166 | self.progressBar.raise_() 167 | self.smalllogo.raise_() 168 | self.contentbutton.raise_() 169 | self.stylebutton.raise_() 170 | self.generatebutton.raise_() 171 | self.contentframe.raise_() 172 | self.styleframe.raise_() 173 | self.outputframe.raise_() 174 | self.savebutton.raise_() 175 | self.comboBox.raise_() 176 | self.qualty.raise_() 177 | self.resbox.raise_() 178 | self.res.raise_() 179 | self.warninglabel.raise_() 180 | self.equalabel.raise_() 181 | self.pluslabel.raise_() 182 | MainWindow.setCentralWidget(self.centralwidget) 183 | self.menubar = QtWidgets.QMenuBar(MainWindow) 184 | self.menubar.setGeometry(QtCore.QRect(0, 0, 946, 41)) 185 | self.menubar.setObjectName("menubar") 186 | self.menuHome = QtWidgets.QMenu(self.menubar) 187 | self.menuHome.setObjectName("menuHome") 188 | self.menuCreate_New = QtWidgets.QMenu(self.menubar) 189 | self.menuCreate_New.setObjectName("menuCreate_New") 190 | self.menuAbout = QtWidgets.QMenu(self.menubar) 191 | self.menuAbout.setObjectName("menuAbout") 192 | self.menuExit = QtWidgets.QMenu(self.menubar) 193 | self.menuExit.setObjectName("menuExit") 194 | MainWindow.setMenuBar(self.menubar) 195 | self.actionCreate_New = QtWidgets.QAction(MainWindow) 196 | self.actionCreate_New.setObjectName("actionCreate_New") 197 | self.actionHome = QtWidgets.QAction(MainWindow) 198 | self.actionHome.setObjectName("actionHome") 199 | self.actionAbout = QtWidgets.QAction(MainWindow) 200 | self.actionAbout.setObjectName("actionAbout") 201 | self.actionExit = QtWidgets.QAction(MainWindow) 202 | self.actionExit.setObjectName("actionExit") 203 | self.actionExit2 = QtWidgets.QAction(MainWindow) 204 | self.actionExit2.setObjectName("actionExit2") 205 | self.menuHome.addAction(self.actionHome) 206 | self.menuHome.addAction(self.actionExit2) 207 | self.menuCreate_New.addSeparator() 208 | self.menuCreate_New.addAction(self.actionCreate_New) 209 | self.menuAbout.addAction(self.actionAbout) 210 | self.menuExit.addAction(self.actionExit) 211 | self.menubar.addAction(self.menuHome.menuAction()) 212 | self.menubar.addAction(self.menuCreate_New.menuAction()) 213 | self.menubar.addAction(self.menuAbout.menuAction()) 214 | self.menubar.addAction(self.menuExit.menuAction()) 215 | self.retranslateUi(MainWindow) 216 | self.actionCreate_New.triggered.connect(self.mainlog.hide) 217 | self.actionHome.triggered.connect(self.progressBar.hide) 218 | self.actionHome.triggered.connect(self.contentbutton.hide) 219 | self.actionHome.triggered.connect(self.generatebutton.hide) 220 | self.actionHome.triggered.connect(self.stylebutton.hide) 221 | self.actionHome.triggered.connect(self.smalllogo.hide) 222 | self.actionHome.triggered.connect(self.About.hide) 223 | self.actionHome.triggered.connect(self.qualty.hide) 224 | self.actionHome.triggered.connect(self.savebutton.hide) 225 | self.actionHome.triggered.connect(self.comboBox.hide) 226 | self.actionHome.triggered.connect(self.warninglabel.hide) 227 | self.actionHome.triggered.connect(self.pluslabel.hide) 228 | self.actionHome.triggered.connect(self.equalabel.hide) 229 | self.actionHome.triggered.connect(self.outputframe.hide) 230 | self.actionHome.triggered.connect(self.res.hide) 231 | self.actionHome.triggered.connect(self.resbox.hide) 232 | self.actionHome.triggered.connect(self.mainlog.show) 233 | self.actionHome.triggered.connect(self.when.show) 234 | self.actionCreate_New.triggered.connect(self.createNewScreen) 235 | self.actionAbout.triggered.connect(self.About.show) 236 | self.actionAbout.triggered.connect(self.smalllogo.show) 237 | self.actionAbout.triggered.connect(self.generatebutton.hide) 238 | self.actionAbout.triggered.connect(self.warninglabel.hide) 239 | self.actionAbout.triggered.connect(self.progressBar.hide) 240 | self.generatebutton.clicked.connect(self.generatebutton.hide) 241 | self.actionAbout.triggered.connect(self.contentbutton.hide) 242 | self.actionAbout.triggered.connect(self.stylebutton.hide) 243 | self.actionAbout.triggered.connect(self.pluslabel.hide) 244 | self.actionAbout.triggered.connect(self.equalabel.hide) 245 | self.actionAbout.triggered.connect(self.outputframe.hide) 246 | self.actionAbout.triggered.connect(self.when.hide) 247 | self.actionAbout.triggered.connect(self.qualty.hide) 248 | self.actionAbout.triggered.connect(self.savebutton.hide) 249 | self.actionAbout.triggered.connect(self.comboBox.hide) 250 | self.actionAbout.triggered.connect(self.res.hide) 251 | self.actionAbout.triggered.connect(self.resbox.hide) 252 | self.actionAbout.triggered.connect(self.mainlog.hide) 253 | self.actionExit2.triggered.connect(self.exit) 254 | self.actionExit.triggered.connect(self.openhelp) 255 | self.actionAbout.triggered.connect(self.contentframe.hide) 256 | self.actionAbout.triggered.connect(self.styleframe.hide) 257 | self.actionHome.triggered.connect(self.contentframe.hide) 258 | self.actionHome.triggered.connect(self.styleframe.hide) 259 | QtCore.QMetaObject.connectSlotsByName(MainWindow) 260 | self.contentbutton.clicked.connect(self.setContentImage) 261 | self.stylebutton.clicked.connect(self.setStyleImage) 262 | self.generatebutton.clicked.connect(self.lunch_thread) 263 | self.savebutton.clicked.connect(self.saveimage) 264 | 265 | def retranslateUi(self, MainWindow): 266 | _translate = QtCore.QCoreApplication.translate 267 | MainWindow.setWindowTitle(_translate("artme", "artme")) 268 | self.contentbutton.setText(_translate("MainWindow", "Content Image")) 269 | self.stylebutton.setText(_translate("MainWindow", "Style Image")) 270 | self.generatebutton.setText(_translate("MainWindow", "Generate")) 271 | self.savebutton.setText(_translate("MainWindow", "Save Image")) 272 | self.comboBox.setItemText(0, _translate("MainWindow", "Low")) 273 | self.comboBox.setItemText(1, _translate("MainWindow", "Medium")) 274 | self.comboBox.setItemText(2, _translate("MainWindow", "High")) 275 | self.qualty.setText(_translate("MainWindow", "Quality: ")) 276 | self.resbox.setItemText(0, _translate("MainWindow", "256 Px")) 277 | self.resbox.setItemText(1, _translate("MainWindow", "512 Px")) 278 | self.resbox.setItemText(2, _translate("MainWindow", "1024 Px")) 279 | self.resbox.setItemText(3, _translate("MainWindow", "2048 Px")) 280 | self.res.setText(_translate("MainWindow", "Resolution:")) 281 | self.menuHome.setTitle(_translate("MainWindow", "Home")) 282 | self.menuCreate_New.setTitle(_translate("MainWindow", "Create New")) 283 | self.menuAbout.setTitle(_translate("MainWindow", "About")) 284 | self.menuExit.setTitle(_translate("MainWindow", "Help")) 285 | self.actionCreate_New.setText(_translate("MainWindow", "Create New")) 286 | self.actionHome.setText(_translate("MainWindow", "Home")) 287 | self.actionAbout.setText(_translate("MainWindow", "About")) 288 | self.actionExit.setText(_translate("MainWindow", "Help")) 289 | self.actionExit2.setText(_translate("MainWindow", "Exit")) 290 | 291 | # openhelp function open the help file. 292 | def openhelp(self): 293 | import os 294 | filename = 'Help.pdf' 295 | try: 296 | os.startfile(filename) 297 | except: 298 | return 299 | 300 | # createNewScreen function control of what shows in the create new screen. 301 | def createNewScreen(self): 302 | global flag3 303 | if(flag3==1): 304 | self.savebutton.show() 305 | else: 306 | self.savebutton.hide() 307 | self.smalllogo.show() 308 | self.contentframe.show() 309 | self.styleframe.show() 310 | self.contentbutton.show() 311 | self.warninglabel.hide() 312 | self.generatebutton.show() 313 | self.qualty.show() 314 | self.comboBox.show() 315 | self.res.show() 316 | self.resbox.show() 317 | self.pluslabel.show() 318 | self.equalabel.show() 319 | self.outputframe.show() 320 | self.mainlog.hide() 321 | self.when.hide() 322 | self.About.hide() 323 | self.progressBar.hide() 324 | self.stylebutton.show() 325 | self.contentframe.show() 326 | self.styleframe.show() 327 | 328 | """onCountChanged function control on updating the progrssBar.""" 329 | def onCountChanged(self, value): 330 | self.progressBar.setValue(value) 331 | 332 | """setContentImage function control on choosing the content image.""" 333 | def setContentImage(self): 334 | fileName, _ = QtWidgets.QFileDialog.getOpenFileNames(None, "Select Image", "", 335 | "Image Files (*.png *.jpg *.jpeg *.bmp)") 336 | if fileName: 337 | global content_path 338 | content_path = fileName[0] 339 | pixmap = QtGui.QPixmap(fileName[0]) 340 | pixmap = pixmap.scaled(290, 290, QtCore.Qt.KeepAspectRatio) 341 | self.contentframe.setPixmap(pixmap) 342 | self.contentframe.setAlignment(QtCore.Qt.AlignCenter) 343 | global flag1 344 | flag1 =1 345 | global flag2 346 | if (flag1==1 and flag2==1): 347 | self.outputframe.show() 348 | self.warninglabel.hide() 349 | self.generatebutton.show() 350 | self.pluslabel.show() 351 | self.equalabel.show() 352 | 353 | """setStyleImage function control on choosing the style image.""" 354 | def setStyleImage(self): 355 | fileName, _ = QtWidgets.QFileDialog.getOpenFileNames(None, "Select Image", "", 356 | "Image Files (*.png *.jpg *.jpeg *.bmp)") 357 | if fileName: 358 | global style_path 359 | style_path = fileName[0] 360 | pixmap = QtGui.QPixmap(fileName[0]) 361 | pixmap = pixmap.scaled(290, 290, QtCore.Qt.KeepAspectRatio) 362 | self.styleframe.setPixmap(pixmap) 363 | self.styleframe.setAlignment(QtCore.Qt.AlignCenter) 364 | global flag2 365 | flag2 = 1 366 | global flag1 367 | if (flag2==1 and flag1==1): 368 | self.outputframe.show() 369 | self.warninglabel.hide() 370 | self.generatebutton.show() 371 | self.pluslabel.show() 372 | self.equalabel.show() 373 | 374 | """Generate function is start when the Generate button pushed. it start the main algorithm.""" 375 | def Generate(self): 376 | global outputImage 377 | global exitflag 378 | exitflag=1 379 | global flag1 380 | global flag2 381 | if (flag1 == 0 or flag2 == 0): 382 | self.warninglabel.show() 383 | return 384 | self.actionHome.setEnabled(False) 385 | self.actionCreate_New.setEnabled(False) 386 | self.actionAbout.setEnabled(False) 387 | self.outputframe.setPixmap(QtGui.QPixmap(":/logo/qustion.png")) 388 | self.savebutton.hide() 389 | self.progressBar.setValue(0) 390 | self.progressBar.show() 391 | # iter control the number of iteration the algorithm run, the user choose it. 392 | global iter 393 | iter=0 394 | if self.comboBox.currentText() == 'Low': 395 | iter=100 396 | elif self.comboBox.currentText() == 'Medium': 397 | iter=500 398 | else: 399 | iter=1000 400 | 401 | # resulotion control the output image resulotion, the user choose it. 402 | resolution = 0 403 | if self.resbox.currentText() == '256 Px': 404 | resolution = 256 405 | elif self.resbox.currentText() == '512 Px': 406 | resolution = 512 407 | elif self.resbox.currentText() == '1024 Px': 408 | resolution = 1024 409 | else: 410 | resolution = 2048 411 | # outputImage get the result from the MainFunc. 412 | outputImage = self.MainFunc(content_path, style_path, iter, resolution) 413 | pixmap = QtGui.QPixmap(outputImage.toqpixmap()) 414 | pixmap = pixmap.scaled(290, 290, QtCore.Qt.KeepAspectRatio) 415 | self.outputframe.setPixmap(pixmap) 416 | self.outputframe.setAlignment(QtCore.Qt.AlignCenter) 417 | self.outputframe.show() 418 | self.savebutton.show() 419 | global flag3 420 | flag3 = 1 421 | self.actionHome.setEnabled(True) 422 | self.actionCreate_New.setEnabled(True) 423 | self.actionAbout.setEnabled(True) 424 | 425 | """lunch_thread control the start of the second thread that running the MainFunc.""" 426 | def lunch_thread(self): 427 | t = threading.Thread(target=self.Generate) 428 | t.start() 429 | 430 | """saveimage function control the saving of the output image.""" 431 | def saveimage(self): 432 | global outputImage 433 | fileName, _ = QtWidgets.QFileDialog.getSaveFileName(None, "Select Image", "", 434 | "Image Files (*.jpg *.png *.jpeg *.bmp)") 435 | if(fileName): 436 | outputImage.save(fileName) 437 | 438 | """exit function control on exit the application.""" 439 | def exit(self): 440 | if(exitflag == 1): 441 | self.exit() 442 | else: 443 | exit(1) 444 | 445 | """MainFunc is the main function that running the main algorithm""" 446 | def MainFunc(self, content_path, style_path, iter, resolution): 447 | import numpy as np 448 | from PIL import Image 449 | import tensorflow as tf 450 | import tensorflow.contrib.eager as tfe 451 | from tensorflow.python.keras.preprocessing import image as kp_image 452 | from tensorflow.python.keras import models 453 | 454 | # Eager execution is a flexible machine learning platform for research and experimentation. 455 | # Since we're using eager our model is callable just like any other function. 456 | tf.enable_eager_execution() 457 | print("Eager execution: {}".format(tf.executing_eagerly())) 458 | 459 | # define calc to the external thread. 460 | self.calc = External() 461 | self.calc.countChanged.connect(self.progressBar.setValue) 462 | 463 | # Content layer for the feature maps 464 | content_layers = ['block5_conv2'] 465 | 466 | # Style layer for the feature maps. 467 | style_layers = ['block1_conv1', 468 | 'block2_conv1', 469 | 'block3_conv1', 470 | 'block4_conv1', 471 | 'block5_conv1' 472 | ] 473 | 474 | num_content_layers = len(content_layers) 475 | num_style_layers = len(style_layers) 476 | 477 | # load_img function get the path of the image, 478 | # resize it and broadcast the image array such that it has a batch dimension. 479 | def load_img(path_to_img): 480 | max_dim = resolution 481 | img = Image.open(path_to_img) 482 | long = max(img.size) 483 | scale = max_dim / long 484 | img = img.resize((round(img.size[0] * scale), round(img.size[1] * scale)), Image.ANTIALIAS) 485 | img = kp_image.img_to_array(img) 486 | img = np.expand_dims(img, axis=0) 487 | return img 488 | 489 | # load_and_process_img is charge on load the image into the vgg19 network. 490 | def load_and_process_img(path_to_img): 491 | img = load_img(path_to_img) 492 | img = tf.keras.applications.vgg19.preprocess_input(img) 493 | return img 494 | 495 | def deprocess_img(processed_img): 496 | x = processed_img.copy() 497 | if len(x.shape) == 4: 498 | x = np.squeeze(x, 0) 499 | assert len(x.shape) == 3, ("Input to deprocess image must be an image of " 500 | "dimension [1, height, width, channel] or [height, width, channel]") 501 | if len(x.shape) != 3: 502 | raise ValueError("Invalid input to deprocessing image") 503 | 504 | x[:, :, 0] += 103.939 505 | x[:, :, 1] += 116.779 506 | x[:, :, 2] += 123.68 507 | x = x[:, :, ::-1] 508 | 509 | x = np.clip(x, 0, 255).astype('uint8') 510 | return x 511 | 512 | # get_model function load the VGG19 model and access the intermediate layers. 513 | # Returns: a Keras model that takes image inputs and outputs the style and content intermediate layers. 514 | def get_model(): 515 | import ssl 516 | ssl._create_default_https_context = ssl._create_unverified_context 517 | # We load pretrained VGG Network, trained on imagenet data 518 | vgg = tf.keras.applications.vgg19.VGG19(include_top=False, weights='imagenet') 519 | vgg.trainable = False 520 | # Get output layers corresponding to style and content layers 521 | style_outputs = [vgg.get_layer(name).output for name in style_layers] 522 | content_outputs = [vgg.get_layer(name).output for name in content_layers] 523 | model_outputs = style_outputs + content_outputs 524 | # Build model 525 | return models.Model(vgg.input, model_outputs) 526 | 527 | # get_content_loss function calculate the content loss that is the 528 | # Mean Squared Error between the two feature representations matrices. 529 | def get_content_loss(base_content, target): 530 | return tf.reduce_mean(tf.square(base_content - target)) 531 | 532 | # Calculate the gram matrix for the style representation. 533 | def gram_matrix(input_tensor): 534 | # Make the image channels 535 | channels = int(input_tensor.shape[-1]) 536 | a = tf.reshape(input_tensor, [-1, channels]) 537 | n = tf.shape(a)[0] 538 | gram = tf.matmul(a, a, transpose_a=True) 539 | return gram / tf.cast(n, tf.float32) 540 | 541 | # get the style loss by calculate the Mean Squared Error between the two gram matrices. 542 | # We scale the loss at a given layer by the size of the feature map and the number of filters 543 | def get_style_loss(base_style, gram_target): 544 | height, width, channels = base_style.get_shape().as_list() 545 | gram_style = gram_matrix(base_style) 546 | return tf.reduce_mean(tf.square(gram_style - gram_target)) 547 | 548 | """This function will simply load and preprocess both the content and style 549 | images from their path. Then it will feed them through the network to obtain 550 | the outputs of the intermediate layers. 551 | Returns the style and the content features representation.""" 552 | def get_feature_representations(model, content_path, style_path): 553 | # Load our images into the VGG19 Network 554 | content_image = load_and_process_img(content_path) 555 | style_image = load_and_process_img(style_path) 556 | 557 | # compute content and style features 558 | style_outputs = model(style_image) 559 | content_outputs = model(content_image) 560 | 561 | # Get the style and content feature representations from our model 562 | style_features = [style_layer[0] for style_layer in style_outputs[:num_style_layers]] 563 | content_features = [content_layer[0] for content_layer in content_outputs[num_style_layers:]] 564 | return style_features, content_features 565 | 566 | """This function compute the content, style and total loss. 567 | we use model that will give us access to the intermediate layers.""" 568 | def compute_loss(model, loss_weights, init_image, gram_style_features, content_features): 569 | style_weight, content_weight = loss_weights 570 | 571 | # Feed our init image through our model. This will give us the content and 572 | # style representations at our desired layers. 573 | model_outputs = model(init_image) 574 | 575 | style_output_features = model_outputs[:num_style_layers] 576 | content_output_features = model_outputs[num_style_layers:] 577 | 578 | style_score = 0 579 | content_score = 0 580 | 581 | # calculate the style losses from all layers 582 | # equally weight each contribution of each loss layer 583 | weight_per_style_layer = 1.0 / float(num_style_layers) 584 | for target_style, comb_style in zip(gram_style_features, style_output_features): 585 | style_score += weight_per_style_layer * get_style_loss(comb_style[0], target_style) 586 | 587 | # calculate content losses from all layers 588 | weight_per_content_layer = 1.0 / float(num_content_layers) 589 | for target_content, comb_content in zip(content_features, content_output_features): 590 | content_score += weight_per_content_layer * get_content_loss(comb_content[0], target_content) 591 | 592 | style_score *= style_weight 593 | content_score *= content_weight 594 | 595 | # Get total loss 596 | loss = style_score + content_score 597 | return loss, style_score, content_score 598 | 599 | # Compute gradients according to input image 600 | def compute_grads(cfg): 601 | with tf.GradientTape() as tape: 602 | all_loss = compute_loss(**cfg) 603 | total_loss = all_loss[0] 604 | return tape.gradient(total_loss, cfg['init_image']), all_loss 605 | 606 | """The main method of the code, running the main loop for generating the image.""" 607 | def run_style_transfer(content_path, 608 | style_path, 609 | num_iterations=1000, 610 | content_weight=1e3, 611 | style_weight=1e-2): 612 | # We don't train any layers of our model, so we set their trainable to false. 613 | model = get_model() 614 | for layer in model.layers: 615 | layer.trainable = False 616 | 617 | # Get the style and content feature representations (from our specified intermediate layers) 618 | style_features, content_features = get_feature_representations(model, content_path, style_path) 619 | gram_style_features = [gram_matrix(style_feature) for style_feature in style_features] 620 | 621 | # Set initial image 622 | init_image = load_and_process_img(content_path) 623 | init_image = tfe.Variable(init_image, dtype=tf.float32) 624 | # We use Adam Optimizer 625 | opt = tf.train.AdamOptimizer(learning_rate=5, beta1=0.99, epsilon=1e-1) 626 | 627 | # Store our best result 628 | best_loss, best_img = float('inf'), None 629 | 630 | # Create config 631 | loss_weights = (style_weight, content_weight) 632 | cfg = { 633 | 'model': model, 634 | 'loss_weights': loss_weights, 635 | 'init_image': init_image, 636 | 'gram_style_features': gram_style_features, 637 | 'content_features': content_features 638 | } 639 | 640 | norm_means = np.array([103.939, 116.779, 123.68]) 641 | min_vals = -norm_means 642 | max_vals = 255 - norm_means 643 | 644 | # Main loop 645 | for i in range(num_iterations): 646 | global count 647 | count=i 648 | self.calc.start() 649 | print(i) 650 | grads, all_loss = compute_grads(cfg) 651 | loss, style_score, content_score = all_loss 652 | opt.apply_gradients([(grads, init_image)]) 653 | clipped = tf.clip_by_value(init_image, min_vals, max_vals) 654 | init_image.assign(clipped) 655 | 656 | if loss < best_loss: 657 | # Update best loss and best image from total loss. 658 | best_loss = loss 659 | best_img = deprocess_img(init_image.numpy()) 660 | 661 | return best_img, best_loss 662 | 663 | best, best_loss = run_style_transfer(content_path, style_path, num_iterations=iter) 664 | im = Image.fromarray(best) 665 | return im 666 | 667 | """External class control the thread running the ProgressBar.""" 668 | class External(QThread): 669 | countChanged = pyqtSignal(int) 670 | 671 | def run(self): 672 | global count 673 | global iter 674 | ii =((count + 1) / iter) * 100 675 | self.countChanged.emit(ii) 676 | 677 | if __name__ == "__main__": 678 | import sys 679 | app = QtWidgets.QApplication(sys.argv) 680 | MainWindow = QtWidgets.QMainWindow() 681 | ui = Ui_MainWindow() 682 | ui.setupUi(MainWindow) 683 | MainWindow.show() 684 | sys.exit(app.exec_()) 685 | 686 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Style Transfer Using Convolutional Neural Network 2 | **implementation of style transfer by using CNN with Tensorflow.** 3 | 4 | The system extract content and style from an image and combined them together in order to get an artistic image by using neural network, code written in python/PyQt5 and worked on pre trained network with tensorflow. 5 | 6 | This is a collage project that based on Leon A. Gatys paper, you can find our full project paper in the following link: 7 | 8 | [Image Style Transfer Using CNN](https://drive.google.com/file/d/17Ll4F1XUl1VXOouRPJZ2c2GksUtDZ9wa/view?usp=sharing) 9 | 10 | For using the application you can or downlowd [artme.exe](https://drive.google.com/file/d/1m13DuCYS6ZbAJFIxCq40FcbEC0IImCvC/view?usp=sharing) and run it on any machine, or run the python code on python3 environment. 11 | 12 | **Example:** 13 |
14 |
15 |
22 |
23 |
27 |
28 |
32 |
33 |
37 |
38 |
47 |
48 |
52 |
53 |