├── .gitattributes
├── LICENSE
├── README.md
├── images
└── catalogue.csv
├── lionimage.jpg
├── models
├── 1.pth
└── old_models.csv
├── requirements.txt
└── serapis.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | models/*.pth filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Mozilla Public License Version 2.0
2 | ==================================
3 |
4 | 1. Definitions
5 | --------------
6 |
7 | 1.1. "Contributor"
8 | means each individual or legal entity that creates, contributes to
9 | the creation of, or owns Covered Software.
10 |
11 | 1.2. "Contributor Version"
12 | means the combination of the Contributions of others (if any) used
13 | by a Contributor and that particular Contributor's Contribution.
14 |
15 | 1.3. "Contribution"
16 | means Covered Software of a particular Contributor.
17 |
18 | 1.4. "Covered Software"
19 | means Source Code Form to which the initial Contributor has attached
20 | the notice in Exhibit A, the Executable Form of such Source Code
21 | Form, and Modifications of such Source Code Form, in each case
22 | including portions thereof.
23 |
24 | 1.5. "Incompatible With Secondary Licenses"
25 | means
26 |
27 | (a) that the initial Contributor has attached the notice described
28 | in Exhibit B to the Covered Software; or
29 |
30 | (b) that the Covered Software was made available under the terms of
31 | version 1.1 or earlier of the License, but not also under the
32 | terms of a Secondary License.
33 |
34 | 1.6. "Executable Form"
35 | means any form of the work other than Source Code Form.
36 |
37 | 1.7. "Larger Work"
38 | means a work that combines Covered Software with other material, in
39 | a separate file or files, that is not Covered Software.
40 |
41 | 1.8. "License"
42 | means this document.
43 |
44 | 1.9. "Licensable"
45 | means having the right to grant, to the maximum extent possible,
46 | whether at the time of the initial grant or subsequently, any and
47 | all of the rights conveyed by this License.
48 |
49 | 1.10. "Modifications"
50 | means any of the following:
51 |
52 | (a) any file in Source Code Form that results from an addition to,
53 | deletion from, or modification of the contents of Covered
54 | Software; or
55 |
56 | (b) any new file in Source Code Form that contains any Covered
57 | Software.
58 |
59 | 1.11. "Patent Claims" of a Contributor
60 | means any patent claim(s), including without limitation, method,
61 | process, and apparatus claims, in any patent Licensable by such
62 | Contributor that would be infringed, but for the grant of the
63 | License, by the making, using, selling, offering for sale, having
64 | made, import, or transfer of either its Contributions or its
65 | Contributor Version.
66 |
67 | 1.12. "Secondary License"
68 | means either the GNU General Public License, Version 2.0, the GNU
69 | Lesser General Public License, Version 2.1, the GNU Affero General
70 | Public License, Version 3.0, or any later versions of those
71 | licenses.
72 |
73 | 1.13. "Source Code Form"
74 | means the form of the work preferred for making modifications.
75 |
76 | 1.14. "You" (or "Your")
77 | means an individual or a legal entity exercising rights under this
78 | License. For legal entities, "You" includes any entity that
79 | controls, is controlled by, or is under common control with You. For
80 | purposes of this definition, "control" means (a) the power, direct
81 | or indirect, to cause the direction or management of such entity,
82 | whether by contract or otherwise, or (b) ownership of more than
83 | fifty percent (50%) of the outstanding shares or beneficial
84 | ownership of such entity.
85 |
86 | 2. License Grants and Conditions
87 | --------------------------------
88 |
89 | 2.1. Grants
90 |
91 | Each Contributor hereby grants You a world-wide, royalty-free,
92 | non-exclusive license:
93 |
94 | (a) under intellectual property rights (other than patent or trademark)
95 | Licensable by such Contributor to use, reproduce, make available,
96 | modify, display, perform, distribute, and otherwise exploit its
97 | Contributions, either on an unmodified basis, with Modifications, or
98 | as part of a Larger Work; and
99 |
100 | (b) under Patent Claims of such Contributor to make, use, sell, offer
101 | for sale, have made, import, and otherwise transfer either its
102 | Contributions or its Contributor Version.
103 |
104 | 2.2. Effective Date
105 |
106 | The licenses granted in Section 2.1 with respect to any Contribution
107 | become effective for each Contribution on the date the Contributor first
108 | distributes such Contribution.
109 |
110 | 2.3. Limitations on Grant Scope
111 |
112 | The licenses granted in this Section 2 are the only rights granted under
113 | this License. No additional rights or licenses will be implied from the
114 | distribution or licensing of Covered Software under this License.
115 | Notwithstanding Section 2.1(b) above, no patent license is granted by a
116 | Contributor:
117 |
118 | (a) for any code that a Contributor has removed from Covered Software;
119 | or
120 |
121 | (b) for infringements caused by: (i) Your and any other third party's
122 | modifications of Covered Software, or (ii) the combination of its
123 | Contributions with other software (except as part of its Contributor
124 | Version); or
125 |
126 | (c) under Patent Claims infringed by Covered Software in the absence of
127 | its Contributions.
128 |
129 | This License does not grant any rights in the trademarks, service marks,
130 | or logos of any Contributor (except as may be necessary to comply with
131 | the notice requirements in Section 3.4).
132 |
133 | 2.4. Subsequent Licenses
134 |
135 | No Contributor makes additional grants as a result of Your choice to
136 | distribute the Covered Software under a subsequent version of this
137 | License (see Section 10.2) or under the terms of a Secondary License (if
138 | permitted under the terms of Section 3.3).
139 |
140 | 2.5. Representation
141 |
142 | Each Contributor represents that the Contributor believes its
143 | Contributions are its original creation(s) or it has sufficient rights
144 | to grant the rights to its Contributions conveyed by this License.
145 |
146 | 2.6. Fair Use
147 |
148 | This License is not intended to limit any rights You have under
149 | applicable copyright doctrines of fair use, fair dealing, or other
150 | equivalents.
151 |
152 | 2.7. Conditions
153 |
154 | Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
155 | in Section 2.1.
156 |
157 | 3. Responsibilities
158 | -------------------
159 |
160 | 3.1. Distribution of Source Form
161 |
162 | All distribution of Covered Software in Source Code Form, including any
163 | Modifications that You create or to which You contribute, must be under
164 | the terms of this License. You must inform recipients that the Source
165 | Code Form of the Covered Software is governed by the terms of this
166 | License, and how they can obtain a copy of this License. You may not
167 | attempt to alter or restrict the recipients' rights in the Source Code
168 | Form.
169 |
170 | 3.2. Distribution of Executable Form
171 |
172 | If You distribute Covered Software in Executable Form then:
173 |
174 | (a) such Covered Software must also be made available in Source Code
175 | Form, as described in Section 3.1, and You must inform recipients of
176 | the Executable Form how they can obtain a copy of such Source Code
177 | Form by reasonable means in a timely manner, at a charge no more
178 | than the cost of distribution to the recipient; and
179 |
180 | (b) You may distribute such Executable Form under the terms of this
181 | License, or sublicense it under different terms, provided that the
182 | license for the Executable Form does not attempt to limit or alter
183 | the recipients' rights in the Source Code Form under this License.
184 |
185 | 3.3. Distribution of a Larger Work
186 |
187 | You may create and distribute a Larger Work under terms of Your choice,
188 | provided that You also comply with the requirements of this License for
189 | the Covered Software. If the Larger Work is a combination of Covered
190 | Software with a work governed by one or more Secondary Licenses, and the
191 | Covered Software is not Incompatible With Secondary Licenses, this
192 | License permits You to additionally distribute such Covered Software
193 | under the terms of such Secondary License(s), so that the recipient of
194 | the Larger Work may, at their option, further distribute the Covered
195 | Software under the terms of either this License or such Secondary
196 | License(s).
197 |
198 | 3.4. Notices
199 |
200 | You may not remove or alter the substance of any license notices
201 | (including copyright notices, patent notices, disclaimers of warranty,
202 | or limitations of liability) contained within the Source Code Form of
203 | the Covered Software, except that You may alter any license notices to
204 | the extent required to remedy known factual inaccuracies.
205 |
206 | 3.5. Application of Additional Terms
207 |
208 | You may choose to offer, and to charge a fee for, warranty, support,
209 | indemnity or liability obligations to one or more recipients of Covered
210 | Software. However, You may do so only on Your own behalf, and not on
211 | behalf of any Contributor. You must make it absolutely clear that any
212 | such warranty, support, indemnity, or liability obligation is offered by
213 | You alone, and You hereby agree to indemnify every Contributor for any
214 | liability incurred by such Contributor as a result of warranty, support,
215 | indemnity or liability terms You offer. You may include additional
216 | disclaimers of warranty and limitations of liability specific to any
217 | jurisdiction.
218 |
219 | 4. Inability to Comply Due to Statute or Regulation
220 | ---------------------------------------------------
221 |
222 | If it is impossible for You to comply with any of the terms of this
223 | License with respect to some or all of the Covered Software due to
224 | statute, judicial order, or regulation then You must: (a) comply with
225 | the terms of this License to the maximum extent possible; and (b)
226 | describe the limitations and the code they affect. Such description must
227 | be placed in a text file included with all distributions of the Covered
228 | Software under this License. Except to the extent prohibited by statute
229 | or regulation, such description must be sufficiently detailed for a
230 | recipient of ordinary skill to be able to understand it.
231 |
232 | 5. Termination
233 | --------------
234 |
235 | 5.1. The rights granted under this License will terminate automatically
236 | if You fail to comply with any of its terms. However, if You become
237 | compliant, then the rights granted under this License from a particular
238 | Contributor are reinstated (a) provisionally, unless and until such
239 | Contributor explicitly and finally terminates Your grants, and (b) on an
240 | ongoing basis, if such Contributor fails to notify You of the
241 | non-compliance by some reasonable means prior to 60 days after You have
242 | come back into compliance. Moreover, Your grants from a particular
243 | Contributor are reinstated on an ongoing basis if such Contributor
244 | notifies You of the non-compliance by some reasonable means, this is the
245 | first time You have received notice of non-compliance with this License
246 | from such Contributor, and You become compliant prior to 30 days after
247 | Your receipt of the notice.
248 |
249 | 5.2. If You initiate litigation against any entity by asserting a patent
250 | infringement claim (excluding declaratory judgment actions,
251 | counter-claims, and cross-claims) alleging that a Contributor Version
252 | directly or indirectly infringes any patent, then the rights granted to
253 | You by any and all Contributors for the Covered Software under Section
254 | 2.1 of this License shall terminate.
255 |
256 | 5.3. In the event of termination under Sections 5.1 or 5.2 above, all
257 | end user license agreements (excluding distributors and resellers) which
258 | have been validly granted by You or Your distributors under this License
259 | prior to termination shall survive termination.
260 |
261 | ************************************************************************
262 | * *
263 | * 6. Disclaimer of Warranty *
264 | * ------------------------- *
265 | * *
266 | * Covered Software is provided under this License on an "as is" *
267 | * basis, without warranty of any kind, either expressed, implied, or *
268 | * statutory, including, without limitation, warranties that the *
269 | * Covered Software is free of defects, merchantable, fit for a *
270 | * particular purpose or non-infringing. The entire risk as to the *
271 | * quality and performance of the Covered Software is with You. *
272 | * Should any Covered Software prove defective in any respect, You *
273 | * (not any Contributor) assume the cost of any necessary servicing, *
274 | * repair, or correction. This disclaimer of warranty constitutes an *
275 | * essential part of this License. No use of any Covered Software is *
276 | * authorized under this License except under this disclaimer. *
277 | * *
278 | ************************************************************************
279 |
280 | ************************************************************************
281 | * *
282 | * 7. Limitation of Liability *
283 | * -------------------------- *
284 | * *
285 | * Under no circumstances and under no legal theory, whether tort *
286 | * (including negligence), contract, or otherwise, shall any *
287 | * Contributor, or anyone who distributes Covered Software as *
288 | * permitted above, be liable to You for any direct, indirect, *
289 | * special, incidental, or consequential damages of any character *
290 | * including, without limitation, damages for lost profits, loss of *
291 | * goodwill, work stoppage, computer failure or malfunction, or any *
292 | * and all other commercial damages or losses, even if such party *
293 | * shall have been informed of the possibility of such damages. This *
294 | * limitation of liability shall not apply to liability for death or *
295 | * personal injury resulting from such party's negligence to the *
296 | * extent applicable law prohibits such limitation. Some *
297 | * jurisdictions do not allow the exclusion or limitation of *
298 | * incidental or consequential damages, so this exclusion and *
299 | * limitation may not apply to You. *
300 | * *
301 | ************************************************************************
302 |
303 | 8. Litigation
304 | -------------
305 |
306 | Any litigation relating to this License may be brought only in the
307 | courts of a jurisdiction where the defendant maintains its principal
308 | place of business and such litigation shall be governed by laws of that
309 | jurisdiction, without reference to its conflict-of-law provisions.
310 | Nothing in this Section shall prevent a party's ability to bring
311 | cross-claims or counter-claims.
312 |
313 | 9. Miscellaneous
314 | ----------------
315 |
316 | This License represents the complete agreement concerning the subject
317 | matter hereof. If any provision of this License is held to be
318 | unenforceable, such provision shall be reformed only to the extent
319 | necessary to make it enforceable. Any law or regulation which provides
320 | that the language of a contract shall be construed against the drafter
321 | shall not be used to construe this License against a Contributor.
322 |
323 | 10. Versions of the License
324 | ---------------------------
325 |
326 | 10.1. New Versions
327 |
328 | Mozilla Foundation is the license steward. Except as provided in Section
329 | 10.3, no one other than the license steward has the right to modify or
330 | publish new versions of this License. Each version will be given a
331 | distinguishing version number.
332 |
333 | 10.2. Effect of New Versions
334 |
335 | You may distribute the Covered Software under the terms of the version
336 | of the License under which You originally received the Covered Software,
337 | or under the terms of any subsequent version published by the license
338 | steward.
339 |
340 | 10.3. Modified Versions
341 |
342 | If you create software not governed by this License, and you want to
343 | create a new license for such software, you may create and use a
344 | modified version of this License if you rename the license and remove
345 | any references to the name of the license steward (except to note that
346 | such modified license differs from this License).
347 |
348 | 10.4. Distributing Source Code Form that is Incompatible With Secondary
349 | Licenses
350 |
351 | If You choose to distribute Source Code Form that is Incompatible With
352 | Secondary Licenses under the terms of this version of the License, the
353 | notice described in Exhibit B of this License must be attached.
354 |
355 | Exhibit A - Source Code Form License Notice
356 | -------------------------------------------
357 |
358 | This Source Code Form is subject to the terms of the Mozilla Public
359 | License, v. 2.0. If a copy of the MPL was not distributed with this
360 | file, You can obtain one at http://mozilla.org/MPL/2.0/.
361 |
362 | If it is not possible or desirable to put the notice in a particular
363 | file, then You may include the notice in a location (such as a LICENSE
364 | file in a relevant directory) where a recipient would be likely to look
365 | for such a notice.
366 |
367 | You may add additional accurate notices of copyright ownership.
368 |
369 | Exhibit B - "Incompatible With Secondary Licenses" Notice
370 | ---------------------------------------------------------
371 |
372 | This Source Code Form is "Incompatible With Secondary Licenses", as
373 | defined by the Mozilla Public License, v. 2.0.
374 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Serapis AI Image Classifier
2 |
3 |
4 |
5 |
6 |
7 |
8 | Serapis AI Image Classifier is a program that allows you to automatically create image datasets using SerpApi's Google Images Scraper API, finetune a ResNet50 model, and classify images using the trained model.
9 |
10 |
11 |
12 |
13 | ---
14 |
15 | Installation
16 |
17 | You can install these dependencies using the following command:
18 | ```bash
19 | pip install -r requirements.txt
20 | ```
21 |
22 | ---
23 |
24 | Usage
25 |
26 | You can use Serapis AI Image Classifier in one of the following three modes:
27 |
28 | ---
29 |
30 | Create a dataset and train a new model from scratch
31 |
32 | To create a dataset and train a new model from scratch, you will need to provide a list of labels and an image to use as a reference for the scraping process.
33 | [SerpApi API Key](https://serpapi.com/manage-api-key) is necessary for this mode for the program to automatically scrape images you will use in your database using [SerpApi's Google Images Scraper API](https://serpapi.com/images-results).
34 |
35 | You can [register to SerpApi to claim free credits](https://serpapi.com/users/sign_up).
36 |
37 | ```bash
38 | python serapis.py --train --labels eagle, bull, lion, man --image-path lionimage.jpg --api-key
39 | ```
40 |
41 | ---
42 |
43 | Use old scraped images and train a new model
44 |
45 | To use old scraped images and train a new model, you will need to provide a list of labels and specify that you want to use old images with the --use-old-images flag.
46 | You can also put your images in the `images/` folder and also add enter them in `images/catalogue.csv` to manually train models using your own dataset.
47 |
48 | ```bash
49 | python serapis.py --train --labels eagle, bull, lion, man --use-old-images --image-path lionimage.jpg
50 | ```
51 |
52 | ---
53 |
54 | Use a previously trained model
55 |
56 | To use a previously trained model, you will need to provide the path to the trained model and an image to classify.
57 |
58 | ```bash
59 | python serapis.py --model-path models/1.pth --image-path lionimage.jpg
60 | ```
61 |
62 | ---
63 |
64 | Dialogue Mode
65 |
66 | You can also navigate the program by not providing any arguments and using the dialogue mode:
67 |
68 | ```bash
69 | python serapis.py
70 | ```
71 |
72 | ---
73 |
74 | The Output
75 |
76 |
77 |
78 |
79 | The output will give you the answer:
80 |
81 | ```bash
82 | The image contains Lion
83 | ```
84 |
85 | ---
86 |
87 | Optional Arguments:
88 |
89 | -h, --help Help to Nagigate
90 | --train Whether to train a new model
91 | --model-path MODEL_PATH Pretrained Model path you want to use
92 | --dialogue Whether to use dialogue to navigate through the program
93 | --use-old-images Whether to use old images you have downloaded to train a new model
94 | --api-key API_KEY SerpApi API Key
95 | --limit LIMIT Number of images you want to scrape at most for each label
96 | --labels LABELS [LABELS ...] Labels you want to use to train a new model
97 | --image-path IMAGE_PATH Path to the image you want to classify
98 |
--------------------------------------------------------------------------------
/images/catalogue.csv:
--------------------------------------------------------------------------------
1 | label,image_path
2 |
--------------------------------------------------------------------------------
/lionimage.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/serpapi/serapis-ai-image-classifier/0c9ca68fa7070e98d2994b856bc915c59a06830e/lionimage.jpg
--------------------------------------------------------------------------------
/models/1.pth:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:70a8d1b09e6d56878ffb05d373b12ee3c69bb30f8a3af59a4f8cb1e4daeba444
3 | size 94404289
4 |
--------------------------------------------------------------------------------
/models/old_models.csv:
--------------------------------------------------------------------------------
1 | model_path,target_labels
2 | models/1.pth,Eagle--Bull--Lion--Human
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pandas==1.4.4
2 | Pillow==9.4.0
3 | requests==2.28.1
4 | torch==1.13.0
5 | torchvision==0.13.1a0
6 | transformers==4.24.0
7 |
--------------------------------------------------------------------------------
/serapis.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.filterwarnings("ignore", category=UserWarning, module='jax')
3 | warnings.filterwarnings("ignore", module='PIL')
4 |
5 | from transformers import ResNetForImageClassification
6 | from torch.utils.data import DataLoader, Dataset
7 | from torchvision import transforms
8 | from transformers import logging
9 | import torch.nn.functional as F
10 | from PIL import Image
11 | import pandas as pd
12 | import urllib.parse
13 | import argparse
14 | import requests
15 | import asyncio
16 | import torch
17 | import sys
18 | import os
19 | import io
20 |
21 | async def call_serpapi(target, api_key, ijn: 0):
22 | q = urllib.parse.quote(target["q"])
23 | url = "https://serpapi.com/search.json?api_key={}&tbm=isch&q={}&ijn={}".format(api_key, q, ijn)
24 | if "chips" in target:
25 | chips = None
26 |
27 | # First call to retrieve chips from the search
28 | response = requests.get(url)
29 | results = response.json()
30 | if "error" in results:
31 | print(results["error"])
32 | sys.exit(1)
33 |
34 | if "chips" in suggestion and suggestion["name"] == target["chips"]:
35 | if "suggested_searches" in results:
36 | for suggestion in results["suggested_searches"]:
37 | chips = urllib.parse.quote(suggestion["chips"])
38 |
39 | if chips != None:
40 | # Second call to make a chips search
41 | url = "https://serpapi.com/search.json?api_key={}&tbm=isch&q={}&chips={}&ijn={}".format(api_key, q, chips, ijn)
42 | response = requests.get(url)
43 | else:
44 | # Call without chips
45 | response = requests.get(url)
46 | return response.json()
47 |
48 | async def search(targets, api_key):
49 | images_results = []
50 | for target in targets:
51 | searches = []
52 | # First page results
53 | searches.append(call_serpapi(target, api_key, ijn=0))
54 | if "page" in target and target["page"] > 1:
55 | # If more than one page is requested
56 | [searches.append(call_serpapi(target, api_key, ijn=i)) for i in range(target["page"]) if i != 0]
57 | images = await asyncio.gather(*searches)
58 | images_results.append([target["q"], images])
59 | return images_results
60 |
61 | async def get_single_image_url(label, url, df):
62 | try:
63 | response = requests.get(url, timeout=2)
64 | f = io.BytesIO(response.content)
65 |
66 | last_item = len(os.listdir("images"))
67 | image_path = "images/{}.png".format(last_item)
68 |
69 | im = Image.open(f)
70 | im.save(image_path, format='PNG', quality=95)
71 |
72 | df.loc[0] = [label, image_path]
73 | print("Downloaded {}".format(url))
74 | return df
75 | except:
76 | return df
77 |
78 | async def get_images(images_results, limit = None):
79 | all_dfs = []
80 | for images in images_results:
81 | calls = []
82 | if len(images) == 2:
83 | for page_result in images[1]:
84 | if "images_results" in page_result:
85 | for i in range(len(page_result["images_results"])):
86 | result = page_result["images_results"][i]
87 | if limit != None and i == limit:
88 | break
89 |
90 | if "original" in result:
91 | df = pd.DataFrame(columns=["label", "image_path"])
92 | url = result["original"]
93 | calls.append(get_single_image_url(images[0], url, df))
94 | print("Added Coroutine: {} - {}".format(images[0], url, df))
95 | dfs = await asyncio.gather(*calls)
96 | all_dfs = all_dfs + dfs
97 | df = pd.concat(all_dfs)
98 | print("---")
99 | return df
100 |
101 |
102 | class CustomDataset(Dataset):
103 | def __init__(self, df):
104 | self.df = df
105 |
106 | def __len__(self):
107 | return len(self.df)
108 |
109 | def preprocess_image(self, image_path):
110 | # Load image using PIL
111 | image = Image.open(image_path)
112 | # Convert image to RGB if it is not already in that format
113 | if image.mode != 'RGB':
114 | image = image.convert('RGB')
115 | # Resize image to a fixed size
116 | image = image.resize((224, 224))
117 | # Convert image to a tensor
118 | image = transforms.ToTensor()(image)
119 | # Normalize image
120 | image = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image)
121 | return image
122 |
123 | def __getitem__(self, index):
124 | row = self.df.iloc[index]
125 | label = row['label']
126 | image_path = row['image_path']
127 | # Load and pre-process image data here
128 | image_data = self.preprocess_image(image_path)
129 | return image_data, label
130 |
131 | def train_and_save_model(df, target_labels):
132 | print("Training a new model. If you get a warning about shapes below, you can ignore it.")
133 | dataset = CustomDataset(df)
134 | dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
135 | class_to_idx = {class_name: idx for idx, class_name in enumerate(target_labels)}
136 | idx_to_class = {idx: class_name for idx, class_name in enumerate(target_labels)}
137 | model = ResNetForImageClassification.from_pretrained(
138 | "microsoft/resnet-50",
139 | num_labels=len(target_labels),
140 | id2label=idx_to_class,
141 | label2id=class_to_idx,
142 | ignore_mismatched_sizes=True
143 | )
144 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
145 | total_examples = len(dataset)
146 | num_epochs = 2
147 | processed_examples = 0
148 |
149 | print("---")
150 | for epoch in range(num_epochs):
151 | for data in dataloader:
152 | images, labels = data
153 | model.train()
154 | model.zero_grad()
155 |
156 | logits = model(images).logits
157 | labels = [class_to_idx[label] for label in labels if label in class_to_idx]
158 | labels = torch.tensor(labels)
159 |
160 | loss = F.cross_entropy(logits, labels)
161 | loss.backward()
162 | optimizer.step()
163 |
164 | model.eval()
165 | logits = model(images, labels = labels).logits
166 | _, predicted_labels = logits.max(dim=1)
167 | accuracy = (predicted_labels == labels).float().mean()
168 |
169 | batch_size = images.size(0)
170 | processed_examples += batch_size
171 | progress = (processed_examples / (total_examples * num_epochs)) * 100
172 | print(f'Accuracy: {accuracy} | Progress: {progress:.2f}%')
173 |
174 | last_item = len(os.listdir("models"))
175 | model_path = "models/{}.pth".format(last_item)
176 | torch.save(model.state_dict(), model_path)
177 | df = pd.DataFrame(columns=["model_path", "target_labels"])
178 | df.loc[0] = [model_path, "--".join(target_labels)]
179 | old_df = pd.read_csv("models/old_models.csv")
180 | new_df = pd.concat([df, old_df], ignore_index=True)
181 | new_df.to_csv("models/old_models.csv", index=False)
182 |
183 | return model
184 |
185 | def train_a_new_model(targets=None, use_catalogue=True, api_key=None, limit=None):
186 | target_labels = [dictionary["q"] for dictionary in targets]
187 | if use_catalogue:
188 | # Call the Catalogue CSV and get only the keys you targeted.
189 | df = pd.read_csv("images/catalogue.csv")
190 | all_labels = list(set(df["label"]))
191 | for label in all_labels:
192 | if label not in target_labels:
193 | df = df.drop(df[df["label"] == label].index)
194 | else:
195 | # Save the new images to Catalogue CSV but use only the ones you targeted.
196 | images_results = asyncio.run(search(targets, api_key))
197 | df = asyncio.run(get_images(images_results, limit = limit))
198 | old_df = pd.read_csv("images/catalogue.csv")
199 | new_df = pd.concat([df, old_df], ignore_index=True)
200 | new_df.to_csv("images/catalogue.csv", index=False)
201 |
202 | model = train_and_save_model(df, target_labels)
203 | return model, target_labels
204 |
205 | def use_old_model(model_path=None):
206 | model_df = pd.read_csv("models/old_models.csv")
207 | target_labels = model_df.loc[model_df[model_df["model_path"] == model_path].index]["target_labels"].iloc[0]
208 | target_labels = target_labels.split("--")
209 | class_to_idx = {class_name: idx for idx, class_name in enumerate(target_labels)}
210 | idx_to_class = {idx: class_name for idx, class_name in enumerate(target_labels)}
211 | model = ResNetForImageClassification.from_pretrained(
212 | "microsoft/resnet-50",
213 | num_labels=len(target_labels),
214 | id2label=idx_to_class,
215 | label2id=class_to_idx,
216 | ignore_mismatched_sizes=True
217 | )
218 | state_dict = torch.load(model_path)
219 | model.load_state_dict(state_dict)
220 | return model, target_labels
221 |
222 | def predict_image(image_path=None, model=None, target_labels=None):
223 | df = pd.DataFrame(columns=["label", "image_path"])
224 | dataloader = CustomDataset(df)
225 | image = dataloader.preprocess_image(image_path)
226 | model.eval()
227 | logits = model(image.unsqueeze(0)).logits
228 | _, predicted_labels = logits.max(dim=1)
229 | predicted_class_names = [target_labels[int(label)] for label in predicted_labels]
230 | return predicted_class_names[0]
231 |
232 | def questions():
233 | api_key = ""
234 | model_path = ""
235 | train_new_model = False
236 | use_catalogue = True
237 | limit = None
238 | targets = []
239 | while True:
240 | print("What do you want to do?")
241 | print("1. Train a new model.")
242 | print("2. Use an old model.")
243 | choice = input("Enter your choice: ")
244 | if choice == "1" or choice == "2":
245 | break
246 | else:
247 | print("Please enter a valid choice.")
248 | print("---")
249 |
250 | if choice == "1":
251 | train_new_model = True
252 | while True:
253 | print("Would you like to use the old images you have stored, or scrape fresh images?")
254 | print("1. Use old images.")
255 | print("2. Scrape new images.")
256 | second_choice = input("Enter your choice: ")
257 | if second_choice == "1" or second_choice == "2":
258 | break
259 | else:
260 | print("Please enter a valid choice.")
261 | print("---")
262 |
263 | if second_choice == "1":
264 | use_catalogue = True
265 | old_df = pd.read_csv("images/catalogue.csv")
266 | all_labels = list(set(old_df["label"]))
267 | while True:
268 | print("Here are the labels of images you have stored.")
269 | print("{}".format(", ".join(all_labels)))
270 | print("Which labels do you want to use?")
271 | desired_labels = input("Enter the labels (case sensitive) separated by a comma: ")
272 | desired_labels = desired_labels.split(",")
273 | desired_labels = [label.strip() for label in desired_labels if label.strip != ""]
274 | if len(desired_labels) < 2:
275 | print("Please enter at least two labels.")
276 | elif not all(elem in all_labels for elem in desired_labels):
277 | print("Please enter valid labels that are already in the labels of images you have stored.")
278 | elif all(elem in all_labels for elem in desired_labels):
279 | break
280 | print("---")
281 | elif second_choice == "2":
282 | use_catalogue = False
283 | while True:
284 | print("Which labels do you want to use?")
285 | desired_labels = input("Enter the labels (case sensitive) separated by a comma: ")
286 | desired_labels = desired_labels.split(",")
287 | desired_labels = [label.strip() for label in desired_labels if label.strip != ""]
288 | if len(desired_labels) < 2:
289 | print("Please enter at least two labels.")
290 | else:
291 | break
292 | print("---")
293 |
294 | while True:
295 | print("How many images do you want to scrape at most for each label?")
296 | limit = input("Enter the limit (Enter nothing to pass): ")
297 | if limit == "":
298 | break
299 | elif limit.isdigit():
300 | limit = int(limit)
301 | break
302 | else:
303 | print("Please enter a valid integer.")
304 | print("---")
305 |
306 | while True:
307 | api_key = input("Enter your SerpApi API key: ")
308 | if api_key == "":
309 | print("Please enter a valid API key.")
310 | else:
311 | break
312 | print("---")
313 | targets = [{"q": label} for label in desired_labels]
314 | elif choice == "2":
315 | while True:
316 | print("Which model do you want to use?")
317 | model_df = pd.read_csv("models/old_models.csv")
318 | print(model_df)
319 | model_path = input("Enter the model path: ")
320 | if model_path not in list(model_df["model_path"]):
321 | print("Please enter a valid model path.")
322 | elif not os.path.isfile(model_path):
323 | print("The model exists in CSV, but there is no model at the path. Please enter a valid model path.")
324 | else:
325 | break
326 | print("---")
327 |
328 | while True:
329 | image_path = input("Enter the image path you want to predict: ")
330 | if not os.path.isfile(image_path):
331 | print("Please enter a valid image path.")
332 | else:
333 | break
334 |
335 | print("---")
336 | return targets, train_new_model, use_catalogue, api_key, model_path, limit, image_path
337 |
338 | logging.set_verbosity_error()
339 | parser = argparse.ArgumentParser()
340 |
341 | # Mode
342 | parser.add_argument('--train', action='store_true', help='Whether to train a new model')
343 | parser.add_argument('--model-path', type=str, help='Pretrained Model path you want to use')
344 | parser.add_argument('--dialogue', action='store_false', help='Whether to use dialogue to navigate through the program')
345 |
346 | # Training
347 | parser.add_argument('--use-old-images', action='store_true', help='Whether to use old images you have downloaded to train a new model')
348 | parser.add_argument('--api-key', type=str, help='SerpApi API Key')
349 | parser.add_argument('--limit', type=int, help='Number of images you want to scrape at most for each label')
350 | parser.add_argument('--labels', type=str, nargs='+', help='Labels you want to use to train a new model')
351 |
352 | # Prediction
353 | parser.add_argument('--image-path', type=str, help='Path to the image you want to classify')
354 |
355 | args = parser.parse_args()
356 |
357 | # New Training with Image Scraping
358 | if args.train and args.labels and args.image_path and not args.model_path:
359 | if args.limit:
360 | limit = args.limit
361 | else:
362 | limit = None
363 |
364 | if not args.api_key:
365 | print("You need to enter your SerpApi API key to scrape new images.")
366 | sys.exit(1)
367 |
368 | labels = [label.replace(",","").strip() for label in args.labels if label.replace(",","").strip() != ""]
369 | targets = [{"q": label} for label in labels]
370 |
371 | train_new_model = True
372 | use_catalogue = False
373 | api_key = args.api_key
374 | image_path = args.image_path
375 |
376 | model, target_labels = train_a_new_model(targets, use_catalogue, api_key, limit)
377 | print("---")
378 | print("The image contains {}".format(predict_image(image_path, model, target_labels)))
379 |
380 | # New Training without Old Scraped Images
381 | elif args.train and args.labels and args.use_old_images and args.image_path and not args.model_path:
382 | labels = [label.replace(",","").strip() for label in args.labels if label.replace(",","").strip() != ""]
383 | targets = [{"q": label} for label in labels]
384 | train_new_model = True
385 | use_catalogue = True
386 | api_key = None
387 | limit = None
388 | image_path = args.image_path
389 |
390 | old_df = pd.read_csv("images/catalogue.csv")
391 | all_labels = list(set(old_df["label"]))
392 | if len(labels) < 2:
393 | print("Please enter at least two labels.")
394 | sys.exit(1)
395 | elif not all(elem in all_labels for elem in labels):
396 | print("Please enter labels that are in the catalogue.")
397 | sys.exit(1)
398 |
399 | model, target_labels = train_a_new_model(targets, use_catalogue, api_key, limit)
400 | print("---")
401 | print("The image contains {}".format(predict_image(image_path, model, target_labels)))
402 |
403 | # Old Model Prediction
404 | elif args.model_path and args.image_path and not args.train:
405 | model_path = args.model_path
406 | image_path = args.image_path
407 |
408 | if not os.path.isfile(model_path):
409 | print("Please enter a valid model path.")
410 | sys.exit(1)
411 |
412 | model, target_labels = use_old_model(model_path)
413 | print("---")
414 | print("The image contains {}".format(predict_image(image_path, model, target_labels)))
415 |
416 | # Actions from Dialogue
417 | elif args.dialogue and not args.model_path and not args.image_path and not args.train:
418 | targets, train_new_model, use_catalogue, api_key, model_path, limit, image_path = questions()
419 |
420 | if train_new_model:
421 | model, target_labels = train_a_new_model(targets, use_catalogue, api_key, limit)
422 | else:
423 | model, target_labels = use_old_model(model_path)
424 |
425 | print("---")
426 | print("The image contains {}".format(predict_image(image_path, model, target_labels)))
427 | else:
428 | print("Please enter the correct arguments.")
429 | sys.exit(1)
430 |
431 | # Tips for Advanced Usage
432 | # Below is an example of how to target specific images.
433 | #
434 | #targets = [
435 | # {
436 | # "q": "Elephant",
437 | # "page": 10,
438 | # "chips": "male"
439 | # }
440 | #]
441 | # `q` stands for the query you want to make to
442 | # SerpApi's Google Images Scraper API.
443 | #
444 | # `page` stands for how many pages you want to
445 | # scrape. Each page has 100 images. Not all images
446 | # are usable for training. But it will download a
447 | # lot of images enough for you to finetune ResNet50.
448 | #
449 | # `chips` stands for the chips you want to add to
450 | # the query. Chips are the labels you want to add
451 | # to the query. For example, if you want to target
452 | # male elephants only, you can use the chips below.
453 | # The script will make a double call to create a
454 | # chips search. The chips can be found on top of
455 | # the page just below the search bar.
456 | #
457 | # You can tweak the code to insert manual targets
458 | # to the program.
459 |
--------------------------------------------------------------------------------