├── src ├── __init__.py ├── image_generator │ ├── __init__.py │ ├── image_generator.py │ ├── google_image_scraper.py │ └── image_face_processor.py ├── list_generator │ ├── __init__.py │ ├── categories │ │ ├── __init__.py │ │ ├── nba.py │ │ ├── billionaires.py │ │ ├── us_artists.py │ │ ├── us_politicians.py │ │ └── kpop.py │ ├── list_generator.py │ └── fetch_names.py ├── utils.py └── main.py ├── tests ├── __init__.py ├── list_generator │ └── categories │ │ ├── test_nba.py │ │ ├── test_kpop.py │ │ ├── test_us_artists.py │ │ ├── test_billionaires.py │ │ └── test_us_politicians.py └── image_generator │ ├── test_google_image_scraper.py │ └── test_image_face_processor.py ├── requirements.txt ├── LICENSE ├── README.md └── .gitignore /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/image_generator/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/list_generator/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/list_generator/categories/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | dlib 3 | numpy 4 | Pillow 5 | requests 6 | pandas 7 | openpyxl 8 | beautifulsoup4 9 | lxml 10 | pytest 11 | tqdm -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | 4 | 5 | def df_to_excel(df, output_path): 6 | parent_dir = os.path.dirname(output_path) 7 | if not os.path.exists(parent_dir): 8 | os.makedirs(parent_dir) 9 | 10 | df.to_excel(output_path, index=False) 11 | 12 | 13 | def save_image_to_file(image, image_path): 14 | with open(image_path, "wb") as f: 15 | f.write(image) 16 | -------------------------------------------------------------------------------- /src/list_generator/list_generator.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from src.utils import df_to_excel 4 | from src.list_generator.fetch_names import fetch_names 5 | 6 | 7 | def list_generator(category, number_of_persons, lists_dir): 8 | names = fetch_names(category, number_of_persons) 9 | df = pd.DataFrame({"name": names}) 10 | 11 | list_path = f"{lists_dir}/{category}.xlsx" 12 | df_to_excel(df, list_path) 13 | 14 | print(f"List successfully generate in path: {list_path}") 15 | -------------------------------------------------------------------------------- /tests/list_generator/categories/test_nba.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from src.list_generator.categories.nba import ( 3 | fetch_names_nba, 4 | ) 5 | 6 | 7 | def test_fetch_names_nba(): 8 | number_of_persons = 10 9 | 10 | full_names = fetch_names_nba(number_of_persons) 11 | 12 | assert ( 13 | len(full_names) == number_of_persons 14 | ), f"Expected {number_of_persons}, but got {len(full_names)}" 15 | assert all( 16 | isinstance(name, str) for name in full_names 17 | ), "All names should be strings" 18 | -------------------------------------------------------------------------------- /tests/list_generator/categories/test_kpop.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from src.list_generator.categories.kpop import ( 3 | fetch_names_kpop, 4 | ) 5 | 6 | 7 | def test_fetch_names_kpop(): 8 | number_of_persons = 10 9 | 10 | full_names = fetch_names_kpop(number_of_persons) 11 | 12 | assert ( 13 | len(full_names) == number_of_persons 14 | ), f"Expected {number_of_persons}, but got {len(full_names)}" 15 | assert all( 16 | isinstance(name, str) for name in full_names 17 | ), "All names should be strings" 18 | -------------------------------------------------------------------------------- /tests/list_generator/categories/test_us_artists.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from src.list_generator.categories.us_artists import ( 3 | fetch_names_us_artists, 4 | ) 5 | 6 | 7 | def test_fetch_names_us_artists(): 8 | number_of_persons = 10 9 | 10 | full_names = fetch_names_us_artists(number_of_persons) 11 | 12 | assert ( 13 | len(full_names) == number_of_persons 14 | ), f"Expected {number_of_persons}, but got {len(full_names)}" 15 | assert all( 16 | isinstance(name, str) for name in full_names 17 | ), "All names should be strings" 18 | -------------------------------------------------------------------------------- /tests/list_generator/categories/test_billionaires.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from src.list_generator.categories.billionaires import ( 3 | fetch_names_billionaires, 4 | ) 5 | 6 | 7 | def test_fetch_names_billionaires(): 8 | number_of_persons = 10 9 | 10 | full_names = fetch_names_billionaires(number_of_persons) 11 | 12 | assert ( 13 | len(full_names) == number_of_persons 14 | ), f"Expected {number_of_persons}, but got {len(full_names)}" 15 | assert all( 16 | isinstance(name, str) for name in full_names 17 | ), "All names should be strings" 18 | -------------------------------------------------------------------------------- /tests/list_generator/categories/test_us_politicians.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from src.list_generator.categories.us_politicians import ( 3 | fetch_names_us_politicians, 4 | ) 5 | 6 | 7 | def test_fetch_names_us_politicians(): 8 | number_of_persons = 10 9 | 10 | full_names = fetch_names_us_politicians(number_of_persons) 11 | 12 | assert ( 13 | len(full_names) == number_of_persons 14 | ), f"Expected {number_of_persons}, but got {len(full_names)}" 15 | assert all( 16 | isinstance(name, str) for name in full_names 17 | ), "All names should be strings" 18 | -------------------------------------------------------------------------------- /tests/image_generator/test_google_image_scraper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | 4 | from src.image_generator.google_image_scraper import google_image_scraper 5 | 6 | TEST_OUTPUT_DIR = "test_outputs" 7 | 8 | 9 | @pytest.fixture 10 | def test_output_dir(): 11 | os.makedirs(TEST_OUTPUT_DIR, exist_ok=True) 12 | return TEST_OUTPUT_DIR 13 | 14 | 15 | def test_google_image_scraper(test_output_dir): 16 | search_query = "face images" 17 | output_path = os.path.join(test_output_dir, search_query) 18 | google_image_scraper(search_query, output_path) 19 | 20 | downloaded_files = os.listdir(output_path) 21 | assert len(downloaded_files) > 0 22 | -------------------------------------------------------------------------------- /src/image_generator/image_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import pandas as pd 4 | 5 | from src.image_generator.google_image_scraper import google_image_scraper 6 | from src.image_generator.image_face_processor import ImageFaceProcessor 7 | 8 | 9 | def image_generator(category, number_of_images, lists_dir, images_dir): 10 | list_path = os.path.join(lists_dir, f"{category}.xlsx") 11 | names = list(pd.read_excel(list_path)["name"]) 12 | 13 | for name in names: 14 | with tempfile.TemporaryDirectory() as image_tmp_dir: 15 | google_image_scraper(name, image_tmp_dir) 16 | 17 | processor = ImageFaceProcessor() 18 | processor.process_images(image_tmp_dir, f"{images_dir}/{category}/{name}") 19 | -------------------------------------------------------------------------------- /src/list_generator/categories/nba.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | 4 | def fetch_data_nba(number_of_persons): 5 | url = "https://stats.nba.com/stats/leagueLeaders?LeagueID=00&PerMode=PerGame&Scope=S&Season=2024-25&SeasonType=Regular%20Season&StatCategory=PTS" 6 | headers = { 7 | "User-Agent": "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Mobile Safari/537.36" 8 | } 9 | 10 | response = requests.get(url) 11 | data = response.json() 12 | return data["resultSet"]["rowSet"][:number_of_persons] 13 | 14 | 15 | def extract_full_names(person_list): 16 | return [person[2] for person in person_list] 17 | 18 | 19 | def fetch_names_nba(number_of_persons): 20 | nba_list = fetch_data_nba(number_of_persons) 21 | full_names = extract_full_names(nba_list) 22 | return full_names 23 | -------------------------------------------------------------------------------- /src/list_generator/categories/billionaires.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | 4 | def fetch_data_billionaires(number_of_persons): 5 | url = "https://www.forbes.com/forbesapi/person/billionaires/2024/position/true.json?fields=uri,finalWorth,age,country,source,rank,category,personName,industries,organization,gender,firstName,lastName,squareImage,bios,status,countryOfCitizenship" 6 | response = requests.get(url) 7 | data = response.json() 8 | return data["personList"]["personsLists"][:number_of_persons] 9 | 10 | 11 | def extract_full_names(person_list): 12 | return [ 13 | f"{person['firstName']} {person['lastName']}".replace(" & family", "") 14 | for person in person_list 15 | ] 16 | 17 | 18 | def fetch_names_billionaires(number_of_persons): 19 | billionaire_list = fetch_data_billionaires(number_of_persons) 20 | full_names = extract_full_names(billionaire_list) 21 | return full_names 22 | -------------------------------------------------------------------------------- /src/list_generator/fetch_names.py: -------------------------------------------------------------------------------- 1 | from src.list_generator.categories.billionaires import fetch_names_billionaires 2 | from src.list_generator.categories.kpop import fetch_names_kpop 3 | from src.list_generator.categories.nba import fetch_names_nba 4 | from src.list_generator.categories.us_artists import fetch_names_us_artists 5 | from src.list_generator.categories.us_politicians import fetch_names_us_politicians 6 | 7 | 8 | def fetch_names(category, number_of_persons): 9 | if category == "billionaires": 10 | return fetch_names_billionaires(number_of_persons) 11 | if category == "kpop": 12 | return fetch_names_kpop(number_of_persons) 13 | if category == "nba": 14 | return fetch_names_nba(number_of_persons) 15 | if category == "us_artists": 16 | return fetch_names_us_artists(number_of_persons) 17 | if category == "us_politicians": 18 | return fetch_names_us_politicians(number_of_persons) 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 leochen66 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/image_generator/test_image_face_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from datetime import datetime 4 | 5 | from src.image_generator.image_face_processor import ImageFaceProcessor 6 | 7 | 8 | def test_image_face_processor(): 9 | # Create processor instance 10 | processor = ImageFaceProcessor(max_images=10) 11 | 12 | # Setup directories 13 | test_input_dir = "test_outputs" 14 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 15 | test_output_base = f"test_outputs/processed_outputs_{timestamp}" 16 | 17 | # Get all sub-directories 18 | person_dirs = [ 19 | d 20 | for d in os.listdir(test_input_dir) 21 | if os.path.isdir(os.path.join(test_input_dir, d)) 22 | ] 23 | 24 | for person_dir in person_dirs: 25 | input_dir = os.path.join(test_input_dir, person_dir) 26 | output_dir = os.path.join(test_output_base, person_dir) 27 | 28 | # Process images 29 | processor.process_images(input_dir, output_dir) 30 | 31 | # Basic verification 32 | assert os.path.exists( 33 | output_dir 34 | ), f"Output directory for {person_dir} was not created" 35 | processed_images = [ 36 | f 37 | for f in os.listdir(output_dir) 38 | if f.endswith((".jpg", ".jpeg", ".png", ".bmp")) 39 | ] 40 | assert len(processed_images) > 0, f"No images were processed for {person_dir}" 41 | -------------------------------------------------------------------------------- /src/list_generator/categories/us_artists.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import requests 4 | from bs4 import BeautifulSoup 5 | 6 | 7 | def fetch_data_us_artists(number_of_persons): 8 | url = "https://www.billboard.com/charts/artist-100" 9 | headers = { 10 | "User-Agent": "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Mobile Safari/537.36" 11 | } 12 | 13 | try: 14 | response = requests.get(url, headers=headers) 15 | response.raise_for_status() 16 | html = response.text 17 | soup = BeautifulSoup(html, "lxml") 18 | 19 | person_sections = soup.select(".o-chart-results-list-row-container")[ 20 | :number_of_persons 21 | ] 22 | person_names = [ 23 | person_section.select("#title-of-a-story")[0] 24 | .text.lstrip() 25 | .replace("\n", "") 26 | .replace("\t", "") 27 | for person_section in person_sections 28 | ] 29 | 30 | return person_names 31 | 32 | except requests.exceptions.RequestException as e: 33 | raise RuntimeError(f"Error fetching data from the URL: {e}") 34 | except Exception as e: 35 | raise RuntimeError(f"An unexpected error occurred: {e}") 36 | 37 | 38 | def fetch_names_us_artists(number_of_persons): 39 | full_names = fetch_data_us_artists(number_of_persons) 40 | return full_names 41 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from src.list_generator.list_generator import list_generator 5 | from src.image_generator.image_generator import image_generator 6 | 7 | 8 | def main(): 9 | current_dir = os.path.dirname(os.path.abspath(__file__)) 10 | lists_dir = os.path.join(os.path.dirname(current_dir), "lists") 11 | images_dir = os.path.join(os.path.dirname(current_dir), "datasets") 12 | 13 | parser = argparse.ArgumentParser( 14 | description="Command line tool for generating lists or images." 15 | ) 16 | subparsers = parser.add_subparsers( 17 | dest="command", required=True, help="Subcommands: 'list' or 'image'" 18 | ) 19 | 20 | # Subcommand for 'list' 21 | list_parser = subparsers.add_parser("list", help="Generate a list") 22 | list_parser.add_argument( 23 | "-c", "--category", type=str, required=True, help="Specify the category" 24 | ) 25 | list_parser.add_argument( 26 | "-n", 27 | "--number", 28 | type=int, 29 | default=100, 30 | help="Specify the number of persons (default: 100)", 31 | ) 32 | 33 | # Subcommand for 'image' 34 | image_parser = subparsers.add_parser("image", help="Generate images") 35 | image_parser.add_argument( 36 | "-c", "--category", type=str, required=True, help="Specify the category" 37 | ) 38 | image_parser.add_argument( 39 | "-n", 40 | "--number", 41 | type=int, 42 | default=50, 43 | help="Specify the number of persons (default: 50)", 44 | ) 45 | 46 | args = parser.parse_args() 47 | 48 | if args.command == "list": 49 | list_generator(args.category, args.number, lists_dir) 50 | elif args.command == "image": 51 | image_generator(args.category, args.number, lists_dir, images_dir) 52 | 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /src/list_generator/categories/us_politicians.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import requests 4 | from bs4 import BeautifulSoup 5 | 6 | 7 | def fetch_data_us_politicians(number_of_persons): 8 | try: 9 | offset = 0 10 | limit = 20 11 | person_names = [] 12 | base_url = "https://today.yougov.com/_pubapis/v5/us/search/entity/?group=07df945a-adf0-11e9-9161-317b338eee4b&sort_by=popularity&" 13 | headers = { 14 | "User-Agent": "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Mobile Safari/537.36" 15 | } 16 | 17 | while len(person_names) < number_of_persons: 18 | remaining = number_of_persons - len(person_names) 19 | current_limit = min(limit, remaining) 20 | 21 | url = f"{base_url}limit={current_limit}&offset={offset}" 22 | # print(f"Fetching: {url}") 23 | 24 | response = requests.get(url) 25 | if response.status_code == 200: 26 | data = response.json() 27 | data_names = [person["name"] for person in data["data"]] 28 | person_names.extend(data_names) 29 | else: 30 | print(f"Failed to fetch data: {response.status_code}") 31 | break 32 | 33 | offset += current_limit 34 | 35 | return person_names 36 | 37 | except requests.exceptions.RequestException as e: 38 | raise RuntimeError(f"Error fetching data from the URL: {e}") 39 | except json.JSONDecodeError as e: 40 | raise ValueError(f"Error decoding JSON data: {e}") 41 | except Exception as e: 42 | raise RuntimeError(f"An unexpected error occurred: {e}") 43 | 44 | 45 | def fetch_names_us_politicians(number_of_persons): 46 | full_names = fetch_data_us_politicians(number_of_persons) 47 | return full_names 48 | -------------------------------------------------------------------------------- /src/list_generator/categories/kpop.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import requests 4 | from bs4 import BeautifulSoup 5 | 6 | 7 | def fetch_data_kpop(number_of_persons): 8 | url = "https://www.billboard.com/lists/k-pop-artist-100-list-2024-ranked" 9 | headers = { 10 | "User-Agent": "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Mobile Safari/537.36" 11 | } 12 | 13 | try: 14 | response = requests.get(url, headers=headers) 15 | response.raise_for_status() 16 | html = response.text 17 | soup = BeautifulSoup(html, "lxml") 18 | 19 | script_tag = soup.select_one("#pmc-lists-front-js-extra") 20 | if not script_tag: 21 | raise ValueError("Script tag with required data not found.") 22 | 23 | script_content = script_tag.text 24 | match = re.search(r"var pmcGalleryExports = (.*?)};", script_content, re.DOTALL) 25 | if not match: 26 | raise ValueError( 27 | "Required data (pmcGalleryExports) not found in the script tag." 28 | ) 29 | 30 | person_data = match.group(1) + "}" # Add the missing closing brace 31 | person_json = json.loads(person_data) 32 | person_json = person_json.get("gallery", []) 33 | 34 | return person_json[:number_of_persons] 35 | 36 | except requests.exceptions.RequestException as e: 37 | raise RuntimeError(f"Error fetching data from the URL: {e}") 38 | except json.JSONDecodeError as e: 39 | raise ValueError(f"Error decoding JSON data: {e}") 40 | except Exception as e: 41 | raise RuntimeError(f"An unexpected error occurred: {e}") 42 | 43 | 44 | def extract_full_names(person_list): 45 | return [person.get("title") for person in person_list] 46 | 47 | 48 | def fetch_names_kpop(number_of_persons): 49 | kpop_list = fetch_data_kpop(number_of_persons) 50 | full_names = extract_full_names(kpop_list) 51 | return full_names 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Celebrity Face Dataset 2 | 3 | **Celebrity Face Dataset** is an open-source web scraper designed to collect and generate datasets of celebrity faces from various categories. This tool automates the process of gathering names from publicly available sources and retrieving corresponding images from Google Image Search, followed by facial detection and cropping. 4 | 5 | ## How It Works 6 | 7 | The scraper consists of two main components: 8 | 9 | ### 1. List Generator 10 | This module generates a list of celebrities based on different categories. Currently, the following categories are supported: 11 | 12 | | Category | Source URL | 13 | |---------------|------------| 14 | | **Billionaires** | [Forbes Billionaires](https://www.forbes.com/billionaires/) | 15 | | **K-pop Artists** | [Billboard K-pop 100](https://www.billboard.com/lists/k-pop-artist-100-list-2024-ranked/) | 16 | | **NBA Players** | [NBA Stats Leaders](https://www.nba.com/stats/leaders) | 17 | | **US Artists** | [Billboard Artist 100](https://www.billboard.com/charts/artist-100/) | 18 | | **US Politicians** | [YouGov US Politicians](https://today.yougov.com/ratings/politics/popularity/politicians/all) | 19 | 20 | ### 2. Image Generator 21 | This module retrieves celebrity images using Google Image Search and processes them by: 22 | - Scraping images from search results 23 | - Evaluating and detecting the facial regions 24 | - Cropping and saving the images 25 | 26 | It uses the list generated by the **List Generator** to find relevant images. 27 | 28 | ## Dataset Download 29 | The dataset is available on Kaggle: 30 | [Download Celebrity Face Dataset](https://www.kaggle.com/datasets/leochen66/celebrity-face-dataset) 31 | 32 | ## Installation 33 | Ensure you have Python installed, then install the required dependencies: 34 | ```sh 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | ## Usage 39 | 40 | ### Generate a Celebrity List 41 | Run the following command to generate a celebrity list for a given category: 42 | ```sh 43 | cd celebrity-face-dataset 44 | python -m src.main list -c kpop -n 200 45 | ``` 46 | 47 | #### Parameters 48 | | Flag | Description | 49 | |------|------------| 50 | | `-c`, `--category` | Category for generating the list (e.g., `kpop`, `nba`) | 51 | | `-n`, `--number` | Maximum number of names to generate | 52 | 53 | ### Generate Celebrity Images 54 | Run the following command to scrape and process images: 55 | ```sh 56 | python -m src.main image -c kpop -n 50 57 | ``` 58 | 59 | #### Parameters 60 | | Flag | Description | 61 | |------|------------| 62 | | `-c`, `--category` | Category to fetch images for (must match a generated list) | 63 | | `-n`, `--number` | Maximum number of images to retrieve | 64 | 65 | ## Running Tests 66 | To execute all tests: 67 | ```sh 68 | cd celebrity-face-dataset 69 | PYTHONPATH=./ pytest 70 | ``` 71 | 72 | To run a specific unit test: 73 | ```sh 74 | cd celebrity-face-dataset 75 | PYTHONPATH=./ pytest -s tests/image_generator/test_google_image_scraper.py 76 | ``` 77 | 78 | -------------------------------------------------------------------------------- /src/image_generator/google_image_scraper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import re 4 | from tqdm import tqdm 5 | from urllib.parse import quote, unquote 6 | 7 | from src.utils import save_image_to_file 8 | 9 | 10 | def get_image_extension(img_url, response): 11 | if img_url.lower().endswith(".png"): 12 | return ".png" 13 | elif img_url.lower().endswith(".jpg") or img_url.lower().endswith(".jpeg"): 14 | return ".jpg" 15 | else: 16 | content_type = response.headers.get("content-type", "") 17 | if "png" in content_type.lower(): 18 | return ".png" 19 | else: 20 | return ".jpg" 21 | 22 | 23 | def google_image_scraper(search_query, output_dir): 24 | if not os.path.exists(output_dir): 25 | os.makedirs(output_dir) 26 | 27 | headers = { 28 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", 29 | "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8", 30 | "Accept-Language": "en-US,en;q=0.5", 31 | } 32 | 33 | image_pattern = r'https?://[^"\'<>\s]+?(?:\.jpg|\.jpeg|\.png)(?:[^"\'<>\s])*' 34 | downloaded = 0 35 | all_image_urls = set() 36 | 37 | for page in range(2): 38 | encoded_query = quote(search_query) 39 | start_index = page * 20 40 | url = f"https://www.google.com/search?q={encoded_query}&tbm=isch&start={start_index}" 41 | 42 | try: 43 | response = requests.get(url, headers=headers) 44 | response.raise_for_status() 45 | image_urls = set(re.findall(image_pattern, response.text, re.IGNORECASE)) 46 | image_urls = [unquote(url) for url in image_urls] 47 | all_image_urls.update(image_urls) 48 | except requests.exceptions.RequestException as e: 49 | raise RuntimeError(f"Failed to fetch image links: {str(e)}") from e 50 | 51 | num_images = len(all_image_urls) 52 | print(f"Found {num_images} image links for query: {search_query}") 53 | 54 | for img_url in tqdm( 55 | all_image_urls, desc="Downloading Images", unit="image", total=num_images 56 | ): 57 | try: 58 | if not img_url.startswith("http"): 59 | continue 60 | 61 | response = requests.get(img_url, headers=headers, timeout=10) 62 | if response.status_code == 200: 63 | ext = get_image_extension(img_url, response) 64 | image_path = os.path.join(output_dir, f"{downloaded+1}{ext}") 65 | save_image_to_file(response.content, image_path) 66 | downloaded += 1 67 | # else: 68 | # print( 69 | # f"Failed to download image from {img_url}: {response.status_code}" 70 | # ) 71 | except Exception as e: 72 | print(f"Failed to download image from {img_url}") 73 | 74 | print(f"Successfully downloaded {downloaded} images") 75 | return downloaded 76 | -------------------------------------------------------------------------------- /src/image_generator/image_face_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import dlib 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | 8 | class ImageFaceProcessor: 9 | def __init__( 10 | self, 11 | min_face_size_ratio=0.2, 12 | sharpness_threshold=100, 13 | max_images=50, 14 | max_image_side=400, 15 | padding_ratio=0.5, 16 | ): 17 | self.detector = dlib.get_frontal_face_detector() 18 | self.min_face_size_ratio = min_face_size_ratio 19 | self.sharpness_threshold = sharpness_threshold 20 | self.max_images = max_images 21 | self.max_image_side = max_image_side 22 | self.padding_ratio = padding_ratio 23 | 24 | def face_detector(self, image_path): 25 | try: 26 | img = cv2.imread(image_path) 27 | if img is None: 28 | return False 29 | 30 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 31 | faces = self.detector(gray, 1) 32 | 33 | if len(faces) != 1: 34 | return False 35 | 36 | face = faces[0] 37 | face_width = face.right() - face.left() 38 | face_height = face.bottom() - face.top() 39 | 40 | min_face_size = min(img.shape[0], img.shape[1]) * self.min_face_size_ratio 41 | if face_width < min_face_size or face_height < min_face_size: 42 | return False 43 | 44 | face_region = gray[face.top() : face.bottom(), face.left() : face.right()] 45 | if face_region.size == 0: 46 | return False 47 | 48 | laplacian_var = cv2.Laplacian(face_region, cv2.CV_64F).var() 49 | if laplacian_var <= self.sharpness_threshold: 50 | return False 51 | 52 | # Calculate padded region 53 | padding_x = int(face_width * self.padding_ratio) 54 | padding_y = int(face_height * self.padding_ratio) 55 | 56 | left = max(0, face.left() - padding_x) 57 | right = min(img.shape[1], face.right() + padding_x) 58 | top = max(0, face.top() - padding_y) 59 | bottom = min(img.shape[0], face.bottom() + padding_y) 60 | 61 | face_image = img[top:bottom, left:right] 62 | return True, face_image 63 | 64 | except Exception as e: 65 | raise RuntimeError(f"Face detection error for {image_path}: {str(e)}") 66 | 67 | def resize(self, image): 68 | h, w = image.shape[:2] 69 | scale = self.max_image_side / max(h, w) 70 | 71 | if scale >= 1: 72 | return image 73 | 74 | new_w = int(w * scale) 75 | new_h = int(h * scale) 76 | return cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA) 77 | 78 | def process_images(self, input_dir, output_dir): 79 | os.makedirs(output_dir, exist_ok=True) 80 | valid_face_images = [] 81 | 82 | image_files = [ 83 | f 84 | for f in os.listdir(input_dir) 85 | if f.lower().endswith((".png", ".jpg", ".jpeg", ".bmp")) 86 | ] 87 | 88 | # Process each image 89 | for filename in tqdm(image_files, desc="Processing Images"): 90 | image_path = os.path.join(input_dir, filename) 91 | 92 | result = self.face_detector(image_path) 93 | if result: 94 | _, face_image = result 95 | valid_face_images.append(face_image) 96 | 97 | print(f"Total valid face images found: {len(valid_face_images)}") 98 | 99 | # Sort images by sharpness 100 | valid_face_images.sort( 101 | key=lambda x: cv2.Laplacian( 102 | cv2.cvtColor(x, cv2.COLOR_BGR2GRAY), cv2.CV_64F 103 | ).var(), 104 | reverse=True, 105 | ) 106 | 107 | # Select top images and resize 108 | for i, face_image in enumerate(valid_face_images[: self.max_images]): 109 | resized_image = self.resize(face_image) 110 | 111 | output_path = os.path.join(output_dir, f"{i+1}.jpg") 112 | cv2.imwrite(output_path, resized_image) 113 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | 173 | datasets/* 174 | lists/* 175 | test_outputs/* --------------------------------------------------------------------------------