├── .gitattributes ├── LICENSE ├── README.md ├── images └── catalogue.csv ├── lionimage.jpg ├── models ├── 1.pth └── old_models.csv ├── requirements.txt └── serapis.py /.gitattributes: -------------------------------------------------------------------------------- 1 | models/*.pth filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Mozilla Public License Version 2.0 2 | ================================== 3 | 4 | 1. Definitions 5 | -------------- 6 | 7 | 1.1. "Contributor" 8 | means each individual or legal entity that creates, contributes to 9 | the creation of, or owns Covered Software. 10 | 11 | 1.2. "Contributor Version" 12 | means the combination of the Contributions of others (if any) used 13 | by a Contributor and that particular Contributor's Contribution. 14 | 15 | 1.3. "Contribution" 16 | means Covered Software of a particular Contributor. 17 | 18 | 1.4. "Covered Software" 19 | means Source Code Form to which the initial Contributor has attached 20 | the notice in Exhibit A, the Executable Form of such Source Code 21 | Form, and Modifications of such Source Code Form, in each case 22 | including portions thereof. 23 | 24 | 1.5. "Incompatible With Secondary Licenses" 25 | means 26 | 27 | (a) that the initial Contributor has attached the notice described 28 | in Exhibit B to the Covered Software; or 29 | 30 | (b) that the Covered Software was made available under the terms of 31 | version 1.1 or earlier of the License, but not also under the 32 | terms of a Secondary License. 33 | 34 | 1.6. "Executable Form" 35 | means any form of the work other than Source Code Form. 36 | 37 | 1.7. "Larger Work" 38 | means a work that combines Covered Software with other material, in 39 | a separate file or files, that is not Covered Software. 40 | 41 | 1.8. "License" 42 | means this document. 43 | 44 | 1.9. "Licensable" 45 | means having the right to grant, to the maximum extent possible, 46 | whether at the time of the initial grant or subsequently, any and 47 | all of the rights conveyed by this License. 48 | 49 | 1.10. "Modifications" 50 | means any of the following: 51 | 52 | (a) any file in Source Code Form that results from an addition to, 53 | deletion from, or modification of the contents of Covered 54 | Software; or 55 | 56 | (b) any new file in Source Code Form that contains any Covered 57 | Software. 58 | 59 | 1.11. "Patent Claims" of a Contributor 60 | means any patent claim(s), including without limitation, method, 61 | process, and apparatus claims, in any patent Licensable by such 62 | Contributor that would be infringed, but for the grant of the 63 | License, by the making, using, selling, offering for sale, having 64 | made, import, or transfer of either its Contributions or its 65 | Contributor Version. 66 | 67 | 1.12. "Secondary License" 68 | means either the GNU General Public License, Version 2.0, the GNU 69 | Lesser General Public License, Version 2.1, the GNU Affero General 70 | Public License, Version 3.0, or any later versions of those 71 | licenses. 72 | 73 | 1.13. "Source Code Form" 74 | means the form of the work preferred for making modifications. 75 | 76 | 1.14. "You" (or "Your") 77 | means an individual or a legal entity exercising rights under this 78 | License. For legal entities, "You" includes any entity that 79 | controls, is controlled by, or is under common control with You. For 80 | purposes of this definition, "control" means (a) the power, direct 81 | or indirect, to cause the direction or management of such entity, 82 | whether by contract or otherwise, or (b) ownership of more than 83 | fifty percent (50%) of the outstanding shares or beneficial 84 | ownership of such entity. 85 | 86 | 2. License Grants and Conditions 87 | -------------------------------- 88 | 89 | 2.1. Grants 90 | 91 | Each Contributor hereby grants You a world-wide, royalty-free, 92 | non-exclusive license: 93 | 94 | (a) under intellectual property rights (other than patent or trademark) 95 | Licensable by such Contributor to use, reproduce, make available, 96 | modify, display, perform, distribute, and otherwise exploit its 97 | Contributions, either on an unmodified basis, with Modifications, or 98 | as part of a Larger Work; and 99 | 100 | (b) under Patent Claims of such Contributor to make, use, sell, offer 101 | for sale, have made, import, and otherwise transfer either its 102 | Contributions or its Contributor Version. 103 | 104 | 2.2. Effective Date 105 | 106 | The licenses granted in Section 2.1 with respect to any Contribution 107 | become effective for each Contribution on the date the Contributor first 108 | distributes such Contribution. 109 | 110 | 2.3. Limitations on Grant Scope 111 | 112 | The licenses granted in this Section 2 are the only rights granted under 113 | this License. No additional rights or licenses will be implied from the 114 | distribution or licensing of Covered Software under this License. 115 | Notwithstanding Section 2.1(b) above, no patent license is granted by a 116 | Contributor: 117 | 118 | (a) for any code that a Contributor has removed from Covered Software; 119 | or 120 | 121 | (b) for infringements caused by: (i) Your and any other third party's 122 | modifications of Covered Software, or (ii) the combination of its 123 | Contributions with other software (except as part of its Contributor 124 | Version); or 125 | 126 | (c) under Patent Claims infringed by Covered Software in the absence of 127 | its Contributions. 128 | 129 | This License does not grant any rights in the trademarks, service marks, 130 | or logos of any Contributor (except as may be necessary to comply with 131 | the notice requirements in Section 3.4). 132 | 133 | 2.4. Subsequent Licenses 134 | 135 | No Contributor makes additional grants as a result of Your choice to 136 | distribute the Covered Software under a subsequent version of this 137 | License (see Section 10.2) or under the terms of a Secondary License (if 138 | permitted under the terms of Section 3.3). 139 | 140 | 2.5. Representation 141 | 142 | Each Contributor represents that the Contributor believes its 143 | Contributions are its original creation(s) or it has sufficient rights 144 | to grant the rights to its Contributions conveyed by this License. 145 | 146 | 2.6. Fair Use 147 | 148 | This License is not intended to limit any rights You have under 149 | applicable copyright doctrines of fair use, fair dealing, or other 150 | equivalents. 151 | 152 | 2.7. Conditions 153 | 154 | Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted 155 | in Section 2.1. 156 | 157 | 3. Responsibilities 158 | ------------------- 159 | 160 | 3.1. Distribution of Source Form 161 | 162 | All distribution of Covered Software in Source Code Form, including any 163 | Modifications that You create or to which You contribute, must be under 164 | the terms of this License. You must inform recipients that the Source 165 | Code Form of the Covered Software is governed by the terms of this 166 | License, and how they can obtain a copy of this License. You may not 167 | attempt to alter or restrict the recipients' rights in the Source Code 168 | Form. 169 | 170 | 3.2. Distribution of Executable Form 171 | 172 | If You distribute Covered Software in Executable Form then: 173 | 174 | (a) such Covered Software must also be made available in Source Code 175 | Form, as described in Section 3.1, and You must inform recipients of 176 | the Executable Form how they can obtain a copy of such Source Code 177 | Form by reasonable means in a timely manner, at a charge no more 178 | than the cost of distribution to the recipient; and 179 | 180 | (b) You may distribute such Executable Form under the terms of this 181 | License, or sublicense it under different terms, provided that the 182 | license for the Executable Form does not attempt to limit or alter 183 | the recipients' rights in the Source Code Form under this License. 184 | 185 | 3.3. Distribution of a Larger Work 186 | 187 | You may create and distribute a Larger Work under terms of Your choice, 188 | provided that You also comply with the requirements of this License for 189 | the Covered Software. If the Larger Work is a combination of Covered 190 | Software with a work governed by one or more Secondary Licenses, and the 191 | Covered Software is not Incompatible With Secondary Licenses, this 192 | License permits You to additionally distribute such Covered Software 193 | under the terms of such Secondary License(s), so that the recipient of 194 | the Larger Work may, at their option, further distribute the Covered 195 | Software under the terms of either this License or such Secondary 196 | License(s). 197 | 198 | 3.4. Notices 199 | 200 | You may not remove or alter the substance of any license notices 201 | (including copyright notices, patent notices, disclaimers of warranty, 202 | or limitations of liability) contained within the Source Code Form of 203 | the Covered Software, except that You may alter any license notices to 204 | the extent required to remedy known factual inaccuracies. 205 | 206 | 3.5. Application of Additional Terms 207 | 208 | You may choose to offer, and to charge a fee for, warranty, support, 209 | indemnity or liability obligations to one or more recipients of Covered 210 | Software. However, You may do so only on Your own behalf, and not on 211 | behalf of any Contributor. You must make it absolutely clear that any 212 | such warranty, support, indemnity, or liability obligation is offered by 213 | You alone, and You hereby agree to indemnify every Contributor for any 214 | liability incurred by such Contributor as a result of warranty, support, 215 | indemnity or liability terms You offer. You may include additional 216 | disclaimers of warranty and limitations of liability specific to any 217 | jurisdiction. 218 | 219 | 4. Inability to Comply Due to Statute or Regulation 220 | --------------------------------------------------- 221 | 222 | If it is impossible for You to comply with any of the terms of this 223 | License with respect to some or all of the Covered Software due to 224 | statute, judicial order, or regulation then You must: (a) comply with 225 | the terms of this License to the maximum extent possible; and (b) 226 | describe the limitations and the code they affect. Such description must 227 | be placed in a text file included with all distributions of the Covered 228 | Software under this License. Except to the extent prohibited by statute 229 | or regulation, such description must be sufficiently detailed for a 230 | recipient of ordinary skill to be able to understand it. 231 | 232 | 5. Termination 233 | -------------- 234 | 235 | 5.1. The rights granted under this License will terminate automatically 236 | if You fail to comply with any of its terms. However, if You become 237 | compliant, then the rights granted under this License from a particular 238 | Contributor are reinstated (a) provisionally, unless and until such 239 | Contributor explicitly and finally terminates Your grants, and (b) on an 240 | ongoing basis, if such Contributor fails to notify You of the 241 | non-compliance by some reasonable means prior to 60 days after You have 242 | come back into compliance. Moreover, Your grants from a particular 243 | Contributor are reinstated on an ongoing basis if such Contributor 244 | notifies You of the non-compliance by some reasonable means, this is the 245 | first time You have received notice of non-compliance with this License 246 | from such Contributor, and You become compliant prior to 30 days after 247 | Your receipt of the notice. 248 | 249 | 5.2. If You initiate litigation against any entity by asserting a patent 250 | infringement claim (excluding declaratory judgment actions, 251 | counter-claims, and cross-claims) alleging that a Contributor Version 252 | directly or indirectly infringes any patent, then the rights granted to 253 | You by any and all Contributors for the Covered Software under Section 254 | 2.1 of this License shall terminate. 255 | 256 | 5.3. In the event of termination under Sections 5.1 or 5.2 above, all 257 | end user license agreements (excluding distributors and resellers) which 258 | have been validly granted by You or Your distributors under this License 259 | prior to termination shall survive termination. 260 | 261 | ************************************************************************ 262 | * * 263 | * 6. Disclaimer of Warranty * 264 | * ------------------------- * 265 | * * 266 | * Covered Software is provided under this License on an "as is" * 267 | * basis, without warranty of any kind, either expressed, implied, or * 268 | * statutory, including, without limitation, warranties that the * 269 | * Covered Software is free of defects, merchantable, fit for a * 270 | * particular purpose or non-infringing. The entire risk as to the * 271 | * quality and performance of the Covered Software is with You. * 272 | * Should any Covered Software prove defective in any respect, You * 273 | * (not any Contributor) assume the cost of any necessary servicing, * 274 | * repair, or correction. This disclaimer of warranty constitutes an * 275 | * essential part of this License. No use of any Covered Software is * 276 | * authorized under this License except under this disclaimer. * 277 | * * 278 | ************************************************************************ 279 | 280 | ************************************************************************ 281 | * * 282 | * 7. Limitation of Liability * 283 | * -------------------------- * 284 | * * 285 | * Under no circumstances and under no legal theory, whether tort * 286 | * (including negligence), contract, or otherwise, shall any * 287 | * Contributor, or anyone who distributes Covered Software as * 288 | * permitted above, be liable to You for any direct, indirect, * 289 | * special, incidental, or consequential damages of any character * 290 | * including, without limitation, damages for lost profits, loss of * 291 | * goodwill, work stoppage, computer failure or malfunction, or any * 292 | * and all other commercial damages or losses, even if such party * 293 | * shall have been informed of the possibility of such damages. This * 294 | * limitation of liability shall not apply to liability for death or * 295 | * personal injury resulting from such party's negligence to the * 296 | * extent applicable law prohibits such limitation. Some * 297 | * jurisdictions do not allow the exclusion or limitation of * 298 | * incidental or consequential damages, so this exclusion and * 299 | * limitation may not apply to You. * 300 | * * 301 | ************************************************************************ 302 | 303 | 8. Litigation 304 | ------------- 305 | 306 | Any litigation relating to this License may be brought only in the 307 | courts of a jurisdiction where the defendant maintains its principal 308 | place of business and such litigation shall be governed by laws of that 309 | jurisdiction, without reference to its conflict-of-law provisions. 310 | Nothing in this Section shall prevent a party's ability to bring 311 | cross-claims or counter-claims. 312 | 313 | 9. Miscellaneous 314 | ---------------- 315 | 316 | This License represents the complete agreement concerning the subject 317 | matter hereof. If any provision of this License is held to be 318 | unenforceable, such provision shall be reformed only to the extent 319 | necessary to make it enforceable. Any law or regulation which provides 320 | that the language of a contract shall be construed against the drafter 321 | shall not be used to construe this License against a Contributor. 322 | 323 | 10. Versions of the License 324 | --------------------------- 325 | 326 | 10.1. New Versions 327 | 328 | Mozilla Foundation is the license steward. Except as provided in Section 329 | 10.3, no one other than the license steward has the right to modify or 330 | publish new versions of this License. Each version will be given a 331 | distinguishing version number. 332 | 333 | 10.2. Effect of New Versions 334 | 335 | You may distribute the Covered Software under the terms of the version 336 | of the License under which You originally received the Covered Software, 337 | or under the terms of any subsequent version published by the license 338 | steward. 339 | 340 | 10.3. Modified Versions 341 | 342 | If you create software not governed by this License, and you want to 343 | create a new license for such software, you may create and use a 344 | modified version of this License if you rename the license and remove 345 | any references to the name of the license steward (except to note that 346 | such modified license differs from this License). 347 | 348 | 10.4. Distributing Source Code Form that is Incompatible With Secondary 349 | Licenses 350 | 351 | If You choose to distribute Source Code Form that is Incompatible With 352 | Secondary Licenses under the terms of this version of the License, the 353 | notice described in Exhibit B of this License must be attached. 354 | 355 | Exhibit A - Source Code Form License Notice 356 | ------------------------------------------- 357 | 358 | This Source Code Form is subject to the terms of the Mozilla Public 359 | License, v. 2.0. If a copy of the MPL was not distributed with this 360 | file, You can obtain one at http://mozilla.org/MPL/2.0/. 361 | 362 | If it is not possible or desirable to put the notice in a particular 363 | file, then You may include the notice in a location (such as a LICENSE 364 | file in a relevant directory) where a recipient would be likely to look 365 | for such a notice. 366 | 367 | You may add additional accurate notices of copyright ownership. 368 | 369 | Exhibit B - "Incompatible With Secondary Licenses" Notice 370 | --------------------------------------------------------- 371 | 372 | This Source Code Form is "Incompatible With Secondary Licenses", as 373 | defined by the Mozilla Public License, v. 2.0. 374 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Serapis AI Image Classifier

2 | 3 |

4 | The Staff of Serapis 5 |

6 | 7 |

8 | Serapis AI Image Classifier is a program that allows you to automatically create image datasets using SerpApi's Google Images Scraper API, finetune a ResNet50 model, and classify images using the trained model. 9 |
10 |

11 | 12 | 13 | --- 14 | 15 |

Installation

16 | 17 | You can install these dependencies using the following command: 18 | ```bash 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | --- 23 | 24 |

Usage

25 | 26 | You can use Serapis AI Image Classifier in one of the following three modes: 27 | 28 | --- 29 | 30 |

Create a dataset and train a new model from scratch

31 | 32 | To create a dataset and train a new model from scratch, you will need to provide a list of labels and an image to use as a reference for the scraping process. 33 | [SerpApi API Key](https://serpapi.com/manage-api-key) is necessary for this mode for the program to automatically scrape images you will use in your database using [SerpApi's Google Images Scraper API](https://serpapi.com/images-results). 34 | 35 | You can [register to SerpApi to claim free credits](https://serpapi.com/users/sign_up). 36 | 37 | ```bash 38 | python serapis.py --train --labels eagle, bull, lion, man --image-path lionimage.jpg --api-key 39 | ``` 40 | 41 | --- 42 | 43 |

Use old scraped images and train a new model

44 | 45 | To use old scraped images and train a new model, you will need to provide a list of labels and specify that you want to use old images with the --use-old-images flag. 46 | You can also put your images in the `images/` folder and also add enter them in `images/catalogue.csv` to manually train models using your own dataset. 47 | 48 | ```bash 49 | python serapis.py --train --labels eagle, bull, lion, man --use-old-images --image-path lionimage.jpg 50 | ``` 51 | 52 | --- 53 | 54 |

Use a previously trained model

55 | 56 | To use a previously trained model, you will need to provide the path to the trained model and an image to classify. 57 | 58 | ```bash 59 | python serapis.py --model-path models/1.pth --image-path lionimage.jpg 60 | ``` 61 | 62 | --- 63 | 64 |

Dialogue Mode

65 | 66 | You can also navigate the program by not providing any arguments and using the dialogue mode: 67 | 68 | ```bash 69 | python serapis.py 70 | ``` 71 | 72 | --- 73 | 74 |

The Output

75 |

76 | Classified Image of a Lion 77 |

78 | 79 | The output will give you the answer: 80 | 81 | ```bash 82 | The image contains Lion 83 | ``` 84 | 85 | --- 86 | 87 | Optional Arguments: 88 | 89 | -h, --help Help to Nagigate 90 | --train Whether to train a new model 91 | --model-path MODEL_PATH Pretrained Model path you want to use 92 | --dialogue Whether to use dialogue to navigate through the program 93 | --use-old-images Whether to use old images you have downloaded to train a new model 94 | --api-key API_KEY SerpApi API Key 95 | --limit LIMIT Number of images you want to scrape at most for each label 96 | --labels LABELS [LABELS ...] Labels you want to use to train a new model 97 | --image-path IMAGE_PATH Path to the image you want to classify 98 | -------------------------------------------------------------------------------- /images/catalogue.csv: -------------------------------------------------------------------------------- 1 | label,image_path 2 | -------------------------------------------------------------------------------- /lionimage.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serpapi/serapis-ai-image-classifier/0c9ca68fa7070e98d2994b856bc915c59a06830e/lionimage.jpg -------------------------------------------------------------------------------- /models/1.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:70a8d1b09e6d56878ffb05d373b12ee3c69bb30f8a3af59a4f8cb1e4daeba444 3 | size 94404289 4 | -------------------------------------------------------------------------------- /models/old_models.csv: -------------------------------------------------------------------------------- 1 | model_path,target_labels 2 | models/1.pth,Eagle--Bull--Lion--Human 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==1.4.4 2 | Pillow==9.4.0 3 | requests==2.28.1 4 | torch==1.13.0 5 | torchvision==0.13.1a0 6 | transformers==4.24.0 7 | -------------------------------------------------------------------------------- /serapis.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore", category=UserWarning, module='jax') 3 | warnings.filterwarnings("ignore", module='PIL') 4 | 5 | from transformers import ResNetForImageClassification 6 | from torch.utils.data import DataLoader, Dataset 7 | from torchvision import transforms 8 | from transformers import logging 9 | import torch.nn.functional as F 10 | from PIL import Image 11 | import pandas as pd 12 | import urllib.parse 13 | import argparse 14 | import requests 15 | import asyncio 16 | import torch 17 | import sys 18 | import os 19 | import io 20 | 21 | async def call_serpapi(target, api_key, ijn: 0): 22 | q = urllib.parse.quote(target["q"]) 23 | url = "https://serpapi.com/search.json?api_key={}&tbm=isch&q={}&ijn={}".format(api_key, q, ijn) 24 | if "chips" in target: 25 | chips = None 26 | 27 | # First call to retrieve chips from the search 28 | response = requests.get(url) 29 | results = response.json() 30 | if "error" in results: 31 | print(results["error"]) 32 | sys.exit(1) 33 | 34 | if "chips" in suggestion and suggestion["name"] == target["chips"]: 35 | if "suggested_searches" in results: 36 | for suggestion in results["suggested_searches"]: 37 | chips = urllib.parse.quote(suggestion["chips"]) 38 | 39 | if chips != None: 40 | # Second call to make a chips search 41 | url = "https://serpapi.com/search.json?api_key={}&tbm=isch&q={}&chips={}&ijn={}".format(api_key, q, chips, ijn) 42 | response = requests.get(url) 43 | else: 44 | # Call without chips 45 | response = requests.get(url) 46 | return response.json() 47 | 48 | async def search(targets, api_key): 49 | images_results = [] 50 | for target in targets: 51 | searches = [] 52 | # First page results 53 | searches.append(call_serpapi(target, api_key, ijn=0)) 54 | if "page" in target and target["page"] > 1: 55 | # If more than one page is requested 56 | [searches.append(call_serpapi(target, api_key, ijn=i)) for i in range(target["page"]) if i != 0] 57 | images = await asyncio.gather(*searches) 58 | images_results.append([target["q"], images]) 59 | return images_results 60 | 61 | async def get_single_image_url(label, url, df): 62 | try: 63 | response = requests.get(url, timeout=2) 64 | f = io.BytesIO(response.content) 65 | 66 | last_item = len(os.listdir("images")) 67 | image_path = "images/{}.png".format(last_item) 68 | 69 | im = Image.open(f) 70 | im.save(image_path, format='PNG', quality=95) 71 | 72 | df.loc[0] = [label, image_path] 73 | print("Downloaded {}".format(url)) 74 | return df 75 | except: 76 | return df 77 | 78 | async def get_images(images_results, limit = None): 79 | all_dfs = [] 80 | for images in images_results: 81 | calls = [] 82 | if len(images) == 2: 83 | for page_result in images[1]: 84 | if "images_results" in page_result: 85 | for i in range(len(page_result["images_results"])): 86 | result = page_result["images_results"][i] 87 | if limit != None and i == limit: 88 | break 89 | 90 | if "original" in result: 91 | df = pd.DataFrame(columns=["label", "image_path"]) 92 | url = result["original"] 93 | calls.append(get_single_image_url(images[0], url, df)) 94 | print("Added Coroutine: {} - {}".format(images[0], url, df)) 95 | dfs = await asyncio.gather(*calls) 96 | all_dfs = all_dfs + dfs 97 | df = pd.concat(all_dfs) 98 | print("---") 99 | return df 100 | 101 | 102 | class CustomDataset(Dataset): 103 | def __init__(self, df): 104 | self.df = df 105 | 106 | def __len__(self): 107 | return len(self.df) 108 | 109 | def preprocess_image(self, image_path): 110 | # Load image using PIL 111 | image = Image.open(image_path) 112 | # Convert image to RGB if it is not already in that format 113 | if image.mode != 'RGB': 114 | image = image.convert('RGB') 115 | # Resize image to a fixed size 116 | image = image.resize((224, 224)) 117 | # Convert image to a tensor 118 | image = transforms.ToTensor()(image) 119 | # Normalize image 120 | image = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image) 121 | return image 122 | 123 | def __getitem__(self, index): 124 | row = self.df.iloc[index] 125 | label = row['label'] 126 | image_path = row['image_path'] 127 | # Load and pre-process image data here 128 | image_data = self.preprocess_image(image_path) 129 | return image_data, label 130 | 131 | def train_and_save_model(df, target_labels): 132 | print("Training a new model. If you get a warning about shapes below, you can ignore it.") 133 | dataset = CustomDataset(df) 134 | dataloader = DataLoader(dataset, batch_size=32, shuffle=True) 135 | class_to_idx = {class_name: idx for idx, class_name in enumerate(target_labels)} 136 | idx_to_class = {idx: class_name for idx, class_name in enumerate(target_labels)} 137 | model = ResNetForImageClassification.from_pretrained( 138 | "microsoft/resnet-50", 139 | num_labels=len(target_labels), 140 | id2label=idx_to_class, 141 | label2id=class_to_idx, 142 | ignore_mismatched_sizes=True 143 | ) 144 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 145 | total_examples = len(dataset) 146 | num_epochs = 2 147 | processed_examples = 0 148 | 149 | print("---") 150 | for epoch in range(num_epochs): 151 | for data in dataloader: 152 | images, labels = data 153 | model.train() 154 | model.zero_grad() 155 | 156 | logits = model(images).logits 157 | labels = [class_to_idx[label] for label in labels if label in class_to_idx] 158 | labels = torch.tensor(labels) 159 | 160 | loss = F.cross_entropy(logits, labels) 161 | loss.backward() 162 | optimizer.step() 163 | 164 | model.eval() 165 | logits = model(images, labels = labels).logits 166 | _, predicted_labels = logits.max(dim=1) 167 | accuracy = (predicted_labels == labels).float().mean() 168 | 169 | batch_size = images.size(0) 170 | processed_examples += batch_size 171 | progress = (processed_examples / (total_examples * num_epochs)) * 100 172 | print(f'Accuracy: {accuracy} | Progress: {progress:.2f}%') 173 | 174 | last_item = len(os.listdir("models")) 175 | model_path = "models/{}.pth".format(last_item) 176 | torch.save(model.state_dict(), model_path) 177 | df = pd.DataFrame(columns=["model_path", "target_labels"]) 178 | df.loc[0] = [model_path, "--".join(target_labels)] 179 | old_df = pd.read_csv("models/old_models.csv") 180 | new_df = pd.concat([df, old_df], ignore_index=True) 181 | new_df.to_csv("models/old_models.csv", index=False) 182 | 183 | return model 184 | 185 | def train_a_new_model(targets=None, use_catalogue=True, api_key=None, limit=None): 186 | target_labels = [dictionary["q"] for dictionary in targets] 187 | if use_catalogue: 188 | # Call the Catalogue CSV and get only the keys you targeted. 189 | df = pd.read_csv("images/catalogue.csv") 190 | all_labels = list(set(df["label"])) 191 | for label in all_labels: 192 | if label not in target_labels: 193 | df = df.drop(df[df["label"] == label].index) 194 | else: 195 | # Save the new images to Catalogue CSV but use only the ones you targeted. 196 | images_results = asyncio.run(search(targets, api_key)) 197 | df = asyncio.run(get_images(images_results, limit = limit)) 198 | old_df = pd.read_csv("images/catalogue.csv") 199 | new_df = pd.concat([df, old_df], ignore_index=True) 200 | new_df.to_csv("images/catalogue.csv", index=False) 201 | 202 | model = train_and_save_model(df, target_labels) 203 | return model, target_labels 204 | 205 | def use_old_model(model_path=None): 206 | model_df = pd.read_csv("models/old_models.csv") 207 | target_labels = model_df.loc[model_df[model_df["model_path"] == model_path].index]["target_labels"].iloc[0] 208 | target_labels = target_labels.split("--") 209 | class_to_idx = {class_name: idx for idx, class_name in enumerate(target_labels)} 210 | idx_to_class = {idx: class_name for idx, class_name in enumerate(target_labels)} 211 | model = ResNetForImageClassification.from_pretrained( 212 | "microsoft/resnet-50", 213 | num_labels=len(target_labels), 214 | id2label=idx_to_class, 215 | label2id=class_to_idx, 216 | ignore_mismatched_sizes=True 217 | ) 218 | state_dict = torch.load(model_path) 219 | model.load_state_dict(state_dict) 220 | return model, target_labels 221 | 222 | def predict_image(image_path=None, model=None, target_labels=None): 223 | df = pd.DataFrame(columns=["label", "image_path"]) 224 | dataloader = CustomDataset(df) 225 | image = dataloader.preprocess_image(image_path) 226 | model.eval() 227 | logits = model(image.unsqueeze(0)).logits 228 | _, predicted_labels = logits.max(dim=1) 229 | predicted_class_names = [target_labels[int(label)] for label in predicted_labels] 230 | return predicted_class_names[0] 231 | 232 | def questions(): 233 | api_key = "" 234 | model_path = "" 235 | train_new_model = False 236 | use_catalogue = True 237 | limit = None 238 | targets = [] 239 | while True: 240 | print("What do you want to do?") 241 | print("1. Train a new model.") 242 | print("2. Use an old model.") 243 | choice = input("Enter your choice: ") 244 | if choice == "1" or choice == "2": 245 | break 246 | else: 247 | print("Please enter a valid choice.") 248 | print("---") 249 | 250 | if choice == "1": 251 | train_new_model = True 252 | while True: 253 | print("Would you like to use the old images you have stored, or scrape fresh images?") 254 | print("1. Use old images.") 255 | print("2. Scrape new images.") 256 | second_choice = input("Enter your choice: ") 257 | if second_choice == "1" or second_choice == "2": 258 | break 259 | else: 260 | print("Please enter a valid choice.") 261 | print("---") 262 | 263 | if second_choice == "1": 264 | use_catalogue = True 265 | old_df = pd.read_csv("images/catalogue.csv") 266 | all_labels = list(set(old_df["label"])) 267 | while True: 268 | print("Here are the labels of images you have stored.") 269 | print("{}".format(", ".join(all_labels))) 270 | print("Which labels do you want to use?") 271 | desired_labels = input("Enter the labels (case sensitive) separated by a comma: ") 272 | desired_labels = desired_labels.split(",") 273 | desired_labels = [label.strip() for label in desired_labels if label.strip != ""] 274 | if len(desired_labels) < 2: 275 | print("Please enter at least two labels.") 276 | elif not all(elem in all_labels for elem in desired_labels): 277 | print("Please enter valid labels that are already in the labels of images you have stored.") 278 | elif all(elem in all_labels for elem in desired_labels): 279 | break 280 | print("---") 281 | elif second_choice == "2": 282 | use_catalogue = False 283 | while True: 284 | print("Which labels do you want to use?") 285 | desired_labels = input("Enter the labels (case sensitive) separated by a comma: ") 286 | desired_labels = desired_labels.split(",") 287 | desired_labels = [label.strip() for label in desired_labels if label.strip != ""] 288 | if len(desired_labels) < 2: 289 | print("Please enter at least two labels.") 290 | else: 291 | break 292 | print("---") 293 | 294 | while True: 295 | print("How many images do you want to scrape at most for each label?") 296 | limit = input("Enter the limit (Enter nothing to pass): ") 297 | if limit == "": 298 | break 299 | elif limit.isdigit(): 300 | limit = int(limit) 301 | break 302 | else: 303 | print("Please enter a valid integer.") 304 | print("---") 305 | 306 | while True: 307 | api_key = input("Enter your SerpApi API key: ") 308 | if api_key == "": 309 | print("Please enter a valid API key.") 310 | else: 311 | break 312 | print("---") 313 | targets = [{"q": label} for label in desired_labels] 314 | elif choice == "2": 315 | while True: 316 | print("Which model do you want to use?") 317 | model_df = pd.read_csv("models/old_models.csv") 318 | print(model_df) 319 | model_path = input("Enter the model path: ") 320 | if model_path not in list(model_df["model_path"]): 321 | print("Please enter a valid model path.") 322 | elif not os.path.isfile(model_path): 323 | print("The model exists in CSV, but there is no model at the path. Please enter a valid model path.") 324 | else: 325 | break 326 | print("---") 327 | 328 | while True: 329 | image_path = input("Enter the image path you want to predict: ") 330 | if not os.path.isfile(image_path): 331 | print("Please enter a valid image path.") 332 | else: 333 | break 334 | 335 | print("---") 336 | return targets, train_new_model, use_catalogue, api_key, model_path, limit, image_path 337 | 338 | logging.set_verbosity_error() 339 | parser = argparse.ArgumentParser() 340 | 341 | # Mode 342 | parser.add_argument('--train', action='store_true', help='Whether to train a new model') 343 | parser.add_argument('--model-path', type=str, help='Pretrained Model path you want to use') 344 | parser.add_argument('--dialogue', action='store_false', help='Whether to use dialogue to navigate through the program') 345 | 346 | # Training 347 | parser.add_argument('--use-old-images', action='store_true', help='Whether to use old images you have downloaded to train a new model') 348 | parser.add_argument('--api-key', type=str, help='SerpApi API Key') 349 | parser.add_argument('--limit', type=int, help='Number of images you want to scrape at most for each label') 350 | parser.add_argument('--labels', type=str, nargs='+', help='Labels you want to use to train a new model') 351 | 352 | # Prediction 353 | parser.add_argument('--image-path', type=str, help='Path to the image you want to classify') 354 | 355 | args = parser.parse_args() 356 | 357 | # New Training with Image Scraping 358 | if args.train and args.labels and args.image_path and not args.model_path: 359 | if args.limit: 360 | limit = args.limit 361 | else: 362 | limit = None 363 | 364 | if not args.api_key: 365 | print("You need to enter your SerpApi API key to scrape new images.") 366 | sys.exit(1) 367 | 368 | labels = [label.replace(",","").strip() for label in args.labels if label.replace(",","").strip() != ""] 369 | targets = [{"q": label} for label in labels] 370 | 371 | train_new_model = True 372 | use_catalogue = False 373 | api_key = args.api_key 374 | image_path = args.image_path 375 | 376 | model, target_labels = train_a_new_model(targets, use_catalogue, api_key, limit) 377 | print("---") 378 | print("The image contains {}".format(predict_image(image_path, model, target_labels))) 379 | 380 | # New Training without Old Scraped Images 381 | elif args.train and args.labels and args.use_old_images and args.image_path and not args.model_path: 382 | labels = [label.replace(",","").strip() for label in args.labels if label.replace(",","").strip() != ""] 383 | targets = [{"q": label} for label in labels] 384 | train_new_model = True 385 | use_catalogue = True 386 | api_key = None 387 | limit = None 388 | image_path = args.image_path 389 | 390 | old_df = pd.read_csv("images/catalogue.csv") 391 | all_labels = list(set(old_df["label"])) 392 | if len(labels) < 2: 393 | print("Please enter at least two labels.") 394 | sys.exit(1) 395 | elif not all(elem in all_labels for elem in labels): 396 | print("Please enter labels that are in the catalogue.") 397 | sys.exit(1) 398 | 399 | model, target_labels = train_a_new_model(targets, use_catalogue, api_key, limit) 400 | print("---") 401 | print("The image contains {}".format(predict_image(image_path, model, target_labels))) 402 | 403 | # Old Model Prediction 404 | elif args.model_path and args.image_path and not args.train: 405 | model_path = args.model_path 406 | image_path = args.image_path 407 | 408 | if not os.path.isfile(model_path): 409 | print("Please enter a valid model path.") 410 | sys.exit(1) 411 | 412 | model, target_labels = use_old_model(model_path) 413 | print("---") 414 | print("The image contains {}".format(predict_image(image_path, model, target_labels))) 415 | 416 | # Actions from Dialogue 417 | elif args.dialogue and not args.model_path and not args.image_path and not args.train: 418 | targets, train_new_model, use_catalogue, api_key, model_path, limit, image_path = questions() 419 | 420 | if train_new_model: 421 | model, target_labels = train_a_new_model(targets, use_catalogue, api_key, limit) 422 | else: 423 | model, target_labels = use_old_model(model_path) 424 | 425 | print("---") 426 | print("The image contains {}".format(predict_image(image_path, model, target_labels))) 427 | else: 428 | print("Please enter the correct arguments.") 429 | sys.exit(1) 430 | 431 | # Tips for Advanced Usage 432 | # Below is an example of how to target specific images. 433 | # 434 | #targets = [ 435 | # { 436 | # "q": "Elephant", 437 | # "page": 10, 438 | # "chips": "male" 439 | # } 440 | #] 441 | # `q` stands for the query you want to make to 442 | # SerpApi's Google Images Scraper API. 443 | # 444 | # `page` stands for how many pages you want to 445 | # scrape. Each page has 100 images. Not all images 446 | # are usable for training. But it will download a 447 | # lot of images enough for you to finetune ResNet50. 448 | # 449 | # `chips` stands for the chips you want to add to 450 | # the query. Chips are the labels you want to add 451 | # to the query. For example, if you want to target 452 | # male elephants only, you can use the chips below. 453 | # The script will make a double call to create a 454 | # chips search. The chips can be found on top of 455 | # the page just below the search bar. 456 | # 457 | # You can tweak the code to insert manual targets 458 | # to the program. 459 | --------------------------------------------------------------------------------