├── scripts ├── phishintention │ ├── modules │ │ ├── __init__.py │ │ ├── awl_detector.py │ │ ├── models2.py │ │ └── logo_matching.py │ ├── utils │ │ ├── __init__.py │ │ ├── web_utils.py │ │ └── utils.py │ ├── ocr_lib │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── stn_head.py │ │ │ ├── resnet_aster.py │ │ │ ├── tps_spatial_transformer.py │ │ │ ├── model_builder.py │ │ │ └── attention_recognition_head.py │ │ ├── utils │ │ │ ├── meters.py │ │ │ ├── osutils.py │ │ │ ├── __init__.py │ │ │ ├── labelmaps.py │ │ │ ├── serialization.py │ │ │ ├── logging.py │ │ │ └── visualization_utils.py │ │ └── loss │ │ │ └── sequenceCrossEntropyLoss.py │ ├── configs │ │ ├── configs.yaml │ │ ├── faster_rcnn_web.yaml │ │ └── faster_rcnn_login_lr0.001_finetune.yaml │ ├── configs.py │ └── setup.sh ├── utils │ ├── draw_utils.py │ ├── logger_utils.py │ ├── PhishIntentionWrapper.py │ ├── utils.py │ └── web_utils.py └── infer │ └── test.py ├── datasets ├── test_sites │ └── www.baidu.com │ │ ├── info.txt │ │ └── shot.png └── hosting_blacklists.txt ├── figures ├── phishllm.pdf └── phishllm.png ├── requirements.txt ├── param_dict.yaml ├── .gitignore ├── README.md └── prompts └── crp_trans_prompt.json /scripts/phishintention/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/phishintention/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/phishintention/ocr_lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/test_sites/www.baidu.com/info.txt: -------------------------------------------------------------------------------- 1 | https://www.baidu.com/ -------------------------------------------------------------------------------- /figures/phishllm.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-philia/PhishVLM/HEAD/figures/phishllm.pdf -------------------------------------------------------------------------------- /figures/phishllm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-philia/PhishVLM/HEAD/figures/phishllm.png -------------------------------------------------------------------------------- /datasets/test_sites/www.baidu.com/shot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/code-philia/PhishVLM/HEAD/datasets/test_sites/www.baidu.com/shot.png -------------------------------------------------------------------------------- /scripts/phishintention/ocr_lib/utils/meters.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class AverageMeter(object): 4 | """Computes and stores the average and current value""" 5 | 6 | def __init__(self): 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.23.0 2 | scikit-learn 3 | scipy 4 | pandas 5 | matplotlib 6 | nltk 7 | spacy 8 | tqdm 9 | unidecode 10 | fvcore 11 | lxml 12 | psutil 13 | Pillow==8.4.0 14 | requests 15 | beautifulsoup4 16 | tldextract 17 | gdown 18 | selenium 19 | selenium-wire 20 | helium 21 | webdriver-manager 22 | flask 23 | flask-cors 24 | google-cloud-vision 25 | googletrans 26 | editdistance 27 | cryptography==38.0.4 28 | httpcore==0.15.0 29 | h11 30 | h2 31 | blinker==1.7.0 32 | hyperframe 33 | pycocotools 34 | opencv-python 35 | opencv-contrib-python 36 | openai 37 | -------------------------------------------------------------------------------- /scripts/phishintention/ocr_lib/utils/osutils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | 4 | 5 | def mkdir_if_missing(dir_path): 6 | try: 7 | os.makedirs(dir_path) 8 | except OSError as e: 9 | if e.errno != errno.EEXIST: 10 | raise 11 | 12 | 13 | def make_symlink_if_not_exists(real_path, link_path): 14 | ''' 15 | param real_path: str the path linked 16 | param link_path: str the path with only the symbol 17 | ''' 18 | try: 19 | os.makedirs(real_path) 20 | except OSError as e: 21 | if e.errno != errno.EEXIST: 22 | raise 23 | 24 | cmd = 'ln -s {0} {1}'.format(real_path, link_path) 25 | os.system(cmd) -------------------------------------------------------------------------------- /scripts/phishintention/ocr_lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | 5 | def to_numpy(tensor): 6 | if torch.is_tensor(tensor): 7 | return tensor.cpu().numpy() 8 | elif type(tensor).__module__ != 'numpy': 9 | raise ValueError("Cannot convert {} to numpy array" 10 | .format(type(tensor))) 11 | return tensor 12 | 13 | 14 | def to_torch(ndarray): 15 | if type(ndarray).__module__ == 'numpy': 16 | return torch.from_numpy(ndarray) 17 | elif not torch.is_tensor(ndarray): 18 | raise ValueError("Cannot convert {} to torch tensor" 19 | .format(type(ndarray))) 20 | return ndarray -------------------------------------------------------------------------------- /scripts/phishintention/configs/configs.yaml: -------------------------------------------------------------------------------- 1 | AWL_MODEL: # element recognition model -- logo only 2 | CFG_PATH: configs/faster_rcnn_web.yaml 3 | WEIGHTS_PATH: models/layout_detector.pth 4 | DETECT_THRE: 0.3 5 | 6 | CRP_CLASSIFIER: 7 | WEIGHTS_PATH: models/crp_classifier.pth.tar 8 | MODEL_TYPE: 'mixed' 9 | 10 | CRP_LOCATOR: # element recognition model -- logo only 11 | CFG_PATH: configs/faster_rcnn_login_lr0.001_finetune.yaml 12 | WEIGHTS_PATH: models/crp_locator.pth 13 | DETECT_THRE: 0.05 14 | 15 | SIAMESE_MODEL: 16 | NUM_CLASSES: 277 # number of brands, users don't need to modify this even the targetlist is expanded 17 | WEIGHTS_PATH: models/ocr_siamese.pth.tar 18 | OCR_WEIGHTS_PATH: models/ocr_pretrained.pth.tar 19 | TARGETLIST_PATH: models/expand_targetlist.zip 20 | MATCH_THRE: 0.87 # FIXME: threshold is 0.87 in phish-discovery? 21 | DOMAIN_MAP_PATH: models/domain_map.pkl -------------------------------------------------------------------------------- /scripts/utils/draw_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw, ImageFont 2 | from PIL import Image 3 | import numpy as np 4 | from typing import List 5 | import cv2 6 | from numpy.typing import ArrayLike, NDArray 7 | 8 | 9 | def draw_annotated_image_box( 10 | image: Image.Image, 11 | predicted_domain: str, 12 | box: ArrayLike 13 | ) -> Image.Image: 14 | image = image.convert('RGB') 15 | screenshot_img_arr = np.asarray(image) 16 | screenshot_img_arr = np.flip(screenshot_img_arr, -1) 17 | screenshot_img_arr = screenshot_img_arr.astype(np.uint8) 18 | 19 | if box is not None: 20 | cv2.rectangle(screenshot_img_arr, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (69, 139, 0), 2) 21 | cv2.putText(screenshot_img_arr, 'Predicted phishing target: '+ predicted_domain, (int(box[0]), int(box[3])), 22 | cv2.FONT_HERSHEY_SIMPLEX, 23 | 0.5, (0, 0, 255), 2) 24 | else: 25 | cv2.putText(screenshot_img_arr, 'Predicted phishing target: ' + predicted_domain, (int(10), int(10)), 26 | cv2.FONT_HERSHEY_SIMPLEX, 27 | 0.5, (0, 0, 255), 2) 28 | screenshot_img_arr = np.flip(screenshot_img_arr, -1) 29 | image = Image.fromarray(screenshot_img_arr) 30 | return image 31 | -------------------------------------------------------------------------------- /param_dict.yaml: -------------------------------------------------------------------------------- 1 | VLM_model: "gpt-4o-mini-2024-07-18" # use gpt4 2 | 3 | brand_recog: # brand recognition model 4 | temperature: 0 # deterministic response 5 | max_tokens: 10 # limit the maximum number of generated tokens 6 | sleep_time: 0.5 # 7 | prompt_path: "./prompts/brand_recog_prompt.json" # path to the prompt 8 | 9 | brand_valid: 10 | activate: True # whether to activate the brand validation? 11 | k: 10 # look at the top-10 google image results to check whether the webpage logo is similar to any one of them 12 | siamese_thre: 0.7 # whether the webpage logo is similar to any one of them with similarity threshold as 0.7 13 | 14 | crp_pred: # CRP prediction model 15 | temperature: 0 # deterministic response 16 | max_tokens: 200 # limit the maximum number of generated tokens 17 | sleep_time: 0.5 18 | prompt_path: "./prompts/crp_pred_prompt.json" # path to the prompt 19 | 20 | rank: # CRP transition model 21 | temperature: 0 # deterministic response 22 | max_tokens: 10 # limit the maximum number of generated tokens 23 | max_uis_process: 30 # only look at the first k UI elements because the login UI is likely to be located on the top of the screenshot 24 | depth_limit: 1 25 | driver_sleep_time: 3 26 | script_timeout: 30 27 | page_load_timeout: 30 28 | prompt_path: "./prompts/crp_trans_prompt.json" # path to the prompt 29 | -------------------------------------------------------------------------------- /scripts/phishintention/ocr_lib/utils/labelmaps.py: -------------------------------------------------------------------------------- 1 | 2 | import string 3 | 4 | from . import to_torch, to_numpy 5 | 6 | def get_vocabulary(voc_type, EOS='EOS', PADDING='PADDING', UNKNOWN='UNKNOWN'): 7 | ''' 8 | voc_type: str: one of 'LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS' 9 | ''' 10 | voc = None 11 | types = ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS'] 12 | if voc_type == 'LOWERCASE': 13 | voc = list(string.digits + string.ascii_lowercase) 14 | elif voc_type == 'ALLCASES': 15 | voc = list(string.digits + string.ascii_letters) 16 | elif voc_type == 'ALLCASES_SYMBOLS': 17 | voc = list(string.printable[:-6]) 18 | else: 19 | raise KeyError('voc_type must be one of "LOWERCASE", "ALLCASES", "ALLCASES_SYMBOLS"') 20 | 21 | # update the voc with specifical chars 22 | voc.append(EOS) 23 | voc.append(PADDING) 24 | voc.append(UNKNOWN) 25 | 26 | return voc 27 | 28 | ## param voc: the list of vocabulary 29 | def char2id(voc): 30 | return dict(zip(voc, range(len(voc)))) 31 | 32 | def id2char(voc): 33 | return dict(zip(range(len(voc)), voc)) 34 | 35 | def labels2strs(labels, id2char, char2id): 36 | # labels: batch_size x len_seq 37 | if labels.ndimension() == 1: 38 | labels = labels.unsqueeze(0) 39 | assert labels.dim() == 2 40 | labels = to_numpy(labels) 41 | strings = [] 42 | batch_size = labels.shape[0] 43 | 44 | for i in range(batch_size): 45 | label = labels[i] 46 | string = [] 47 | for l in label: 48 | if l == char2id['EOS']: 49 | break 50 | else: 51 | string.append(id2char[l]) 52 | string = ''.join(string) 53 | strings.append(string) 54 | 55 | return strings -------------------------------------------------------------------------------- /scripts/phishintention/configs.py: -------------------------------------------------------------------------------- 1 | # Global configuration 2 | import subprocess 3 | import yaml 4 | from .modules.awl_detector import config_rcnn 5 | from .modules.logo_matching import siamese_model_config, ocr_model_config, cache_reference_list 6 | import os 7 | import numpy as np 8 | 9 | def get_absolute_path(relative_path): 10 | base_path = os.path.dirname(__file__) 11 | return os.path.abspath(os.path.join(base_path, relative_path)) 12 | 13 | def load_config(reload_targetlist=False): 14 | 15 | with open(os.path.join(os.path.dirname(__file__), 'configs/configs.yaml')) as file: 16 | configs = yaml.load(file, Loader=yaml.FullLoader) 17 | 18 | # Iterate through the configuration and update paths 19 | for section, settings in configs.items(): 20 | for key, value in settings.items(): 21 | if 'PATH' in key and isinstance(value, str): # Check if the key indicates a path 22 | absolute_path = get_absolute_path(value) 23 | configs[section][key] = absolute_path 24 | 25 | AWL_MODEL = config_rcnn( 26 | cfg_path=configs['AWL_MODEL']['CFG_PATH'], 27 | weights_path=configs['AWL_MODEL']['WEIGHTS_PATH'], 28 | conf_threshold=configs['AWL_MODEL']['DETECT_THRE'] 29 | ) 30 | 31 | # siamese model 32 | SIAMESE_THRE = configs['SIAMESE_MODEL']['MATCH_THRE'] 33 | 34 | SIAMESE_MODEL = siamese_model_config( 35 | num_classes=configs['SIAMESE_MODEL']['NUM_CLASSES'], 36 | weights_path=configs['SIAMESE_MODEL']['WEIGHTS_PATH'] 37 | ) 38 | 39 | OCR_MODEL = ocr_model_config(weights_path = configs['SIAMESE_MODEL']['OCR_WEIGHTS_PATH']) 40 | 41 | 42 | return AWL_MODEL, SIAMESE_MODEL, OCR_MODEL, SIAMESE_THRE -------------------------------------------------------------------------------- /scripts/phishintention/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail # Safer bash behavior 4 | IFS=$'\n\t' 5 | 6 | # Set up model directory 7 | FILEDIR="$(pwd)" 8 | MODELS_DIR="$FILEDIR/models" 9 | mkdir -p "$MODELS_DIR" 10 | cd "$MODELS_DIR" 11 | 12 | # RCNN model weights 13 | if [ -f "layout_detector.pth" ]; then 14 | echo "layout_detector weights exists... skip" 15 | else 16 | gdown --id "1HWjE5Fv-c3nCDzLCBc7I3vClP1IeuP_I" -O "layout_detector.pth" 17 | fi 18 | 19 | # Faster RCNN config 20 | if [ -f "crp_classifier.pth.tar" ]; then 21 | echo "CRP classifier weights exists... skip" 22 | else 23 | gdown --id "1igEMRz0vFBonxAILeYMRWTyd7A9sRirO" -O "crp_classifier.pth.tar" 24 | fi 25 | 26 | 27 | # Siamese model weights 28 | if [ -f "crp_locator.pth" ]; then 29 | echo "crp_locator weights exists... skip" 30 | else 31 | gdown --id "1_O5SALqaJqvWoZDrdIVpsZyCnmSkzQcm" -O "crp_locator.pth" 32 | fi 33 | 34 | # Siamese model pretrained weights 35 | if [ -f "ocr_pretrained.pth.tar" ]; then 36 | echo "OCR pretrained model weights exists... skip" 37 | else 38 | gdown --id "15pfVWnZR-at46gqxd50cWhrXemP8oaxp" -O "ocr_pretrained.pth.tar" 39 | fi 40 | 41 | # Siamese model finetuned weights 42 | if [ -f "ocr_siamese.pth.tar" ]; then 43 | echo "OCR-siamese weights exists... skip" 44 | else 45 | gdown --id "1BxJf5lAcNEnnC0In55flWZ89xwlYkzPk" -O "ocr_siamese.pth.tar" 46 | fi 47 | 48 | # Reference list 49 | if [ -f "expand_targetlist.zip" ]; then 50 | echo "Reference list exists... skip" 51 | else 52 | gdown --id "1fr5ZxBKyDiNZ_1B6rRAfZbAHBBoUjZ7I" -O "expand_targetlist.zip" 53 | fi 54 | 55 | # Domain map 56 | if [ -f "domain_map.pkl" ]; then 57 | echo "Domain map exists... skip" 58 | else 59 | gdown --id "1qSdkSSoCYUkZMKs44Rup_1DPBxHnEKl1" -O "domain_map.pkl" 60 | fi 61 | 62 | # Extract and flatten expand_targetlist 63 | echo "Extracting expand_targetlist.zip..." 64 | unzip -o expand_targetlist.zip -d expand_targetlist 65 | cd expand_targetlist || error_exit "Extraction directory missing." 66 | 67 | if [ -d "expand_targetlist" ]; then 68 | echo "Flattening nested expand_targetlist/ directory..." 69 | mv expand_targetlist/* . 70 | rm -r expand_targetlist 71 | fi 72 | 73 | echo "Extraction completed successfully." 74 | echo "All packages installed successfully!" 75 | -------------------------------------------------------------------------------- /scripts/phishintention/ocr_lib/loss/sequenceCrossEntropyLoss.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | 7 | def to_contiguous(tensor): 8 | if tensor.is_contiguous(): 9 | return tensor 10 | else: 11 | return tensor.contiguous() 12 | 13 | def _assert_no_grad(variable): 14 | assert not variable.requires_grad, \ 15 | "nn criterions don't compute the gradient w.r.t. targets - please " \ 16 | "mark these variables as not requiring gradients" 17 | 18 | class SequenceCrossEntropyLoss(nn.Module): 19 | def __init__(self, 20 | weight=None, 21 | size_average=True, 22 | ignore_index=-100, 23 | sequence_normalize=False, 24 | sample_normalize=True): 25 | super(SequenceCrossEntropyLoss, self).__init__() 26 | self.weight = weight 27 | self.size_average = size_average 28 | self.ignore_index = ignore_index 29 | self.sequence_normalize = sequence_normalize 30 | self.sample_normalize = sample_normalize 31 | 32 | assert (sequence_normalize and sample_normalize) == False 33 | 34 | def forward(self, input, target, length): 35 | _assert_no_grad(target) 36 | # length to mask 37 | batch_size, def_max_length = target.size(0), target.size(1) 38 | mask = torch.zeros(batch_size, def_max_length) 39 | for i in range(batch_size): 40 | mask[i,:length[i]].fill_(1) 41 | mask = mask.type_as(input) 42 | # truncate to the same size 43 | max_length = max(length) 44 | assert max_length == input.size(1) 45 | target = target[:, :max_length] 46 | mask = mask[:, :max_length] 47 | input = to_contiguous(input).view(-1, input.size(2)) 48 | input = F.log_softmax(input, dim=1) 49 | target = to_contiguous(target).view(-1, 1) 50 | mask = to_contiguous(mask).view(-1, 1) 51 | output = - input.gather(1, target.long()) * mask 52 | # if self.size_average: 53 | # output = torch.sum(output) / torch.sum(mask) 54 | # elif self.reduce: 55 | # output = torch.sum(output) 56 | ## 57 | output = torch.sum(output) 58 | if self.sequence_normalize: 59 | output = output / torch.sum(mask) 60 | if self.sample_normalize: 61 | output = output / batch_size 62 | 63 | return output -------------------------------------------------------------------------------- /scripts/phishintention/configs/faster_rcnn_web.yaml: -------------------------------------------------------------------------------- 1 | 2 | MODEL: 3 | META_ARCHITECTURE: "GeneralizedRCNN" 4 | BACKBONE: 5 | NAME: "build_resnet_fpn_backbone" 6 | FREEZE_AT: 2 # Default 2 7 | RESNETS: 8 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 9 | DEPTH: 50 # ResNet50 10 | FPN: 11 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 12 | ANCHOR_GENERATOR: 13 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 14 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 15 | RPN: 16 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 17 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 18 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 19 | # Detectron1 uses 2000 proposals per-batch, 20 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 21 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 22 | POST_NMS_TOPK_TRAIN: 1000 23 | POST_NMS_TOPK_TEST: 1000 24 | ROI_HEADS: 25 | NAME: "StandardROIHeads" 26 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 27 | NUM_CLASSES: 5 # Change to suit own task 28 | BATCH_SIZE_PER_IMAGE: 512 29 | ROI_BOX_HEAD: 30 | NAME: "FastRCNNConvFCHead" 31 | NUM_FC: 2 32 | POOLER_RESOLUTION: 7 33 | ROI_MASK_HEAD: 34 | NAME: "MaskRCNNConvUpsampleHead" 35 | NUM_CONV: 4 36 | POOLER_RESOLUTION: 14 37 | # COCO ResNet50 weights 38 | WEIGHTS: "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_R_50_FPN_3x/137849458/model_final_280758.pkl" 39 | MASK_ON: False # Not doing segmentation 40 | 41 | DATASETS: 42 | TRAIN: ("web_train",) 43 | TEST: ("web_test",) 44 | DATALOADER: 45 | NUM_WORKERS: 0 46 | SOLVER: 47 | IMS_PER_BATCH: 8 # Batch size; Default 16 48 | BASE_LR: 0.00001 49 | # (2/3, 8/9) 50 | STEPS: (16341, 21788) # The iteration number to decrease learning rate by GAMMA. 51 | MAX_ITER: 24512 # Number of training iterations 52 | CHECKPOINT_PERIOD: 4000 # Saves checkpoint every number of steps 53 | INPUT: 54 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) # Image input sizes 55 | TEST: 56 | # The period (in terms of steps) to evaluate the model during training. 57 | # Set to 0 to disable. 58 | EVAL_PERIOD: 1000 59 | OUTPUT_DIR: "./output/website" # Specify output directory 60 | -------------------------------------------------------------------------------- /scripts/phishintention/configs/faster_rcnn_login_lr0.001_finetune.yaml: -------------------------------------------------------------------------------- 1 | 2 | 3 | MODEL: 4 | META_ARCHITECTURE: "GeneralizedRCNN" 5 | BACKBONE: 6 | NAME: "build_resnet_fpn_backbone" 7 | FREEZE_AT: 2 # Default 2 8 | RESNETS: 9 | OUT_FEATURES: [ "res2", "res3", "res4", "res5" ] 10 | DEPTH: 50 # ResNet50 11 | FPN: 12 | IN_FEATURES: [ "res2", "res3", "res4", "res5" ] 13 | ANCHOR_GENERATOR: 14 | SIZES: [ [ 32 ], [ 64 ], [ 128 ], [ 256 ], [ 512 ] ] # One size for each in feature map 15 | ASPECT_RATIOS: [ [ 0.5, 1.0, 2.0 ] ] # Three aspect ratios (same for all in feature maps) 16 | RPN: 17 | IN_FEATURES: [ "p2", "p3", "p4", "p5", "p6" ] 18 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 19 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 20 | # Detectron1 uses 2000 proposals per-batch, 21 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 22 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 23 | POST_NMS_TOPK_TRAIN: 1000 24 | POST_NMS_TOPK_TEST: 1000 25 | ROI_HEADS: 26 | NAME: "StandardROIHeads" 27 | IN_FEATURES: [ "p2", "p3", "p4", "p5" ] 28 | NUM_CLASSES: 1 # Change to suit own task 29 | # Can reduce this for lower memory/faster training; Default 512 30 | BATCH_SIZE_PER_IMAGE: 512 31 | ROI_BOX_HEAD: 32 | NAME: "FastRCNNConvFCHead" 33 | NUM_FC: 2 34 | POOLER_RESOLUTION: 7 35 | ROI_MASK_HEAD: 36 | NAME: "MaskRCNNConvUpsampleHead" 37 | NUM_CONV: 4 38 | POOLER_RESOLUTION: 14 39 | # COCO ResNet50 weights 40 | WEIGHTS: "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_R_50_FPN_3x/137849458/model_final_280758.pkl" 41 | MASK_ON: False # Not doing segmentation 42 | 43 | DATASETS: 44 | TRAIN: ("login_train",) 45 | TEST: ("login_test",) 46 | DATALOADER: 47 | NUM_WORKERS: 0 48 | SOLVER: 49 | IMS_PER_BATCH: 8 # Batch size; Default 16 50 | BASE_LR: 0.001 51 | # (2/3, 8/9) 52 | STEPS: (12000, 16000) # The iteration number to decrease learning rate by GAMMA. 53 | MAX_ITER: 18000 # Number of training iterations 54 | CHECKPOINT_PERIOD: 4000 # Saves checkpoint every number of steps 55 | INPUT: 56 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) # Image input sizes 57 | TEST: 58 | # The period (in terms of steps) to evaluate the model during training. 59 | # Set to 0 to disable. 60 | EVAL_PERIOD: 2000 61 | 62 | -------------------------------------------------------------------------------- /scripts/phishintention/ocr_lib/utils/serialization.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | # import moxing as mox 5 | import os.path as osp 6 | import shutil 7 | 8 | import torch 9 | from torch.nn import Parameter 10 | 11 | from ocr_lib.utils.osutils import mkdir_if_missing 12 | 13 | # from config import get_args 14 | # global_args = get_args(sys.argv[1:]) 15 | 16 | # if global_args.run_on_remote: 17 | # import moxing as mox 18 | 19 | 20 | def read_json(fpath): 21 | with open(fpath, 'r') as f: 22 | obj = json.load(f) 23 | return obj 24 | 25 | 26 | def write_json(obj, fpath): 27 | mkdir_if_missing(osp.dirname(fpath)) 28 | with open(fpath, 'w') as f: 29 | json.dump(obj, f, indent=4, separators=(',', ': ')) 30 | 31 | 32 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 33 | print('=> saving checkpoint ', fpath) 34 | if global_args.run_on_remote: 35 | dir_name = osp.dirname(fpath) 36 | if not mox.file.exists(dir_name): 37 | mox.file.make_dirs(dir_name) 38 | print('=> makding dir ', dir_name) 39 | local_path = "local_checkpoint.pth.tar" 40 | torch.save(state, local_path) 41 | mox.file.copy(local_path, fpath) 42 | if is_best: 43 | mox.file.copy(local_path, osp.join(dir_name, 'model_best.pth.tar')) 44 | else: 45 | mkdir_if_missing(osp.dirname(fpath)) 46 | torch.save(state, fpath) 47 | if is_best: 48 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 49 | 50 | 51 | def load_checkpoint(fpath): 52 | if global_args.run_on_remote: 53 | mox.file.shift('os', 'mox') 54 | checkpoint = torch.load(fpath) 55 | print("=> Loaded checkpoint '{}'".format(fpath)) 56 | return checkpoint 57 | else: 58 | load_path = fpath 59 | 60 | if osp.isfile(load_path): 61 | checkpoint = torch.load(load_path) 62 | print("=> Loaded checkpoint '{}'".format(load_path)) 63 | return checkpoint 64 | else: 65 | raise ValueError("=> No checkpoint found at '{}'".format(load_path)) 66 | 67 | 68 | def copy_state_dict(state_dict, model, strip=None): 69 | tgt_state = model.state_dict() 70 | copied_names = set() 71 | for name, param in state_dict.items(): 72 | if strip is not None and name.startswith(strip): 73 | name = name[len(strip):] 74 | if name not in tgt_state: 75 | continue 76 | if isinstance(param, Parameter): 77 | param = param.data 78 | if param.size() != tgt_state[name].size(): 79 | print('mismatch:', name, param.size(), tgt_state[name].size()) 80 | continue 81 | tgt_state[name].copy_(param) 82 | copied_names.add(name) 83 | 84 | missing = set(tgt_state.keys()) - copied_names 85 | if len(missing) > 0: 86 | print("missing keys in state_dict:", missing) 87 | 88 | return model -------------------------------------------------------------------------------- /scripts/utils/logger_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | 4 | class TxtColors: 5 | OK = '\033[92m' 6 | DEBUG = '\033[94m' 7 | WARNING = "\033[93m" 8 | FATAL = '\033[91m' 9 | EXCEPTION = '\033[100m' 10 | ENDC = '\033[0m' 11 | 12 | '''Logging Utils''' 13 | class PhishLLMLogger(): 14 | _caller_prefix = "PhishLLMLogger" 15 | _verbose = True 16 | _logfile = None 17 | _debug = False # Off by default 18 | _warning = True 19 | 20 | @classmethod 21 | def set_verbose(cls, verbose): 22 | cls._verbose = verbose 23 | 24 | @classmethod 25 | def set_logfile(cls, logfile): 26 | PhishLLMLogger._logfile = logfile 27 | 28 | @classmethod 29 | def unset_logfile(cls): 30 | PhishLLMLogger.set_logfile(None) 31 | 32 | @classmethod 33 | def set_debug_on(cls): 34 | PhishLLMLogger._debug = True 35 | 36 | @classmethod 37 | def set_debug_off(cls): # Call if need to turn debug messages off 38 | PhishLLMLogger._debug = False 39 | 40 | @classmethod 41 | def set_warning_on(cls): 42 | PhishLLMLogger._warning = True 43 | 44 | @classmethod 45 | def set_warning_off(cls): # Call if need to turn warnings off 46 | PhishLLMLogger._warning = False 47 | 48 | @classmethod 49 | def spit(cls, msg, warning=False, debug=False, error=False, exception=False, caller_prefix=""): 50 | logging.basicConfig(level=logging.DEBUG if PhishLLMLogger._debug else logging.WARNING) 51 | caller_prefix = f"[{caller_prefix}]" if caller_prefix else "" 52 | prefix = "[FATAL]" if error else "[DEBUG]" if debug else "[WARNING]" if warning else "[EXCEPTION]" if exception else "" 53 | logger = logging.getLogger("custom_logger") # Choose an appropriate logger name 54 | if PhishLLMLogger._logfile: 55 | log_msg = re.sub(r"\033\[\d+m", "", msg) 56 | log_handler = logging.FileHandler(PhishLLMLogger._logfile, mode='a') 57 | log_formatter = logging.Formatter('%(message)s') 58 | log_handler.setFormatter(log_formatter) 59 | logger.addHandler(log_handler) 60 | logger.propagate = False 61 | logger.setLevel(logging.DEBUG if PhishLLMLogger._debug else logging.WARNING) 62 | logger.debug("%s%s %s" % (caller_prefix, prefix, log_msg)) 63 | logger.removeHandler(log_handler) 64 | else: 65 | if PhishLLMLogger._verbose: 66 | txtcolor = TxtColors.FATAL if error else TxtColors.DEBUG if debug else TxtColors.WARNING if warning else "[EXCEPTION]" if exception else TxtColors.OK 67 | # if not debug or Logger._debug: 68 | if (not debug and not warning) or (debug and PhishLLMLogger._debug) or (warning and PhishLLMLogger._warning): 69 | print("%s%s%s %s" % (txtcolor, caller_prefix, prefix, msg)) -------------------------------------------------------------------------------- /scripts/utils/PhishIntentionWrapper.py: -------------------------------------------------------------------------------- 1 | from ..phishintention.modules.awl_detector import pred_rcnn, vis, find_element_type 2 | from ..phishintention.modules.logo_matching import ocr_main, l2_norm 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from PIL import Image, ImageOps 7 | from torchvision import transforms 8 | from typing import Tuple, Union 9 | from numpy.typing import ArrayLike, NDArray 10 | 11 | class LayoutDetector(nn.Module): 12 | def __init__(self, predictor): 13 | super().__init__() 14 | self.predictor = predictor 15 | 16 | def forward(self, screenshot_path: str) -> Tuple[NDArray, NDArray]: 17 | # Run detection with RCNN predictor 18 | pred_boxes, pred_classes, _ = pred_rcnn( 19 | im=screenshot_path, 20 | predictor=self.predictor 21 | ) 22 | pred_boxes = pred_boxes.numpy() 23 | pred_classes = pred_classes.numpy() 24 | return pred_boxes, pred_classes 25 | 26 | class LogoDetector(nn.Module): 27 | def __init__(self, predictor): 28 | super().__init__() 29 | self.predictor = predictor 30 | 31 | def forward(self, screenshot_path: str) -> NDArray: 32 | # Run detection with RCNN predictor 33 | pred_boxes, pred_classes, _ = pred_rcnn( 34 | im=screenshot_path, 35 | predictor=self.predictor 36 | ) 37 | 38 | # Filter to "logo" class 39 | logo_pred_boxes, _ = find_element_type( 40 | pred_boxes, pred_classes, bbox_type="logo" 41 | ) 42 | logo_pred_boxes = logo_pred_boxes.numpy() 43 | return logo_pred_boxes 44 | 45 | class LogoEncoder(nn.Module): 46 | def __init__(self, siamese_model, ocr_model, matching_threshold, img_size: int = 224): 47 | super().__init__() 48 | self.siamese_model = siamese_model 49 | self.ocr_model = ocr_model 50 | self.img_size = img_size 51 | self.matching_threshold = matching_threshold 52 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 53 | 54 | # Transformation pipeline 55 | mean = [0.5, 0.5, 0.5] 56 | std = [0.5, 0.5, 0.5] 57 | self.img_transforms = transforms.Compose([ 58 | transforms.ToTensor(), 59 | transforms.Normalize(mean=mean, std=std), 60 | ]) 61 | 62 | def preprocess_image(self, img: Union[str, Image.Image]) -> Image.Image: 63 | img = Image.open(img) if isinstance(img, str) else img 64 | img = img.convert("RGBA").convert("RGB") 65 | # Pad to square 66 | pad_color = (255, 255, 255) 67 | img = ImageOps.expand( 68 | img, 69 | ( 70 | (max(img.size) - img.size[0]) // 2, 71 | (max(img.size) - img.size[1]) // 2, 72 | (max(img.size) - img.size[0]) // 2, 73 | (max(img.size) - img.size[1]) // 2, 74 | ), 75 | fill=pad_color, 76 | ) 77 | # Resize 78 | img = img.resize((self.img_size, self.img_size)) 79 | return img 80 | 81 | def forward(self, img: Image.Image) -> NDArray: 82 | img = self.preprocess_image(img) 83 | 84 | ocr_emb = ocr_main(image_path=img, model=self.ocr_model, height=None, width=None)[0] 85 | ocr_emb = ocr_emb[None, ...].to(self.device) 86 | img_tensor = self.img_transforms(img)[None, ...].to(self.device) 87 | logo_feat = self.siamese_model.features(img_tensor, ocr_emb) 88 | 89 | # L2 normalize 90 | logo_feat = l2_norm(logo_feat).squeeze(0).detach().cpu().numpy() 91 | return logo_feat -------------------------------------------------------------------------------- /scripts/phishintention/ocr_lib/models/stn_head.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import numpy as np 4 | import sys 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from torch.nn import init 10 | 11 | 12 | def conv3x3_block(in_planes, out_planes, stride=1): 13 | """3x3 convolution with padding""" 14 | conv_layer = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1) 15 | 16 | block = nn.Sequential( 17 | conv_layer, 18 | nn.BatchNorm2d(out_planes), 19 | nn.ReLU(inplace=True), 20 | ) 21 | return block 22 | 23 | 24 | class STNHead(nn.Module): 25 | def __init__(self, in_planes, num_ctrlpoints, activation='none'): 26 | super(STNHead, self).__init__() 27 | 28 | self.in_planes = in_planes 29 | self.num_ctrlpoints = num_ctrlpoints 30 | self.activation = activation 31 | self.stn_convnet = nn.Sequential( 32 | conv3x3_block(in_planes, 32), # 32*64 33 | nn.MaxPool2d(kernel_size=2, stride=2), 34 | conv3x3_block(32, 64), # 16*32 35 | nn.MaxPool2d(kernel_size=2, stride=2), 36 | conv3x3_block(64, 128), # 8*16 37 | nn.MaxPool2d(kernel_size=2, stride=2), 38 | conv3x3_block(128, 256), # 4*8 39 | nn.MaxPool2d(kernel_size=2, stride=2), 40 | conv3x3_block(256, 256), # 2*4, 41 | nn.MaxPool2d(kernel_size=2, stride=2), 42 | conv3x3_block(256, 256)) # 1*2 43 | 44 | self.stn_fc1 = nn.Sequential( 45 | nn.Linear(2*256, 512), 46 | nn.BatchNorm1d(512), 47 | nn.ReLU(inplace=True)) 48 | self.stn_fc2 = nn.Linear(512, num_ctrlpoints*2) 49 | 50 | self.init_weights(self.stn_convnet) 51 | self.init_weights(self.stn_fc1) 52 | self.init_stn(self.stn_fc2) 53 | 54 | def init_weights(self, module): 55 | for m in module.modules(): 56 | if isinstance(m, nn.Conv2d): 57 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 58 | m.weight.data.normal_(0, math.sqrt(2. / n)) 59 | if m.bias is not None: 60 | m.bias.data.zero_() 61 | elif isinstance(m, nn.BatchNorm2d): 62 | m.weight.data.fill_(1) 63 | m.bias.data.zero_() 64 | elif isinstance(m, nn.Linear): 65 | m.weight.data.normal_(0, 0.001) 66 | m.bias.data.zero_() 67 | 68 | def init_stn(self, stn_fc2): 69 | margin = 0.01 70 | sampling_num_per_side = int(self.num_ctrlpoints / 2) 71 | ctrl_pts_x = np.linspace(margin, 1.-margin, sampling_num_per_side) 72 | ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin 73 | ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1-margin) 74 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 75 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 76 | ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32) 77 | if self.activation is 'none': 78 | pass 79 | elif self.activation == 'sigmoid': 80 | ctrl_points = -np.log(1. / ctrl_points - 1.) 81 | stn_fc2.weight.data.zero_() 82 | stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1) 83 | 84 | def forward(self, x): 85 | x = self.stn_convnet(x) 86 | batch_size, _, h, w = x.size() 87 | x = x.view(batch_size, -1) 88 | img_feat = self.stn_fc1(x) 89 | x = self.stn_fc2(0.1 * img_feat) 90 | if self.activation == 'sigmoid': 91 | x = F.sigmoid(x) 92 | x = x.view(-1, self.num_ctrlpoints, 2) 93 | return img_feat, x 94 | 95 | 96 | if __name__ == "__main__": 97 | in_planes = 3 98 | num_ctrlpoints = 20 99 | activation='none' # 'sigmoid' 100 | stn_head = STNHead(in_planes, num_ctrlpoints, activation) 101 | input = torch.randn(10, 3, 32, 64) 102 | control_points = stn_head(input) 103 | print(control_points.size()) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.zip 6 | *.pkl 7 | *.pt 8 | ./datasets/* 9 | ./scripts/phishintention/models/* 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 115 | .pdm.toml 116 | .pdm-python 117 | .pdm-build/ 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ 168 | -------------------------------------------------------------------------------- /scripts/phishintention/modules/awl_detector.py: -------------------------------------------------------------------------------- 1 | from detectron2.engine import DefaultPredictor 2 | from detectron2.config import get_cfg 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def cv_imread(filePath): 9 | ''' 10 | When image path contains nonenglish characters, normal cv2.imread will have error 11 | :param filePath: 12 | :return: 13 | ''' 14 | cv_img = cv2.imdecode(np.fromfile(filePath, dtype=np.uint8), cv2.IMREAD_UNCHANGED) 15 | return cv_img 16 | 17 | 18 | def config_rcnn(cfg_path, weights_path, conf_threshold): 19 | ''' 20 | Configure weights and confidence threshold 21 | :param cfg_path: 22 | :param weights_path: 23 | :param conf_threshold: 24 | :return: 25 | ''' 26 | cfg = get_cfg() 27 | cfg.merge_from_file(cfg_path) 28 | cfg.MODEL.WEIGHTS = weights_path 29 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = conf_threshold 30 | # uncomment if you installed detectron2 cpu version 31 | if not torch.cuda.is_available(): 32 | cfg.MODEL.DEVICE = 'cpu' 33 | 34 | # Initialize model 35 | predictor = DefaultPredictor(cfg) 36 | return predictor 37 | 38 | def pred_rcnn(im, predictor): 39 | ''' 40 | Perform inference for RCNN 41 | :param im: 42 | :param predictor: 43 | :return: 44 | ''' 45 | im = cv2.imread(im) 46 | 47 | if im is not None: 48 | if im.shape[-1] == 4: 49 | im = cv2.cvtColor(im, cv2.COLOR_BGRA2BGR) 50 | else: 51 | return None, None, None 52 | 53 | outputs = predictor(im) 54 | 55 | instances = outputs['instances'] 56 | pred_classes = instances.pred_classes.detach().cpu() # tensor 57 | pred_boxes = instances.pred_boxes.tensor.detach().cpu() # Boxes object 58 | pred_scores = instances.scores # tensor 59 | 60 | return pred_boxes, pred_classes, pred_scores 61 | 62 | 63 | def find_element_type(pred_boxes, pred_classes, bbox_type='button'): 64 | ''' 65 | Filter bboxes by type 66 | :param pred_boxes: torch.Tensor of shape Nx4, bounding box coordinates in (x1, y1, x2, y2) 67 | :param pred_classes: torch.Tensor of shape Nx1 0 for logo, 1 for input, 2 for button, 3 for label(text near input), 4 for block 68 | :param bbox_type: the type of box we want to find 69 | :return pred_boxes_after: pred_boxes after filtering 70 | :return pred_classes_after: pred_classes after filtering 71 | ''' 72 | # global dict 73 | class_dict = {0: 'logo', 1: 'input', 2: 'button', 3: 'label', 4: 'block'} 74 | inv_class_dict = {v: k for k, v in class_dict.items()} 75 | assert bbox_type in ['logo', 'input', 'button', 'label', 'block'] 76 | pred_boxes_after = pred_boxes[pred_classes == inv_class_dict[bbox_type]] 77 | pred_classes_after = pred_classes[pred_classes == inv_class_dict[bbox_type]] 78 | return pred_boxes_after, pred_classes_after 79 | 80 | 81 | def vis(img_path, pred_boxes, pred_classes): 82 | ''' 83 | Visualize rcnn predictions 84 | :param img_path: str 85 | :param pred_boxes: torch.Tensor of shape Nx4, bounding box coordinates in (x1, y1, x2, y2) 86 | :param pred_classes: torch.Tensor of shape Nx1 0 for logo, 1 for input, 2 for button, 3 for label(text near input), 4 for block 87 | :return None 88 | ''' 89 | class_dict = {0: 'logo', 1: 'input', 2: 'button', 3: 'label', 4: 'block'} 90 | check = cv2.imread(img_path) 91 | if pred_boxes is None or len(pred_boxes) == 0: 92 | return check 93 | pred_boxes = pred_boxes.numpy() if not isinstance(pred_boxes, np.ndarray) else pred_boxes 94 | pred_classes = pred_classes.numpy() if not isinstance(pred_classes, np.ndarray) else pred_classes 95 | 96 | # draw rectangles 97 | for j, box in enumerate(pred_boxes): 98 | cv2.rectangle(check, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (36, 255, 12), 2) 99 | cv2.putText(check, class_dict[pred_classes[j].item()], (int(box[0]), int(box[1])), cv2.FONT_HERSHEY_SIMPLEX, 100 | 0.5, (0, 0, 255), 2) 101 | 102 | return check 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /scripts/phishintention/ocr_lib/utils/logging.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import tensorflow as tf 5 | import scipy.misc 6 | try: 7 | from StringIO import StringIO # Python 2.7 8 | except ImportError: 9 | from io import BytesIO # Python 3.x 10 | 11 | from .osutils import mkdir_if_missing 12 | 13 | # from config import get_args 14 | # global_args = get_args(sys.argv[1:]) 15 | 16 | # if global_args.run_on_remote: 17 | # import moxing as mox 18 | # mox.file.shift("os", "mox") 19 | 20 | class Logger(object): 21 | def __init__(self, fpath=None): 22 | self.console = sys.stdout 23 | self.file = None 24 | if fpath is not None: 25 | if global_args.run_on_remote: 26 | dir_name = os.path.dirname(fpath) 27 | if not mox.file.exists(dir_name): 28 | mox.file.make_dirs(dir_name) 29 | print('=> making dir ', dir_name) 30 | self.file = mox.file.File(fpath, 'w') 31 | # self.file = open(fpath, 'w') 32 | else: 33 | mkdir_if_missing(os.path.dirname(fpath)) 34 | self.file = open(fpath, 'w') 35 | 36 | def __del__(self): 37 | self.close() 38 | 39 | def __enter__(self): 40 | pass 41 | 42 | def __exit__(self, *args): 43 | self.close() 44 | 45 | def write(self, msg): 46 | self.console.write(msg) 47 | if self.file is not None: 48 | self.file.write(msg) 49 | 50 | def flush(self): 51 | self.console.flush() 52 | if self.file is not None: 53 | self.file.flush() 54 | os.fsync(self.file.fileno()) 55 | 56 | def close(self): 57 | self.console.close() 58 | if self.file is not None: 59 | self.file.close() 60 | 61 | 62 | class TFLogger(object): 63 | def __init__(self, log_dir=None): 64 | """Create a summary writer logging to log_dir.""" 65 | if log_dir is not None: 66 | mkdir_if_missing(log_dir) 67 | self.writer = tf.summary.FileWriter(log_dir) 68 | 69 | def scalar_summary(self, tag, value, step): 70 | """Log a scalar variable.""" 71 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 72 | self.writer.add_summary(summary, step) 73 | self.writer.flush() 74 | 75 | def image_summary(self, tag, images, step): 76 | """Log a list of images.""" 77 | 78 | img_summaries = [] 79 | for i, img in enumerate(images): 80 | # Write the image to a string 81 | try: 82 | s = StringIO() 83 | except: 84 | s = BytesIO() 85 | scipy.misc.toimage(img).save(s, format="png") 86 | 87 | # Create an Image object 88 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 89 | height=img.shape[0], 90 | width=img.shape[1]) 91 | # Create a Summary value 92 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 93 | 94 | # Create and write Summary 95 | summary = tf.Summary(value=img_summaries) 96 | self.writer.add_summary(summary, step) 97 | self.writer.flush() 98 | 99 | def histo_summary(self, tag, values, step, bins=1000): 100 | """Log a histogram of the tensor of values.""" 101 | 102 | # Create a histogram using numpy 103 | counts, bin_edges = np.histogram(values, bins=bins) 104 | 105 | # Fill the fields of the histogram proto 106 | hist = tf.HistogramProto() 107 | hist.min = float(np.min(values)) 108 | hist.max = float(np.max(values)) 109 | hist.num = int(np.prod(values.shape)) 110 | hist.sum = float(np.sum(values)) 111 | hist.sum_squares = float(np.sum(values**2)) 112 | 113 | # Drop the start of the first bin 114 | bin_edges = bin_edges[1:] 115 | 116 | # Add bin edges and counts 117 | for edge in bin_edges: 118 | hist.bucket_limit.append(edge) 119 | for c in counts: 120 | hist.bucket.append(c) 121 | 122 | # Create and write Summary 123 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 124 | self.writer.add_summary(summary, step) 125 | self.writer.flush() 126 | 127 | def close(self): 128 | self.writer.close() -------------------------------------------------------------------------------- /scripts/phishintention/ocr_lib/models/resnet_aster.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | import sys 6 | import math 7 | 8 | # from config import get_args 9 | # global_args = get_args(sys.argv[1:]) 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | """3x3 convolution with padding""" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | 17 | 18 | def conv1x1(in_planes, out_planes, stride=1): 19 | """1x1 convolution""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 21 | 22 | 23 | def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000): 24 | # [n_position] 25 | positions = torch.arange(0, n_position)#.cuda() 26 | # [feat_dim] 27 | dim_range = torch.arange(0, feat_dim)#.cuda() 28 | dim_range = torch.pow(wave_length, 2 * (dim_range // 2) / feat_dim) 29 | # [n_position, feat_dim] 30 | angles = positions.unsqueeze(1) / dim_range.unsqueeze(0) 31 | angles = angles.float() 32 | angles[:, 0::2] = torch.sin(angles[:, 0::2]) 33 | angles[:, 1::2] = torch.cos(angles[:, 1::2]) 34 | return angles 35 | 36 | 37 | class AsterBlock(nn.Module): 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None): 40 | super(AsterBlock, self).__init__() 41 | self.conv1 = conv1x1(inplanes, planes, stride) 42 | self.bn1 = nn.BatchNorm2d(planes) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.conv2 = conv3x3(planes, planes) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.downsample = downsample 47 | self.stride = stride 48 | 49 | def forward(self, x): 50 | residual = x 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | 57 | if self.downsample is not None: 58 | residual = self.downsample(x) 59 | out += residual 60 | out = self.relu(out) 61 | return out 62 | 63 | 64 | class ResNet_ASTER(nn.Module): 65 | """For aster or crnn""" 66 | 67 | def __init__(self, with_lstm=False, n_group=1): 68 | super(ResNet_ASTER, self).__init__() 69 | self.with_lstm = with_lstm 70 | self.n_group = n_group 71 | 72 | in_channels = 3 73 | self.layer0 = nn.Sequential( 74 | nn.Conv2d(in_channels, 32, kernel_size=(3, 3), stride=1, padding=1, bias=False), 75 | nn.BatchNorm2d(32), 76 | nn.ReLU(inplace=True)) 77 | 78 | self.inplanes = 32 79 | self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50] 80 | self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25] 81 | self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25] 82 | self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25] 83 | self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25] 84 | 85 | if with_lstm: 86 | self.rnn = nn.LSTM(512, 256, bidirectional=True, num_layers=2, batch_first=True) 87 | self.out_planes = 2 * 256 88 | else: 89 | self.out_planes = 512 90 | 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv2d): 93 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 94 | elif isinstance(m, nn.BatchNorm2d): 95 | nn.init.constant_(m.weight, 1) 96 | nn.init.constant_(m.bias, 0) 97 | 98 | def _make_layer(self, planes, blocks, stride): 99 | downsample = None 100 | if stride != [1, 1] or self.inplanes != planes: 101 | downsample = nn.Sequential( 102 | conv1x1(self.inplanes, planes, stride), 103 | nn.BatchNorm2d(planes)) 104 | 105 | layers = [] 106 | layers.append(AsterBlock(self.inplanes, planes, stride, downsample)) 107 | self.inplanes = planes 108 | for _ in range(1, blocks): 109 | layers.append(AsterBlock(self.inplanes, planes)) 110 | return nn.Sequential(*layers) 111 | 112 | def forward(self, x): 113 | x0 = self.layer0(x) 114 | x1 = self.layer1(x0) 115 | x2 = self.layer2(x1) 116 | x3 = self.layer3(x2) 117 | x4 = self.layer4(x3) 118 | x5 = self.layer5(x4) 119 | 120 | cnn_feat = x5.squeeze(2) # [N, c, w] 121 | cnn_feat = cnn_feat.transpose(2, 1) 122 | if self.with_lstm: 123 | rnn_feat, _ = self.rnn(cnn_feat) 124 | return rnn_feat 125 | else: 126 | return cnn_feat 127 | 128 | 129 | if __name__ == "__main__": 130 | x = torch.randn(3, 3, 32, 100) 131 | net = ResNet_ASTER(use_self_attention=True, use_position_embedding=True) 132 | encoder_feat = net(x) 133 | print(encoder_feat.size()) -------------------------------------------------------------------------------- /scripts/phishintention/ocr_lib/utils/visualization_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from PIL import Image 3 | import os 4 | import numpy as np 5 | from collections import OrderedDict 6 | # from scipy.misc import imresize 7 | import matplotlib 8 | matplotlib.use('Agg') 9 | import matplotlib.pyplot as plt 10 | from matplotlib.gridspec import GridSpec 11 | from io import BytesIO 12 | from multiprocessing import Pool 13 | import math 14 | import sys 15 | 16 | import torch 17 | from torch.nn import functional as F 18 | 19 | from ocr_lib.utils import to_torch, to_numpy 20 | from ..evaluation_metrics.metrics import get_str_list 21 | 22 | 23 | def recognition_vis(images, preds, targets, scores, dataset, vis_dir): 24 | images = images.permute(0,2,3,1) 25 | images = to_numpy(images) 26 | images = (images * 0.5 + 0.5)*255 27 | pred_list, targ_list = get_str_list(preds, targets, dataset) 28 | for id, (image, pred, target, score) in enumerate(zip(images, pred_list, targ_list, scores)): 29 | if pred.lower() == target.lower(): 30 | flag = 'right' 31 | else: 32 | flag = 'error' 33 | file_name = '{:}_{:}_{:}_{:}_{:.3f}.jpg'.format(flag, id, pred, target, score) 34 | file_path = os.path.join(vis_dir, file_name) 35 | image = Image.fromarray(np.uint8(image)) 36 | image.save(file_path) 37 | 38 | 39 | # save to disk sub process 40 | def _save_plot_pool(vis_image, save_file_path): 41 | vis_image = Image.fromarray(np.uint8(vis_image)) 42 | vis_image.save(save_file_path) 43 | 44 | 45 | def stn_vis(raw_images, rectified_images, ctrl_points, preds, targets, real_scores, pred_scores, dataset, vis_dir): 46 | """ 47 | raw_images: images without rectification 48 | rectified_images: rectified images with stn 49 | ctrl_points: predicted ctrl points 50 | preds: predicted label sequences 51 | targets: target label sequences 52 | real_scores: scores of recognition model 53 | pred_scores: predicted scores by the score branch 54 | dataset: xxx 55 | vis_dir: xxx 56 | """ 57 | if raw_images.ndimension() == 3: 58 | raw_images = raw_images.unsqueeze(0) 59 | rectified_images = rectified_images.unsqueeze(0) 60 | batch_size, _, raw_height, raw_width = raw_images.size() 61 | 62 | # translate the coordinates of ctrlpoints to image size 63 | ctrl_points = to_numpy(ctrl_points) 64 | ctrl_points[:,:,0] = ctrl_points[:,:,0] * (raw_width-1) 65 | ctrl_points[:,:,1] = ctrl_points[:,:,1] * (raw_height-1) 66 | ctrl_points = ctrl_points.astype(np.int) 67 | 68 | # tensors to pil images 69 | raw_images = raw_images.permute(0,2,3,1) 70 | raw_images = to_numpy(raw_images) 71 | raw_images = (raw_images * 0.5 + 0.5)*255 72 | rectified_images = rectified_images.permute(0,2,3,1) 73 | rectified_images = to_numpy(rectified_images) 74 | rectified_images = (rectified_images * 0.5 + 0.5)*255 75 | 76 | # draw images on canvas 77 | vis_images = [] 78 | num_sub_plot = 2 79 | raw_images = raw_images.astype(np.uint8) 80 | rectified_images = rectified_images.astype(np.uint8) 81 | for i in range(batch_size): 82 | fig = plt.figure() 83 | ax = [fig.add_subplot(num_sub_plot,1,i+1) for i in range(num_sub_plot)] 84 | for a in ax: 85 | a.set_xticklabels([]) 86 | a.set_yticklabels([]) 87 | a.axis('off') 88 | ax[0].imshow(raw_images[i]) 89 | ax[0].scatter(ctrl_points[i,:,0], ctrl_points[i,:,1], marker='+', s=5) 90 | ax[1].imshow(rectified_images[i]) 91 | # plt.subplots_adjust(wspace=0, hspace=0) 92 | plt.show() 93 | buffer_ = BytesIO() 94 | plt.savefig(buffer_, format='png', bbox_inches='tight', pad_inches=0) 95 | plt.close() 96 | buffer_.seek(0) 97 | dataPIL = Image.open(buffer_) 98 | data = np.asarray(dataPIL).astype(np.uint8) 99 | buffer_.close() 100 | 101 | vis_images.append(data) 102 | 103 | # save to disk 104 | if vis_dir is None: 105 | return vis_images 106 | else: 107 | pred_list, targ_list = get_str_list(preds, targets, dataset) 108 | file_path_list = [] 109 | for id, (image, pred, target, real_score) in enumerate(zip(vis_images, pred_list, targ_list, real_scores)): 110 | if pred.lower() == target.lower(): 111 | flag = 'right' 112 | else: 113 | flag = 'error' 114 | if pred_scores is None: 115 | file_name = '{:}_{:}_{:}_{:}_{:.3f}.png'.format(flag, id, pred, target, real_score) 116 | else: 117 | file_name = '{:}_{:}_{:}_{:}_{:.3f}_{:.3f}.png'.format(flag, id, pred, target, real_score, pred_scores[id]) 118 | file_path = os.path.join(vis_dir, file_name) 119 | file_path_list.append(file_path) 120 | 121 | with Pool(os.cpu_count()) as pool: 122 | pool.starmap(_save_plot_pool, zip(vis_images, file_path_list)) -------------------------------------------------------------------------------- /scripts/infer/test.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, date, timedelta 2 | from scripts.phishintention.configs import load_config 3 | from scripts.pipeline.test_llm import * 4 | from scripts.utils.PhishIntentionWrapper import LogoDetector, LogoEncoder, LayoutDetector 5 | import argparse 6 | from tqdm import tqdm 7 | import yaml 8 | import openai 9 | import logging 10 | from selenium.common.exceptions import * 11 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 12 | os.environ['OPENAI_API_KEY'] = open('./datasets/openai_key.txt').read().strip() 13 | logging.getLogger("httpcore").setLevel(logging.WARNING) 14 | 15 | if __name__ == '__main__': 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--folder", default="./datasets/field_study/2023-09-02/") 19 | parser.add_argument("--config", default='./param_dict.yaml', help="Config .yaml path") 20 | args = parser.parse_args() 21 | 22 | PhishLLMLogger.set_debug_on() 23 | PhishLLMLogger.set_verbose(True) 24 | 25 | # load hyperparameters 26 | with open(args.config) as file: 27 | param_dict = yaml.load(file, Loader=yaml.FullLoader) 28 | 29 | AWL_MODEL, SIAMESE_MODEL, OCR_MODEL, SIAMESE_THRE = load_config() 30 | logo_extractor = LogoDetector(AWL_MODEL) 31 | logo_encoder = LogoEncoder(SIAMESE_MODEL, OCR_MODEL, SIAMESE_THRE) 32 | layout_extractor = LayoutDetector(AWL_MODEL) 33 | 34 | # PhishLLM 35 | llm_cls = TestVLM(param_dict=param_dict, 36 | logo_encoder=logo_encoder, 37 | logo_extractor=logo_extractor, 38 | layout_extractor=layout_extractor) 39 | openai.api_key = os.getenv("OPENAI_API_KEY") 40 | openai.proxy = os.getenv("http_proxy", "") # set openai proxy 41 | 42 | driver = boot_driver() 43 | 44 | day = date.today().strftime("%Y-%m-%d") 45 | result_txt = '{}_phishllm.txt'.format(day) 46 | 47 | if not os.path.exists(result_txt): 48 | with open(result_txt, "w+") as f: 49 | f.write("folder" + "\t") 50 | f.write("phish_prediction" + "\t") 51 | f.write("target_prediction" + "\t") # write top1 prediction only 52 | f.write("brand_recog_time" + "\t") 53 | f.write("crp_prediction_time" + "\t") 54 | f.write("crp_transition_time" + "\n") 55 | 56 | 57 | for ct, folder in tqdm(enumerate(os.listdir(args.folder))): 58 | if folder in [x.split('\t')[0] for x in open(result_txt, encoding='ISO-8859-1').readlines()]: 59 | continue 60 | 61 | info_path = os.path.join(args.folder, folder, 'info.txt') 62 | html_path = os.path.join(args.folder, folder, 'html.txt') 63 | shot_path = os.path.join(args.folder, folder, 'shot.png') 64 | predict_path = os.path.join(args.folder, folder, 'predict.png') 65 | if not os.path.exists(shot_path): 66 | continue 67 | 68 | try: 69 | if len(open(info_path, encoding='ISO-8859-1').read()) > 0: 70 | url = open(info_path, encoding='ISO-8859-1').read() 71 | else: 72 | url = 'https://' + folder 73 | except FileNotFoundError: 74 | url = 'https://' + folder 75 | 76 | logo_box, reference_logo = llm_cls.detect_logo(shot_path) 77 | try: 78 | pred, brand, brand_recog_time, crp_prediction_time, crp_transition_time, plotvis = llm_cls.test(url=url, 79 | reference_logo=reference_logo, 80 | logo_box=logo_box, 81 | shot_path=shot_path, 82 | html_path=html_path, 83 | driver=driver, 84 | ) 85 | driver.delete_all_cookies() 86 | except (WebDriverException) as e: 87 | print(f"Driver crashed or encountered an error: {e}. Restarting driver.") 88 | driver = restart_driver(driver) 89 | continue 90 | 91 | try: 92 | with open(result_txt, "a+", encoding='ISO-8859-1') as f: 93 | f.write(f"{folder}\t{pred}\t{brand}\t{brand_recog_time}\t{crp_prediction_time}\t{crp_transition_time}\n") 94 | if pred == 'phish': 95 | plotvis.save(predict_path) 96 | except UnicodeEncodeError: 97 | continue 98 | 99 | 100 | driver.quit() 101 | -------------------------------------------------------------------------------- /datasets/hosting_blacklists.txt: -------------------------------------------------------------------------------- 1 | lws.fr 2 | escrow.com 3 | networksolutions.com 4 | zone.ee 5 | domain.com 6 | virualmin.com 7 | reg.ru 8 | epik.com 9 | dan.com 10 | phpmyadmin.net 11 | phpbb.com 12 | php.net 13 | api.platform 14 | tumblr.com 15 | bigbluebutton.org 16 | nginx.com 17 | tautulli.com 18 | caddyserver.com 19 | rebrandly.com 20 | keycloak.org 21 | zabbix.com 22 | djangoproject.com 23 | control-webpanel.com 24 | appwrite.io 25 | afihost.co.za 26 | afrihost.com 27 | primehost.ca 28 | parallels.com 29 | zyro.com 30 | hostinger.com 31 | diviproject.org 32 | ispmanager.com 33 | ovhcloud.com 34 | directadmin.com 35 | ispconfig.org 36 | zoner.com 37 | nextcloud.com 38 | plesk.com 39 | fornex.com 40 | alfahosting.de 41 | dhosting.pl 42 | drupal.org 43 | ininet.hu 44 | joomla.org 45 | swagger.io 46 | roundcube.net 47 | vistaprint.com 48 | vtiger.com 49 | webmin.com 50 | e-shot.net 51 | mailchimp.com 52 | portainer.io 53 | storybook.js.org 54 | redmine.org 55 | freshrss.org 56 | namecheap.com 57 | debian.org 58 | strapi.io 59 | traefik.io 60 | bitwarden.com 61 | wordpress.org 62 | hostnet.nl 63 | nuxtjs.org 64 | horde.org 65 | hoststar.ch 66 | cloudns.com 67 | domaindiscount24.com 68 | raiolanetworks.es 69 | netsite.dk 70 | webgo.de 71 | hostingwedos.cz 72 | beget.ru 73 | register.com 74 | blogger.com 75 | hostpoint.ch 76 | 1jabber.com 77 | strato.de 78 | zimbra.com 79 | mailrelay.com 80 | onlyoffice.org 81 | onlyoffice.com 82 | yunohost.org 83 | vodien.com 84 | myjobous 85 | domainname.de 86 | laravel.com 87 | autods.com 88 | codingest.net 89 | vhlcentral.com 90 | windsorbrokers.com 91 | logmein.com 92 | xng88.xyz 93 | cpalead.com 94 | profreehost.com 95 | mekari.com 96 | rediffmailpro.com 97 | wix.com 98 | whmcs.com 99 | draytek.com 100 | zyxel.com 101 | adminlte.io 102 | emby.media 103 | invoiceninja.com 104 | jxt.com.au 105 | mantisbt.org 106 | seafile.com 107 | bookstackapp.com 108 | apache.org 109 | liveconfig 110 | firezone.com 111 | wordpress.com 112 | dolibarr.org 113 | cpanel.net 114 | godaddy.com 115 | zimbra.com 116 | owncloud.com 117 | phpmyadmin.net 118 | wordpress.com 119 | draytek.com 120 | afrihost.co.za 121 | 15.vip 122 | okta.com 123 | prohoster.info 124 | matomo.org 125 | hostneel.com 126 | o2switch.fr 127 | froxlor.com 128 | contabo.com 129 | freshrss.org 130 | harmonweb.com 131 | modoboa.org 132 | antagonist.nl 133 | oderland.com 134 | strato 135 | luno 136 | luno.com 137 | bet365 138 | bet365.com 139 | manhattantrust.com 140 | nginxproxymanager.com 141 | bitrix24.com 142 | owncloud.org 143 | websitehosting.ca 144 | fiberhost.pl 145 | smartlife.com 146 | freescout.net 147 | froxlor.com 148 | gestsup.com 149 | cloudflare.com 150 | sonicpanel.com 151 | material.io 152 | invoices.com 153 | inmotionhosting.com 154 | radarr.video 155 | eshop.it 156 | horde.net 157 | moodle.org 158 | mautic.org 159 | pritunl.com 160 | chatwoot.com 161 | sentry.io 162 | pfsense.org 163 | bluehost.com 164 | siteground.com 165 | a2hosting.com 166 | digitalocean.com 167 | linode.com 168 | heroku.com 169 | vultr.com 170 | openshift.com 171 | kubernetes.io 172 | jitsi.org 173 | opnsense.org 174 | untangle.com 175 | smoothwall.org 176 | snort.org 177 | ossec.net 178 | sugarcrm.com 179 | suitecrm.com 180 | zoho.com 181 | zendesk.com 182 | hubspot.com 183 | elastic.co 184 | splunk.com 185 | grafana.com 186 | prometheus.io 187 | influxdata.com 188 | wireshark.org 189 | gimp.org 190 | ansible.com 191 | puppet.com 192 | chef.io 193 | terraform.io 194 | fastly.com 195 | akamai.com 196 | squarespace.com 197 | weebly.com 198 | shopify.com 199 | magento.com 200 | oscommerce.com 201 | prestashop.com 202 | w3schools.com 203 | mozilla.org 204 | sqlite.org 205 | postgresql.org 206 | mysql.com 207 | mariadb.org 208 | redis.io 209 | mongodb.com 210 | hostgator.com 211 | greengeeks.com 212 | midphase.com 213 | westhost.com 214 | lunarpages.com 215 | arvixe.com 216 | liquidweb.com 217 | kinsta.com 218 | wpengine.com 219 | pantheon.io 220 | enom.com 221 | nic.ru 222 | 101domain.com 223 | namesilo.com 224 | eurodns.com 225 | whois.com 226 | wireguard.com 227 | openvpn.net 228 | strongswan.org 229 | freeswan.org 230 | softether.org 231 | endian.com 232 | ipfire.org 233 | clearos.com 234 | sophos.com 235 | shorewall.org 236 | odoo.com 237 | erpnext.com 238 | yetiforce.com 239 | espocrm.com 240 | crm.zoho.com 241 | flarepointcrm.com 242 | hubspot.com 243 | pipedrive.com 244 | fuxnoten.de 245 | horde.com 246 | petrosoftinc.com 247 | mikrotik.com 248 | vodia.com 249 | ccc.de 250 | yourlogo.com 251 | wallabag.org 252 | login.gov 253 | zulip.com 254 | keyhelp.de 255 | akaunting.com 256 | plausible.io 257 | mautic.com 258 | postfixadmin.com 259 | password.com 260 | ingresar.com 261 | example.com 262 | meshcentral.com 263 | glpi-project.org 264 | facturascripts.com 265 | keyweb.de 266 | blizzard.com 267 | meiguo.com 268 | ovh.com 269 | login.com 270 | zibll.com 271 | jellyfin.org 272 | myasp.jp 273 | KEYCLOAK 274 | WrapUp Resources LLC 275 | bitwarden 276 | transip 277 | DOMAINNAME.DE 278 | Dolibarr 279 | ONLYOFFICE 280 | websupport 281 | Webmail Provider 282 | Strato AG 283 | sysPass Systems Password Manager 284 | mailu.io 285 | moje-stranky.sk 286 | tirzokcloud.com -------------------------------------------------------------------------------- /scripts/phishintention/utils/web_utils.py: -------------------------------------------------------------------------------- 1 | from selenium.common.exceptions import NoSuchElementException, TimeoutException, MoveTargetOutOfBoundsException, StaleElementReferenceException 2 | from selenium import webdriver 3 | from selenium.webdriver.common.desired_capabilities import DesiredCapabilities 4 | import helium 5 | import time 6 | import re 7 | from selenium.webdriver.chrome.service import Service 8 | from selenium.webdriver.common.by import By 9 | from selenium.webdriver.chrome.service import Service as ChromeService 10 | 11 | def initialize_chrome_settings(): 12 | ''' 13 | initialize chrome settings 14 | ''' 15 | options = webdriver.ChromeOptions() 16 | 17 | options.add_argument('--no-sandbox') 18 | options.add_argument('--disable-dev-shm-usage') 19 | options.add_argument('--disable-gpu') 20 | options.add_argument('--ignore-certificate-errors') # ignore errors 21 | options.add_argument('--ignore-ssl-errors') 22 | options.add_argument("--headless") # FIXME: do not disable browser (have some issues: https://github.com/mherrmann/selenium-python-helium/issues/47) 23 | options.add_argument('--no-proxy-server') 24 | options.add_argument("--proxy-server='direct://'") 25 | options.add_argument("--proxy-bypass-list=*") 26 | 27 | options.add_argument("--start-maximized") 28 | options.add_argument('--window-size=1920,1080') # fix screenshot size 29 | options.add_argument("--disable-blink-features=AutomationControlled") 30 | options.add_experimental_option('useAutomationExtension', False) 31 | options.add_argument( 32 | 'user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36') 33 | options.set_capability('unhandledPromptBehavior', 'dismiss') # dismiss 34 | 35 | 36 | return options 37 | 38 | def click_button(button_text): 39 | helium.Config.implicit_wait_secs = 2 # this is the implicit timeout for helium 40 | helium.get_driver().implicitly_wait(2) 41 | try: 42 | helium.click(helium.Button(button_text)) 43 | return True 44 | except: 45 | return False 46 | 47 | def get_page_text(driver): 48 | ''' 49 | get body text from html 50 | :param driver: chromdriver 51 | :return: text 52 | ''' 53 | try: 54 | body = driver.find_element(By.TAG_NAME, value='body').text 55 | except NoSuchElementException as e: # if no body tag, just get all text 56 | print(e) 57 | try: 58 | body = driver.page_source 59 | except Exception as e: 60 | print(e) 61 | body = '' 62 | return body 63 | 64 | def click_text(text): 65 | ''' 66 | click the text's region 67 | :param text: 68 | :return: 69 | ''' 70 | helium.Config.implicit_wait_secs = 2 # this is the implicit timeout for helium 71 | helium.get_driver().implicitly_wait(2) # this is the implicit timeout for selenium 72 | body = get_page_text(helium.get_driver()) 73 | try: 74 | helium.highlight(text) # highlight text for debugging 75 | time.sleep(1) 76 | if re.search(text, body, flags=re.I): 77 | helium.click(text) 78 | time.sleep(2) # wait until website is completely loaded 79 | except TimeoutException as e: 80 | print(e) 81 | except LookupError as e: 82 | print(e) 83 | except Exception as e: 84 | print(e) 85 | 86 | def click_point(x, y): 87 | ''' 88 | click on coordinate (x,y) 89 | :param x: 90 | :param y: 91 | :return: 92 | ''' 93 | helium.Config.implicit_wait_secs = 2 # this is the implicit timeout for helium 94 | helium.get_driver().implicitly_wait(2) # this the implicit timeout for selenium 95 | try: 96 | helium.click(helium.Point(x, y)) 97 | time.sleep(2) # wait until website is completely loaded 98 | # click_popup() 99 | except TimeoutException as e: 100 | print(e) 101 | except MoveTargetOutOfBoundsException as e: 102 | print(e) 103 | except LookupError as e: 104 | print(e) 105 | except AttributeError as e: 106 | print(e) 107 | except Exception as e: 108 | print(e) 109 | 110 | def visit_url(driver, orig_url): 111 | ''' 112 | Visit a URL 113 | :param driver: chromedriver 114 | :param orig_url: URL to visit 115 | :param popup: click popup window or not 116 | :param sleep: need sleep time or not 117 | :return: load url successful or not 118 | ''' 119 | try: 120 | driver.get(orig_url) 121 | time.sleep(2) 122 | driver.switch_to.alert.dismiss() 123 | return True, driver 124 | except TimeoutException as e: 125 | print(str(e)) 126 | return False, driver 127 | except Exception as e: 128 | print(str(e)) 129 | print("no alert") 130 | return True, driver 131 | 132 | 133 | def driver_loader(): 134 | 135 | options = initialize_chrome_settings() 136 | service = ChromeService(executable_path="./chromedriver-linux64/chromedriver") 137 | driver = webdriver.Chrome(service=service, options=options) 138 | driver.set_page_load_timeout(60) # set timeout to avoid wasting time 139 | driver.set_script_timeout(60) # set timeout to avoid wasting time 140 | helium.set_driver(driver) 141 | 142 | return driver 143 | 144 | 145 | -------------------------------------------------------------------------------- /scripts/phishintention/ocr_lib/models/tps_spatial_transformer.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | import itertools 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | def grid_sample(input, grid, canvas = None): 11 | output = F.grid_sample(input, grid) 12 | if canvas is None: 13 | return output 14 | else: 15 | input_mask = input.data.new(input.size()).fill_(1) 16 | output_mask = F.grid_sample(input_mask, grid) 17 | padded_output = output * output_mask + canvas * (1 - output_mask) 18 | return padded_output 19 | 20 | 21 | # phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2 22 | def compute_partial_repr(input_points, control_points): 23 | N = input_points.size(0) 24 | M = control_points.size(0) 25 | pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2) 26 | # original implementation, very slow 27 | # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance 28 | pairwise_diff_square = pairwise_diff * pairwise_diff 29 | pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1] 30 | repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist) 31 | # fix numerical error for 0 * log(0), substitute all nan with 0 32 | mask = repr_matrix != repr_matrix 33 | repr_matrix.masked_fill_(mask, 0) 34 | return repr_matrix 35 | 36 | 37 | # output_ctrl_pts are specified, according to our task. 38 | def build_output_control_points(num_control_points, margins): 39 | margin_x, margin_y = margins 40 | num_ctrl_pts_per_side = num_control_points // 2 41 | ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side) 42 | ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y 43 | ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y) 44 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 45 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 46 | # ctrl_pts_top = ctrl_pts_top[1:-1,:] 47 | # ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:] 48 | output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 49 | output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr) 50 | return output_ctrl_pts 51 | 52 | 53 | # demo: ~/test/models/test_tps_transformation.py 54 | class TPSSpatialTransformer(nn.Module): 55 | 56 | def __init__(self, output_image_size=None, num_control_points=None, margins=None): 57 | super(TPSSpatialTransformer, self).__init__() 58 | self.output_image_size = output_image_size 59 | self.num_control_points = num_control_points 60 | self.margins = margins 61 | 62 | self.target_height, self.target_width = output_image_size 63 | target_control_points = build_output_control_points(num_control_points, margins) 64 | N = num_control_points 65 | # N = N - 4 66 | 67 | # create padded kernel matrix 68 | forward_kernel = torch.zeros(N + 3, N + 3) 69 | target_control_partial_repr = compute_partial_repr(target_control_points, target_control_points) 70 | forward_kernel[:N, :N].copy_(target_control_partial_repr) 71 | forward_kernel[:N, -3].fill_(1) 72 | forward_kernel[-3, :N].fill_(1) 73 | forward_kernel[:N, -2:].copy_(target_control_points) 74 | forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1)) 75 | # compute inverse matrix 76 | device = forward_kernel.device 77 | # inverse_kernel = torch.inverse(forward_kernel) 78 | inverse_kernel = torch.inverse(forward_kernel.to("cpu")).to(device) 79 | 80 | # create target cordinate matrix 81 | HW = self.target_height * self.target_width 82 | target_coordinate = list(itertools.product(range(self.target_height), range(self.target_width))) 83 | target_coordinate = torch.Tensor(target_coordinate) # HW x 2 84 | Y, X = target_coordinate.split(1, dim = 1) 85 | Y = Y / (self.target_height - 1) 86 | X = X / (self.target_width - 1) 87 | target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y) 88 | target_coordinate_partial_repr = compute_partial_repr(target_coordinate, target_control_points) 89 | target_coordinate_repr = torch.cat([ 90 | target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate 91 | ], dim = 1) 92 | 93 | # register precomputed matrices 94 | self.register_buffer('inverse_kernel', inverse_kernel) 95 | self.register_buffer('padding_matrix', torch.zeros(3, 2)) 96 | self.register_buffer('target_coordinate_repr', target_coordinate_repr) 97 | self.register_buffer('target_control_points', target_control_points) 98 | 99 | def forward(self, input, source_control_points): 100 | assert source_control_points.ndimension() == 3 101 | assert source_control_points.size(1) == self.num_control_points 102 | assert source_control_points.size(2) == 2 103 | batch_size = source_control_points.size(0) 104 | 105 | Y = torch.cat([source_control_points, self.padding_matrix.expand(batch_size, 3, 2)], 1) 106 | mapping_matrix = torch.matmul(self.inverse_kernel, Y) 107 | source_coordinate = torch.matmul(self.target_coordinate_repr, mapping_matrix) 108 | 109 | grid = source_coordinate.view(-1, self.target_height, self.target_width, 2) 110 | grid = torch.clamp(grid, 0, 1) # the source_control_points may be out of [0, 1]. 111 | # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1] 112 | grid = 2.0 * grid - 1.0 113 | output_maps = grid_sample(input, grid, canvas=None) 114 | return output_maps, source_coordinate 115 | -------------------------------------------------------------------------------- /scripts/phishintention/ocr_lib/models/model_builder.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | from collections import OrderedDict 4 | import sys 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from torch.nn import init 10 | 11 | from .attention_recognition_head import AttentionRecognitionHead 12 | from ..loss.sequenceCrossEntropyLoss import SequenceCrossEntropyLoss 13 | from .tps_spatial_transformer import TPSSpatialTransformer 14 | from .stn_head import STNHead 15 | from .resnet_aster import ResNet_ASTER 16 | 17 | __factory = { 18 | 'ResNet_ASTER': ResNet_ASTER, 19 | } 20 | 21 | 22 | def names(): 23 | return sorted(__factory.keys()) 24 | 25 | 26 | def create(name, *args, **kwargs): 27 | """Create a model instance. 28 | 29 | Parameters 30 | ---------- 31 | name: str 32 | Model name. One of __factory 33 | pretrained: bool, optional 34 | If True, will use ImageNet pretrained model. Default: True 35 | num_classes: int, optional 36 | If positive, will change the original classifier the fit the new classifier with num_classes. Default: True 37 | with_words: bool, optional 38 | If True, the input of this model is the combination of image and word. Default: False 39 | """ 40 | if name not in __factory: 41 | raise KeyError('Unknown model:', name) 42 | return __factory[name](*args, **kwargs) 43 | 44 | # from config import get_args 45 | # global_args = get_args(sys.argv[1:]) 46 | 47 | 48 | class ModelBuilder(nn.Module): 49 | """ 50 | This is the integrated model. 51 | """ 52 | def __init__(self, arch, rec_num_classes, sDim, attDim, max_len_labels, eos, STN_ON=False): 53 | super(ModelBuilder, self).__init__() 54 | 55 | self.arch = arch 56 | self.rec_num_classes = rec_num_classes 57 | self.sDim = sDim 58 | self.attDim = attDim 59 | self.max_len_labels = max_len_labels 60 | self.eos = eos 61 | self.STN_ON = STN_ON 62 | self.tps_inputsize = [32, 64] 63 | 64 | self.encoder = create(self.arch, 65 | with_lstm=True, 66 | n_group=1) 67 | encoder_out_planes = self.encoder.out_planes 68 | 69 | self.decoder = AttentionRecognitionHead( 70 | num_classes=rec_num_classes, 71 | in_planes=encoder_out_planes, 72 | sDim=sDim, 73 | attDim=attDim, 74 | max_len_labels=max_len_labels) 75 | self.rec_crit = SequenceCrossEntropyLoss() 76 | 77 | if self.STN_ON: 78 | self.tps = TPSSpatialTransformer( 79 | output_image_size=tuple([32, 100]), 80 | num_control_points=20, 81 | margins=tuple([0.05,0.05])) 82 | 83 | self.stn_head = STNHead( 84 | in_planes=3, 85 | num_ctrlpoints=20, 86 | activation='none') 87 | 88 | def forward(self, input_dict): 89 | return_dict = {} 90 | return_dict['losses'] = {} 91 | return_dict['output'] = {} 92 | 93 | x, rec_targets, rec_lengths = input_dict['images'], \ 94 | input_dict['rec_targets'], \ 95 | input_dict['rec_lengths'] 96 | 97 | # rectification 98 | if self.STN_ON: 99 | # input images are downsampled before being fed into stn_head. 100 | stn_input = F.interpolate(x, self.tps_inputsize, mode='bilinear', align_corners=True) 101 | stn_img_feat, ctrl_points = self.stn_head(stn_input) 102 | x, _ = self.tps(x, ctrl_points) 103 | if not self.training: 104 | # save for visualization 105 | return_dict['output']['ctrl_points'] = ctrl_points 106 | return_dict['output']['rectified_images'] = x 107 | 108 | encoder_feats = self.encoder(x) 109 | encoder_feats = encoder_feats.contiguous() 110 | 111 | if self.training: 112 | rec_pred = self.decoder([encoder_feats, rec_targets, rec_lengths]) 113 | loss_rec = self.rec_crit(rec_pred, rec_targets, rec_lengths) 114 | return_dict['losses']['loss_rec'] = loss_rec 115 | else: 116 | rec_pred, rec_pred_scores = self.decoder.beam_search(encoder_feats, 5, self.eos) 117 | rec_pred_ = self.decoder([encoder_feats, rec_targets, rec_lengths]) 118 | loss_rec = self.rec_crit(rec_pred_, rec_targets, rec_lengths) 119 | return_dict['losses']['loss_rec'] = loss_rec 120 | return_dict['output']['pred_rec'] = rec_pred 121 | return_dict['output']['pred_rec_score'] = rec_pred_scores 122 | 123 | # pytorch0.4 bug on gathering scalar(0-dim) tensors 124 | for k, v in return_dict['losses'].items(): 125 | return_dict['losses'][k] = v.unsqueeze(0) 126 | 127 | return return_dict 128 | 129 | def features(self, input_dict): 130 | 131 | x, rec_targets, rec_lengths = input_dict['images'], \ 132 | input_dict['rec_targets'], \ 133 | input_dict['rec_lengths'] 134 | 135 | # rectification 136 | if self.STN_ON: 137 | # input images are downsampled before being fed into stn_head. 138 | stn_input = F.interpolate(x, self.tps_inputsize, mode='bilinear', align_corners=True) 139 | stn_img_feat, ctrl_points = self.stn_head(stn_input) 140 | x, _ = self.tps(x, ctrl_points) 141 | 142 | encoder_feats = self.encoder(x) 143 | encoder_feats = encoder_feats.contiguous() 144 | 145 | rec_pred_ = self.decoder([encoder_feats, rec_targets, rec_lengths]) 146 | 147 | return encoder_feats, rec_pred_ 148 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # PhishVLM 3 | 4 | An extension from our work "Less Defined Knowledge and More True Alarms: Reference-based Phishing Detection without a Pre-defined Reference List". 5 | Published in USENIX Security 2024. 6 | 7 |

8 | 9 | • Read our Paper • 10 | 11 | • Visit our Website • 12 | 13 | • Download our Datasets • 14 | 15 | • Cite our Paper • 16 | 17 |

18 | 19 | ## Introduction 20 | Existing reference-based phishing detection: 21 | 22 | - :x: Relies on a pre-defined reference list, which is lack of comprehensiveness and incurs high maintenance cost 23 | - :x: Does not fully make use of the textual semantics present on the webpage 24 | 25 | In our PhishVLM, we build a reference-based phishing detection framework: 26 | 27 | - ✅ **Without the pre-defined reference list**: Modern VLMs have encoded far more extensive brand-domain information than any predefined list 28 | - ✅ **Chain-of-thought credential-taking prediction**: Reasoning the credential-taking status in a step-by-step way by looking at the screenshot 29 | 30 | ## Framework 31 | 32 | 33 | ```Input```: a URL and its screenshot, ```Output```: Phish/Benign, Phishing target 34 | 35 | - **Step 1: Brand recognition model** 36 | - Input: Logo Screenshot 37 | - Output: VLM's predicted brand 38 | 39 | - **Step 2: Credential-Requiring-Page classification model** 40 | - Input: Webpage Screenshot 41 | - Output: VLM chooses from A. Credential-Taking Page or B. Non-Credential-Taking Page 42 | - Go to step 4 if VLM chooses 'A', otherwise go to step 3. 43 | 44 | - **Step 3: Credential-Requiring-Page transition model (activates if VLM chooses 'B' from the last step)** 45 | - Input: All clickable UI elements screenshots 46 | - Intermediate Output: Top-1 most likely login UI 47 | - Output: Webpage after clicking that UI, **go back to Step 1** with the updated webpage and URL 48 | 49 | - **Step 4: Output step** 50 | - _Case 1_: If the domain is from a web hosting domain: it is flagged as **phishing** if 51 | (i) VLM predicts a targeted brand inconsistent with the webpage's domain 52 | and (ii) VLM chooses 'A' from Step 2 53 | 54 | - _Case 2_: If the domain is not from a web hosting domain: it is flagged as **phishing** if 55 | (i) VLM predicts a targeted brand inconsistent with the webpage's domain 56 | (ii) VLM chooses 'A' from Step 2 57 | and (iii) the domain is not a popular domain indexed by Google 58 | 59 | - _Otherwise_: reported as **benign** 60 | 61 | ## Project structure 62 | 63 |
 64 | scripts/ 
 65 | ├── infer/
 66 | │   └──test.py             # inference script
 67 | ├── pipeline/             
 68 | │   └──test_llm.py # TestVLM class
 69 | └── utils/ # other utitiles such as web interaction utility functions 
 70 | 
 71 | prompts/ 
 72 | ├── brand_recog_prompt.json 
 73 | └── crp_pred_prompt.json
 74 | └── crp_trans_prompt.json
 75 | 
 76 | 
77 | 78 | ## Setup 79 | 80 | ### Step 1: **Install Requirements**. 81 | 82 | Tested on Ubuntu, CUDA 11 83 | 84 | - A new conda environment "phishllm" will be created after this step 85 | ```bash 86 | conda create -n phishllm python=3.10 87 | conda activate phishllm 88 | pip install -r requirements.txt 89 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 90 | pip install --no-build-isolation git+https://github.com/facebookresearch/detectron2.git 91 | cd scripts/phishintention 92 | chmod +x setup.sh 93 | ./setup.sh 94 | ``` 95 | 96 | 97 | ### Step 2: **Install Chrome** 98 | ```bash 99 | sudo apt install ./google-chrome-stable_current_amd64.deb 100 | ``` 101 | 102 | ### Step 3: Register **Two API Keys**. 103 | 104 | - 🔑 **OpenAI API key**, [See Tutorial here](https://platform.openai.com/docs/quickstart). Paste the API key to ``./datasets/openai_key.txt``. 105 | 106 | - 🔑 **Google Programmable Search API Key**, [See Tutorial here](https://meta.discourse.org/t/google-search-for-discourse-ai-programmable-search-engine-and-custom-search-api/307107). 107 | Paste your API Key (in the first line) and Search Engine ID (in the second line) to ``./datasets/google_api_key.txt``: 108 | ```text 109 | [API_KEY] 110 | [SEARCH_ENGINE_ID] 111 | ``` 112 | 113 | ## Prepare the Dataset 114 | To test on your own dataset, you need to prepare the dataset in the following structure: 115 |
116 | testing_dir/
117 | ├── aaa.com/
118 | │   ├── shot.png  # save the webpage screenshot
119 | │   ├── info.txt  # save the webpage URL
120 | │   └── html.txt  # save the webpage HTML source
121 | ├── bbb.com/
122 | │   ├── shot.png  # save the webpage screenshot
123 | │   ├── info.txt  # save the webpage URL
124 | │   └── html.txt  # save the webpage HTML source
125 | ├── ccc.com/
126 | │   ├── shot.png  # save the webpage screenshot
127 | │   ├── info.txt  # save the webpage URL
128 | │   └── html.txt  # save the webpage HTML source
129 | 
130 | 131 | 132 | ## Inference: Run PhishLLM 133 | ```bash 134 | conda activate phishllm 135 | python scripts/infer/test.py --folder [folder to test, e.g., ./datasets/test_sites] 136 | ``` 137 | 138 | ## Understand the Output 139 | - You will see the console is printing logs like the following
Expand to see the sample log 140 |

141 |     [PhishLLMLogger][DEBUG] Folder ./datasets/field_study/2023-09-01/device-862044b2-5124-4735-b6d5-f114eea4a232.remotewd.com
142 |     [PhishLLMLogger][DEBUG] Time taken for LLM brand prediction: 0.9699530601501465 Detected brand: sonicwall.com
143 |     [PhishLLMLogger][DEBUG] Domain sonicwall.com is valid and alive
144 |     [PhishLLMLogger][DEBUG] Time taken for LLM CRP classification: 2.9195783138275146 	 CRP prediction: A. This is a credential-requiring page.
145 |     [❗️] Phishing discovered, phishing target is sonicwall.com
146 |   
147 | 148 | - Meanwhile, a txt file named "[today's date]_phishllm.txt" is being created, it has the following columns: 149 | - "folder": name of the folder 150 | - "phish_prediction": "phish" | "benign" 151 | - "target_prediction": phishing target brand's domain, e.g. paypal.com, meta.com 152 | - "brand_recog_time": time taken for brand recognition 153 | - "crp_prediction_time": time taken for CRP prediction 154 | - "crp_transition_time": time taken for CRP transition 155 | 156 | ## Citations 157 | ```bibtex 158 | @inproceedings {299838, 159 | author = {Ruofan Liu and Yun Lin and Xiwen Teoh and Gongshen Liu and Zhiyong Huang and Jin Song Dong}, 160 | title = {Less Defined Knowledge and More True Alarms: Reference-based Phishing Detection without a Pre-defined Reference List}, 161 | booktitle = {33rd USENIX Security Symposium (USENIX Security 24)}, 162 | year = {2024}, 163 | isbn = {978-1-939133-44-1}, 164 | address = {Philadelphia, PA}, 165 | pages = {523--540}, 166 | url = {https://www.usenix.org/conference/usenixsecurity24/presentation/liu-ruofan}, 167 | publisher = {USENIX Association}, 168 | month = aug 169 | } 170 | ``` 171 | If you have any issues running our code, you can raise a Github issue or email us liu.ruofan16@u.nus.edu, lin_yun@sjtu.edu.cn, dcsdjs@nus.edu.sg. -------------------------------------------------------------------------------- /scripts/phishintention/modules/models2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Bottleneck ResNet v2 with GroupNorm and Weight Standardization.""" 17 | 18 | from collections import OrderedDict # pylint: disable=g-importing-member 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | 24 | 25 | class StdConv2d(nn.Conv2d): 26 | 27 | def forward(self, x): 28 | w = self.weight 29 | v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) 30 | w = (w - m) / torch.sqrt(v + 1e-10) 31 | return F.conv2d(x, w, self.bias, self.stride, self.padding, 32 | self.dilation, self.groups) 33 | 34 | 35 | def conv3x3(cin, cout, stride=1, groups=1, bias=False): 36 | return StdConv2d(cin, cout, kernel_size=3, stride=stride, 37 | padding=1, bias=bias, groups=groups) 38 | 39 | 40 | def conv1x1(cin, cout, stride=1, bias=False): 41 | return StdConv2d(cin, cout, kernel_size=1, stride=stride, 42 | padding=0, bias=bias) 43 | 44 | 45 | def tf2th(conv_weights): 46 | """Possibly convert HWIO to OIHW.""" 47 | if conv_weights.ndim == 4: 48 | conv_weights = conv_weights.transpose([3, 2, 0, 1]) 49 | return torch.from_numpy(conv_weights) 50 | 51 | 52 | class PreActBottleneck(nn.Module): 53 | """Pre-activation (v2) bottleneck block. 54 | 55 | Follows the implementation of "Identity Mappings in Deep Residual Networks": 56 | https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua 57 | 58 | Except it puts the stride on 3x3 conv when available. 59 | """ 60 | 61 | def __init__(self, cin, cout=None, cmid=None, stride=1): 62 | super().__init__() 63 | cout = cout or cin 64 | cmid = cmid or cout//4 65 | 66 | self.gn1 = nn.GroupNorm(32, cin) 67 | self.conv1 = conv1x1(cin, cmid) 68 | self.gn2 = nn.GroupNorm(32, cmid) 69 | self.conv2 = conv3x3(cmid, cmid, stride) # Original code has it on conv1!! 70 | self.gn3 = nn.GroupNorm(32, cmid) 71 | self.conv3 = conv1x1(cmid, cout) 72 | self.relu = nn.ReLU(inplace=True) 73 | 74 | if (stride != 1 or cin != cout): 75 | # Projection also with pre-activation according to paper. 76 | self.downsample = conv1x1(cin, cout, stride) 77 | 78 | def forward(self, x): 79 | out = self.relu(self.gn1(x)) 80 | 81 | # Residual branch 82 | residual = x 83 | if hasattr(self, 'downsample'): 84 | residual = self.downsample(out) 85 | 86 | # Unit's branch 87 | out = self.conv1(out) 88 | out = self.conv2(self.relu(self.gn2(out))) 89 | out = self.conv3(self.relu(self.gn3(out))) 90 | 91 | return out + residual 92 | 93 | def load_from(self, weights, prefix=''): 94 | convname = 'standardized_conv2d' 95 | with torch.no_grad(): 96 | self.conv1.weight.copy_(tf2th(weights[f'{prefix}a/{convname}/kernel'])) 97 | self.conv2.weight.copy_(tf2th(weights[f'{prefix}b/{convname}/kernel'])) 98 | self.conv3.weight.copy_(tf2th(weights[f'{prefix}c/{convname}/kernel'])) 99 | self.gn1.weight.copy_(tf2th(weights[f'{prefix}a/group_norm/gamma'])) 100 | self.gn2.weight.copy_(tf2th(weights[f'{prefix}b/group_norm/gamma'])) 101 | self.gn3.weight.copy_(tf2th(weights[f'{prefix}c/group_norm/gamma'])) 102 | self.gn1.bias.copy_(tf2th(weights[f'{prefix}a/group_norm/beta'])) 103 | self.gn2.bias.copy_(tf2th(weights[f'{prefix}b/group_norm/beta'])) 104 | self.gn3.bias.copy_(tf2th(weights[f'{prefix}c/group_norm/beta'])) 105 | if hasattr(self, 'downsample'): 106 | w = weights[f'{prefix}a/proj/{convname}/kernel'] 107 | self.downsample.weight.copy_(tf2th(w)) 108 | 109 | 110 | class ResNetV2(nn.Module): 111 | """Implementation of Pre-activation (v2) ResNet mode.""" 112 | 113 | def __init__(self, block_units, width_factor, head_size=21843, zero_head=False, ocr_emb_size=512): 114 | super().__init__() 115 | wf = width_factor 116 | self.wf = wf 117 | # The following will be unreadable if we split lines. 118 | # pylint: disable=line-too-long 119 | self.root = nn.Sequential(OrderedDict([ 120 | ('conv', StdConv2d(3, 64*wf, kernel_size=7, stride=2, padding=3, bias=False)), 121 | ('pad', nn.ConstantPad2d(1, 0)), 122 | ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)), 123 | ])) 124 | 125 | self.body = nn.Sequential(OrderedDict([ 126 | ('block1', nn.Sequential(OrderedDict( 127 | [('unit01', PreActBottleneck(cin=64*wf, cout=256*wf, cmid=64*wf))] + 128 | [(f'unit{i:02d}', PreActBottleneck(cin=256*wf, cout=256*wf, cmid=64*wf)) for i in range(2, block_units[0] + 1)], 129 | ))), 130 | ('block2', nn.Sequential(OrderedDict( 131 | [('unit01', PreActBottleneck(cin=256*wf, cout=512*wf, cmid=128*wf, stride=2))] + 132 | [(f'unit{i:02d}', PreActBottleneck(cin=512*wf, cout=512*wf, cmid=128*wf)) for i in range(2, block_units[1] + 1)], 133 | ))), 134 | ('block3', nn.Sequential(OrderedDict( 135 | [('unit01', PreActBottleneck(cin=512*wf, cout=1024*wf, cmid=256*wf, stride=2))] + 136 | [(f'unit{i:02d}', PreActBottleneck(cin=1024*wf, cout=1024*wf, cmid=256*wf)) for i in range(2, block_units[2] + 1)], 137 | ))), 138 | ('block4', nn.Sequential(OrderedDict( 139 | [('unit01', PreActBottleneck(cin=1024*wf, cout=2048*wf, cmid=512*wf, stride=2))] + 140 | [(f'unit{i:02d}', PreActBottleneck(cin=2048*wf, cout=2048*wf, cmid=512*wf)) for i in range(2, block_units[3] + 1)], 141 | ))), 142 | ])) 143 | # pylint: enable=line-too-long 144 | 145 | self.zero_head = zero_head 146 | self.head = nn.Sequential(OrderedDict([ 147 | ('gn', nn.GroupNorm(32, 2048*wf)), 148 | ('relu', nn.ReLU(inplace=True)), 149 | ('avg', nn.AdaptiveAvgPool2d(output_size=1)), 150 | ])) 151 | 152 | self.additionalfc = nn.Sequential(OrderedDict([ 153 | ('conv_add', nn.Linear(2048*wf+ocr_emb_size, head_size)), 154 | ])) 155 | 156 | def features(self, x, ocr_emb): 157 | x = self.head(self.body(self.root(x))) 158 | x = x.view(-1, 2048*self.wf) 159 | x = torch.cat((x, ocr_emb), dim=1) 160 | return x.squeeze(-1).squeeze(-1) 161 | 162 | def forward(self, x, ocr_emb): 163 | x = self.head(self.body(self.root(x))) 164 | x = x.view(-1, 2048*self.wf) 165 | x = torch.cat((x, ocr_emb), dim=1) 166 | x = self.additionalfc(x) 167 | print(x.shape) 168 | 169 | return x 170 | 171 | def load_from(self, weights, prefix='resnet/'): 172 | with torch.no_grad(): 173 | self.root.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])) # pylint: disable=line-too-long 174 | self.head.gn.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma'])) 175 | self.head.gn.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta'])) 176 | for bname, block in self.body.named_children(): 177 | for uname, unit in block.named_children(): 178 | unit.load_from(weights, prefix=f'{prefix}{bname}/{uname}/') 179 | 180 | 181 | KNOWN_MODELS = OrderedDict([ 182 | ('BiT-M-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)), 183 | ('BiT-M-R50x3', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)), 184 | ('BiT-M-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)), 185 | ('BiT-M-R101x3', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)), 186 | ('BiT-M-R152x2', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)), 187 | ('BiT-M-R152x4', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)), 188 | ('BiT-S-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)), 189 | ('BiT-S-R50x3', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)), 190 | ('BiT-S-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)), 191 | ('BiT-S-R101x3', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)), 192 | ('BiT-S-R152x2', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)), 193 | ('BiT-S-R152x4', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)), 194 | ]) -------------------------------------------------------------------------------- /scripts/phishintention/ocr_lib/models/attention_recognition_head.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import sys 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.nn import init 9 | 10 | 11 | class AttentionRecognitionHead(nn.Module): 12 | """ 13 | input: [b x 16 x 64 x in_planes] 14 | output: probability sequence: [b x T x num_classes] 15 | """ 16 | def __init__(self, num_classes, in_planes, sDim, attDim, max_len_labels): 17 | super(AttentionRecognitionHead, self).__init__() 18 | self.num_classes = num_classes # this is the output classes. So it includes the . 19 | self.in_planes = in_planes 20 | self.sDim = sDim 21 | self.attDim = attDim 22 | self.max_len_labels = max_len_labels 23 | 24 | self.decoder = DecoderUnit(sDim=sDim, xDim=in_planes, yDim=num_classes, attDim=attDim) 25 | 26 | def forward(self, x): 27 | x, targets, lengths = x 28 | batch_size = x.size(0) 29 | # Decoder 30 | state = torch.zeros(1, batch_size, self.sDim) 31 | outputs = [] 32 | 33 | for i in range(max(lengths)): 34 | if i == 0: 35 | y_prev = torch.zeros((batch_size)).fill_(self.num_classes) # the last one is used as the . 36 | else: 37 | y_prev = targets[:,i-1] 38 | 39 | output, state = self.decoder(x, state, y_prev) 40 | outputs.append(output) 41 | outputs = torch.cat([_.unsqueeze(1) for _ in outputs], 1) 42 | return outputs 43 | 44 | # inference stage. 45 | def sample(self, x): 46 | x, _, _ = x 47 | batch_size = x.size(0) 48 | # Decoder 49 | state = torch.zeros(1, batch_size, self.sDim) 50 | 51 | predicted_ids, predicted_scores = [], [] 52 | for i in range(self.max_len_labels): 53 | if i == 0: 54 | y_prev = torch.zeros((batch_size)).fill_(self.num_classes) 55 | else: 56 | y_prev = predicted 57 | 58 | output, state = self.decoder(x, state, y_prev) 59 | output = F.softmax(output, dim=1) 60 | score, predicted = output.max(1) 61 | predicted_ids.append(predicted.unsqueeze(1)) 62 | predicted_scores.append(score.unsqueeze(1)) 63 | predicted_ids = torch.cat(predicted_ids, 1) 64 | predicted_scores = torch.cat(predicted_scores, 1) 65 | # return predicted_ids.squeeze(), predicted_scores.squeeze() 66 | return predicted_ids, predicted_scores 67 | 68 | def beam_search(self, x, beam_width, eos): 69 | 70 | def _inflate(tensor, times, dim): 71 | repeat_dims = [1] * tensor.dim() 72 | repeat_dims[dim] = times 73 | return tensor.repeat(*repeat_dims) 74 | 75 | # https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py 76 | batch_size, l, d = x.size() 77 | # inflated_encoder_feats = _inflate(encoder_feats, beam_width, 0) # ABC --> AABBCC -/-> ABCABC 78 | inflated_encoder_feats = x.unsqueeze(1).permute((1,0,2,3)).repeat((beam_width,1,1,1)).permute((1,0,2,3)).contiguous().view(-1, l, d) 79 | 80 | # Initialize the decoder 81 | state = torch.zeros(1, batch_size * beam_width, self.sDim) 82 | pos_index = (torch.Tensor(range(batch_size)) * beam_width).long().view(-1, 1) 83 | 84 | # Initialize the scores 85 | sequence_scores = torch.Tensor(batch_size * beam_width, 1) 86 | sequence_scores.fill_(-float('Inf')) 87 | sequence_scores.index_fill_(0, torch.Tensor([i * beam_width for i in range(0, batch_size)]).long(), 0.0) 88 | # sequence_scores.fill_(0.0) 89 | 90 | # Initialize the input vector 91 | y_prev = torch.zeros((batch_size * beam_width)).fill_(self.num_classes) 92 | 93 | # Store decisions for backtracking 94 | stored_scores = list() 95 | stored_predecessors = list() 96 | stored_emitted_symbols = list() 97 | 98 | for i in range(self.max_len_labels): 99 | output, state = self.decoder(inflated_encoder_feats, state, y_prev) 100 | log_softmax_output = F.log_softmax(output, dim=1) 101 | 102 | sequence_scores = _inflate(sequence_scores, self.num_classes, 1) 103 | sequence_scores += log_softmax_output 104 | scores, candidates = sequence_scores.view(batch_size, -1).topk(beam_width, dim=1) 105 | 106 | # Reshape input = (bk, 1) and sequence_scores = (bk, 1) 107 | y_prev = (candidates % self.num_classes).view(batch_size * beam_width) 108 | sequence_scores = scores.view(batch_size * beam_width, 1) 109 | 110 | # Update fields for next timestep 111 | predecessors = (candidates / self.num_classes + pos_index.expand_as(candidates)).view(batch_size * beam_width, 1) 112 | state = state.index_select(1, predecessors.squeeze()) 113 | 114 | # Update sequence socres and erase scores for symbol so that they aren't expanded 115 | stored_scores.append(sequence_scores.clone()) 116 | eos_indices = y_prev.view(-1, 1).eq(eos) 117 | if eos_indices.nonzero().dim() > 0: 118 | sequence_scores.masked_fill_(eos_indices, -float('inf')) 119 | 120 | # Cache results for backtracking 121 | stored_predecessors.append(predecessors) 122 | stored_emitted_symbols.append(y_prev) 123 | 124 | # Do backtracking to return the optimal values 125 | #====== backtrak ======# 126 | # Initialize return variables given different types 127 | p = list() 128 | l = [[self.max_len_labels] * beam_width for _ in range(batch_size)] # Placeholder for lengths of top-k sequences 129 | 130 | # the last step output of the beams are not sorted 131 | # thus they are sorted here 132 | sorted_score, sorted_idx = stored_scores[-1].view(batch_size, beam_width).topk(beam_width) 133 | # initialize the sequence scores with the sorted last step beam scores 134 | s = sorted_score.clone() 135 | 136 | batch_eos_found = [0] * batch_size # the number of EOS found 137 | # in the backward loop below for each batch 138 | t = self.max_len_labels - 1 139 | # initialize the back pointer with the sorted order of the last step beams. 140 | # add pos_index for indexing variable with b*k as the first dimension. 141 | t_predecessors = (sorted_idx + pos_index.expand_as(sorted_idx)).view(batch_size * beam_width) 142 | while t >= 0: 143 | # Re-order the variables with the back pointer 144 | current_symbol = stored_emitted_symbols[t].index_select(0, t_predecessors) 145 | t_predecessors = stored_predecessors[t].index_select(0, t_predecessors).squeeze() 146 | eos_indices = stored_emitted_symbols[t].eq(eos).nonzero() 147 | if eos_indices.dim() > 0: 148 | for i in range(eos_indices.size(0)-1, -1, -1): 149 | # Indices of the EOS symbol for both variables 150 | # with b*k as the first dimension, and b, k for 151 | # the first two dimensions 152 | idx = eos_indices[i] 153 | b_idx = int(idx[0] / beam_width) 154 | # The indices of the replacing position 155 | # according to the replacement strategy noted above 156 | res_k_idx = beam_width - (batch_eos_found[b_idx] % beam_width) - 1 157 | batch_eos_found[b_idx] += 1 158 | res_idx = b_idx * beam_width + res_k_idx 159 | 160 | # Replace the old information in return variables 161 | # with the new ended sequence information 162 | t_predecessors[res_idx] = stored_predecessors[t][idx[0]] 163 | current_symbol[res_idx] = stored_emitted_symbols[t][idx[0]] 164 | s[b_idx, res_k_idx] = stored_scores[t][idx[0], [0]] 165 | l[b_idx][res_k_idx] = t + 1 166 | 167 | # record the back tracked results 168 | p.append(current_symbol) 169 | 170 | t -= 1 171 | 172 | # Sort and re-order again as the added ended sequences may change 173 | # the order (very unlikely) 174 | s, re_sorted_idx = s.topk(beam_width) 175 | for b_idx in range(batch_size): 176 | l[b_idx] = [l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx,:]] 177 | 178 | re_sorted_idx = (re_sorted_idx + pos_index.expand_as(re_sorted_idx)).view(batch_size*beam_width) 179 | 180 | # Reverse the sequences and re-order at the same time 181 | # It is reversed because the backtracking happens in reverse time order 182 | p = [step.index_select(0, re_sorted_idx).view(batch_size, beam_width, -1) for step in reversed(p)] 183 | p = torch.cat(p, -1)[:,0,:] 184 | return p, torch.ones_like(p) 185 | 186 | 187 | class AttentionUnit(nn.Module): 188 | def __init__(self, sDim, xDim, attDim): 189 | super(AttentionUnit, self).__init__() 190 | 191 | self.sDim = sDim 192 | self.xDim = xDim 193 | self.attDim = attDim 194 | 195 | self.sEmbed = nn.Linear(sDim, attDim) 196 | self.xEmbed = nn.Linear(xDim, attDim) 197 | self.wEmbed = nn.Linear(attDim, 1) 198 | 199 | # self.init_weights() 200 | 201 | def init_weights(self): 202 | init.normal_(self.sEmbed.weight, std=0.01) 203 | init.constant_(self.sEmbed.bias, 0) 204 | init.normal_(self.xEmbed.weight, std=0.01) 205 | init.constant_(self.xEmbed.bias, 0) 206 | init.normal_(self.wEmbed.weight, std=0.01) 207 | init.constant_(self.wEmbed.bias, 0) 208 | 209 | def forward(self, x, sPrev): 210 | batch_size, T, _ = x.size() # [b x T x xDim] 211 | x = x.view(-1, self.xDim) # [(b x T) x xDim] 212 | xProj = self.xEmbed(x) # [(b x T) x attDim] 213 | xProj = xProj.view(batch_size, T, -1) # [b x T x attDim] 214 | 215 | sPrev = sPrev.squeeze(0) 216 | sProj = self.sEmbed(sPrev) # [b x attDim] 217 | sProj = torch.unsqueeze(sProj, 1) # [b x 1 x attDim] 218 | sProj = sProj.expand(batch_size, T, self.attDim) # [b x T x attDim] 219 | 220 | sumTanh = torch.tanh(sProj + xProj) 221 | sumTanh = sumTanh.view(-1, self.attDim) 222 | 223 | vProj = self.wEmbed(sumTanh) # [(b x T) x 1] 224 | vProj = vProj.view(batch_size, T) 225 | 226 | alpha = F.softmax(vProj, dim=1) # attention weights for each sample in the minibatch 227 | 228 | return alpha 229 | 230 | 231 | class DecoderUnit(nn.Module): 232 | def __init__(self, sDim, xDim, yDim, attDim): 233 | super(DecoderUnit, self).__init__() 234 | self.sDim = sDim 235 | self.xDim = xDim 236 | self.yDim = yDim 237 | self.attDim = attDim 238 | self.emdDim = attDim 239 | 240 | self.attention_unit = AttentionUnit(sDim, xDim, attDim) 241 | self.tgt_embedding = nn.Embedding(yDim+1, self.emdDim) # the last is used for 242 | self.gru = nn.GRU(input_size=xDim+self.emdDim, hidden_size=sDim, batch_first=True) 243 | self.fc = nn.Linear(sDim, yDim) 244 | 245 | # self.init_weights() 246 | 247 | def init_weights(self): 248 | init.normal_(self.tgt_embedding.weight, std=0.01) 249 | init.normal_(self.fc.weight, std=0.01) 250 | init.constant_(self.fc.bias, 0) 251 | 252 | def forward(self, x, sPrev, yPrev): 253 | # x: feature sequence from the image decoder. 254 | batch_size, T, _ = x.size() 255 | alpha = self.attention_unit(x, sPrev) 256 | context = torch.bmm(alpha.unsqueeze(1), x).squeeze(1) 257 | yProj = self.tgt_embedding(yPrev.long()) 258 | # self.gru.flatten_parameters() 259 | output, state = self.gru(torch.cat([yProj, context], 1).unsqueeze(1), sPrev) 260 | output = output.squeeze(1) 261 | 262 | output = self.fc(output) 263 | return output, state -------------------------------------------------------------------------------- /scripts/phishintention/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | import os 5 | import math 6 | 7 | 8 | def coord_reshape(coords, image_shape, reshaped_size=(256, 512)): 9 | ''' 10 | Revise coordinates when the image is resized 11 | ''' 12 | height, width = image_shape 13 | new_coords = [] 14 | for c in coords: 15 | x1, y1, x2, y2 = c 16 | x1n, y1n, x2n, y2n = reshaped_size[1] * x1 / width, reshaped_size[0] * y1 / height, \ 17 | reshaped_size[1] * x2 / width, reshaped_size[0] * y2 / height 18 | new_coords.append([x1n, y1n, x2n, y2n]) 19 | 20 | return np.asarray(new_coords) 21 | 22 | 23 | def coord2pixel_reverse(img_path, coords, types, num_types=5, reshaped_size=(256, 512)) -> torch.Tensor: 24 | ''' 25 | Convert coordinate to multi-hot encodings for coordinate class 26 | ''' 27 | img = cv2.imread(img_path) if not isinstance(img_path, np.ndarray) else img_path 28 | coords = coords.numpy() if not isinstance(coords, np.ndarray) else coords 29 | coords = coord_reshape(coords, img.shape[:2], reshaped_size) # reshape coordinates 30 | types = types.numpy() if not isinstance(types, np.ndarray) else types 31 | 32 | # Incorrect path/empty image 33 | if img is None: 34 | raise AttributeError('Image is None') 35 | height, width = img.shape[:2] 36 | # Empty image 37 | if height == 0 or width == 0: 38 | raise AttributeError('Empty image') 39 | 40 | # grid array of shape ClassxHxW 41 | grid_arrs = np.zeros((num_types, reshaped_size[0], reshaped_size[1])) 42 | 43 | for j, coord in enumerate(coords): 44 | x1, y1, x2, y2 = coord 45 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) 46 | if x2 - x1 <= 0 or y2 - y1 <= 0: 47 | continue # ignore 48 | 49 | # multi-hot encoding for type? 50 | class_position = types[j] 51 | grid_arrs[class_position, y1:y2, x1:x2] = 1. 52 | 53 | return torch.from_numpy(grid_arrs) 54 | 55 | 56 | def coord2pixel(img_path, coords, types, num_types=5, reshaped_size=(256, 512)) -> torch.Tensor: 57 | ''' 58 | Convert coordinate to multi-hot encodings for coordinate class 59 | ''' 60 | img = cv2.imread(img_path) if not isinstance(img_path, np.ndarray) else img_path 61 | coords = coords.numpy() if not isinstance(coords, np.ndarray) else coords 62 | coords = coord_reshape(coords, img.shape[:2], reshaped_size) # reshape coordinates 63 | types = types.numpy() if not isinstance(types, np.ndarray) else types 64 | 65 | # Incorrect path/empty image 66 | if img is None: 67 | raise AttributeError('Image is None') 68 | height, width = img.shape[:2] 69 | # Empty image 70 | if height == 0 or width == 0: 71 | raise AttributeError('Empty image') 72 | 73 | # grid array of shape ClassxHxW = 5xHxW 74 | grid_arrs = np.zeros((num_types, reshaped_size[0], reshaped_size[1])) 75 | type_dict = {'logo': 1, 'input': 2, 'button': 3, 'label': 4, 'block': 5} 76 | 77 | for j, coord in enumerate(coords): 78 | x1, y1, x2, y2 = coord 79 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) 80 | if x2 - x1 <= 0 or y2 - y1 <= 0: 81 | continue # ignore 82 | 83 | # multi-hot encoding for type? 84 | class_position = type_dict[types[j]] - 1 85 | grid_arrs[class_position, y1:y2, x1:x2] = 1. 86 | 87 | return torch.from_numpy(grid_arrs) 88 | 89 | 90 | def topo2pixel(img_path, coords, knn_matrix, reshaped_size=(256, 512)) -> torch.Tensor: 91 | ''' 92 | Convert coordinate to multi-hot encodings for coordinate class 93 | ''' 94 | img = cv2.imread(img_path) if not isinstance(img_path, np.ndarray) else img_path 95 | coords = coords.numpy() if not isinstance(coords, np.ndarray) else coords 96 | coords = coord_reshape(coords, img.shape[:2], reshaped_size) # reshape coordinates 97 | knn_matrix = knn_matrix.numpy() if not isinstance(knn_matrix, np.ndarray) else knn_matrix 98 | 99 | # Incorrect path/empty image 100 | if img is None: 101 | raise AttributeError('Image is None') 102 | height, width = img.shape[:2] 103 | # Empty image 104 | if height == 0 or width == 0: 105 | raise AttributeError('Empty image') 106 | 107 | # grid array of shape (KxZ)xHxW = 12xHxW 108 | topo_arrs = np.zeros((12, reshaped_size[0], reshaped_size[1])) 109 | if len(coords) <= 1: # num of components smaller than 2 110 | return torch.from_numpy(topo_arrs) 111 | 112 | for j, coord in enumerate(coords): 113 | x1, y1, x2, y2 = coord 114 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) 115 | if x2 - x1 <= 0 or y2 - y1 <= 0: 116 | continue # ignore 117 | 118 | # fill in topological info (zero padding if number of neighbors is less than 3) 119 | topo_arrs[:min(len(knn_matrix[j]), 12), y1:y2, x1:x2] = knn_matrix[j][:, np.newaxis][:, np.newaxis] 120 | 121 | return torch.from_numpy(topo_arrs) 122 | 123 | 124 | def read_img_reverse(img, coords, types, num_types=5, grid_num=10) -> torch.Tensor: 125 | ''' 126 | Convert image with bbox predictions as into grid format 127 | :param img: image path in str or image in np.ndarray 128 | :param coords: Nx4 tensor/np.ndarray for box coords 129 | :param types: Nx1 tensor/np.ndarray for box types (logo, input etc.) 130 | :param num_types: total number of box types 131 | :param grid_num: number of grids needed 132 | :return: grid tensor 133 | ''' 134 | 135 | img = cv2.imread(img) if not isinstance(img, np.ndarray) else img 136 | coords = coords.numpy() if not isinstance(coords, np.ndarray) else coords 137 | types = types.numpy() if not isinstance(types, np.ndarray) else types 138 | 139 | # Incorrect path/empty image 140 | if img is None: 141 | raise AttributeError('Image is None') 142 | 143 | height, width = img.shape[:2] 144 | 145 | # Empty image 146 | if height == 0 or width == 0: 147 | raise AttributeError('Empty image') 148 | 149 | # grid array of shape CxHxW 150 | grid_arrs = np.zeros((4 + num_types, grid_num, grid_num)) # Must be [0, 1], use rel_x, rel_y, rel_w, rel_h 151 | 152 | for j, coord in enumerate(coords): 153 | x1, y1, x2, y2 = coord 154 | w = max(0, x2 - x1) 155 | h = max(0, y2 - y1) 156 | if w == 0 or h == 0: 157 | continue # ignore 158 | 159 | # get the assigned grid index 160 | assigned_grid_w, assigned_grid_h = int(((x1 + x2) / 2) // (width // grid_num)), int( 161 | ((y1 + y2) / 2) // (height // grid_num)) 162 | 163 | # bound above 164 | assigned_grid_w = min(grid_num - 1, assigned_grid_w) 165 | assigned_grid_h = min(grid_num - 1, assigned_grid_h) 166 | 167 | # if this grid has been assigned before, check whether need to re-assign 168 | if grid_arrs[0, assigned_grid_h, assigned_grid_w] != 0: # visted 169 | exist_type = np.where(grid_arrs[:, assigned_grid_h, assigned_grid_w] == 1)[0][0] - 4 170 | new_type = types[j] 171 | if new_type > exist_type: # if new type has lower priority than existing type 172 | continue 173 | 174 | # fill in rel_xywh 175 | grid_arrs[0, assigned_grid_h, assigned_grid_w] = float(x1 / width) 176 | grid_arrs[1, assigned_grid_h, assigned_grid_w] = float(y1 / height) 177 | grid_arrs[2, assigned_grid_h, assigned_grid_w] = float(w / width) 178 | grid_arrs[3, assigned_grid_h, assigned_grid_w] = float(h / height) 179 | 180 | # one-hot encoding for type 181 | cls_arr = np.zeros(num_types) 182 | cls_arr[types[j]] = 1 183 | 184 | grid_arrs[4:, assigned_grid_h, assigned_grid_w] = cls_arr 185 | 186 | return torch.from_numpy(grid_arrs) 187 | 188 | 189 | import torch.nn.functional as F 190 | from PIL import Image 191 | import math 192 | 193 | def resolution_alignment(img1, img2): 194 | ''' 195 | Resize two images according to the minimum resolution between the two 196 | :param img1: first image in PIL.Image 197 | :param img2: second image in PIL.Image 198 | :return: resized img1 in PIL.Image, resized img2 in PIL.Image 199 | ''' 200 | w1, h1 = img1.size 201 | w2, h2 = img2.size 202 | w_min, h_min = min(w1, w2), min(h1, h2) 203 | if w_min == 0 or h_min == 0: ## something wrong, stop resizing 204 | return img1, img2 205 | if w_min < h_min: 206 | img1_resize = img1.resize((int(w_min), math.ceil(h1 * (w_min/w1)))) # ceiling to prevent rounding to 0 207 | img2_resize = img2.resize((int(w_min), math.ceil(h2 * (w_min/w2)))) 208 | else: 209 | img1_resize = img1.resize((math.ceil(w1 * (h_min/h1)), int(h_min))) 210 | img2_resize = img2.resize((math.ceil(w2 * (h_min/h2)), int(h_min))) 211 | return img1_resize, img2_resize 212 | 213 | def brand_converter(brand_name): 214 | ''' 215 | Helper function to deal with inconsistency in brand naming 216 | ''' 217 | if brand_name == 'Adobe Inc.' or brand_name == 'Adobe Inc': 218 | return 'Adobe' 219 | elif brand_name == 'ADP, LLC' or brand_name == 'ADP, LLC.': 220 | return 'ADP' 221 | elif brand_name == 'Amazon.com Inc.' or brand_name == 'Amazon.com Inc': 222 | return 'Amazon' 223 | elif brand_name == 'Americanas.com S,A Comercio Electrnico': 224 | return 'Americanas.com S' 225 | elif brand_name == 'AOL Inc.' or brand_name == 'AOL Inc': 226 | return 'AOL' 227 | elif brand_name == 'Apple Inc.' or brand_name == 'Apple Inc': 228 | return 'Apple' 229 | elif brand_name == 'AT&T Inc.' or brand_name == 'AT&T Inc': 230 | return 'AT&T' 231 | elif brand_name == 'Banco do Brasil S.A.': 232 | return 'Banco do Brasil S.A' 233 | elif brand_name == 'Credit Agricole S.A.': 234 | return 'Credit Agricole S.A' 235 | elif brand_name == 'DGI (French Tax Authority)': 236 | return 'DGI French Tax Authority' 237 | elif brand_name == 'DHL Airways, Inc.' or brand_name == 'DHL Airways, Inc' or brand_name == 'DHL': 238 | return 'DHL Airways' 239 | elif brand_name == 'Dropbox, Inc.' or brand_name == 'Dropbox, Inc': 240 | return 'Dropbox' 241 | elif brand_name == 'eBay Inc.' or brand_name == 'eBay Inc': 242 | return 'eBay' 243 | elif brand_name == 'Facebook, Inc.' or brand_name == 'Facebook, Inc': 244 | return 'Facebook' 245 | elif brand_name == 'Free (ISP)': 246 | return 'Free ISP' 247 | elif brand_name == 'Google Inc.' or brand_name == 'Google Inc': 248 | return 'Google' 249 | elif brand_name == 'Mastercard International Incorporated': 250 | return 'Mastercard International' 251 | elif brand_name == 'Netflix Inc.' or brand_name == 'Netflix Inc': 252 | return 'Netflix' 253 | elif brand_name == 'PayPal Inc.' or brand_name == 'PayPal Inc': 254 | return 'PayPal' 255 | elif brand_name == 'Royal KPN N.V.': 256 | return 'Royal KPN N.V' 257 | elif brand_name == 'SF Express Co.': 258 | return 'SF Express Co' 259 | elif brand_name == 'SNS Bank N.V.': 260 | return 'SNS Bank N.V' 261 | elif brand_name == 'Square, Inc.' or brand_name == 'Square, Inc': 262 | return 'Square' 263 | elif brand_name == 'Webmail Providers': 264 | return 'Webmail Provider' 265 | elif brand_name == 'Yahoo! Inc' or brand_name == 'Yahoo! Inc.': 266 | return 'Yahoo!' 267 | elif brand_name == 'Microsoft OneDrive' or brand_name == 'Office365' or brand_name == 'Outlook': 268 | return 'Microsoft' 269 | elif brand_name == 'Global Sources (HK)': 270 | return 'Global Sources HK' 271 | elif brand_name == 'T-Online': 272 | return 'Deutsche Telekom' 273 | elif brand_name == 'Airbnb, Inc': 274 | return 'Airbnb, Inc.' 275 | elif brand_name == 'azul': 276 | return 'Azul' 277 | elif brand_name == 'Raiffeisen Bank S.A': 278 | return 'Raiffeisen Bank S.A.' 279 | elif brand_name == 'Twitter, Inc' or brand_name == 'Twitter': 280 | return 'Twitter, Inc.' 281 | elif brand_name == 'capital_one': 282 | return 'Capital One Financial Corporation' 283 | elif brand_name == 'la_banque_postale': 284 | return 'La Banque postale' 285 | elif brand_name == 'db': 286 | return 'Deutsche Bank AG' 287 | elif brand_name == 'Swiss Post' or brand_name == 'PostFinance': 288 | return 'PostFinance' 289 | elif brand_name == 'grupo_bancolombia': 290 | return 'Bancolombia' 291 | elif brand_name == 'barclays': 292 | return 'Barclays Bank Plc' 293 | elif brand_name == 'gov_uk': 294 | return 'Government of the United Kingdom' 295 | elif brand_name == 'Aruba S.p.A': 296 | return 'Aruba S.p.A.' 297 | elif brand_name == 'TSB Bank Plc': 298 | return 'TSB Bank Limited' 299 | elif brand_name == 'strato': 300 | return 'Strato AG' 301 | elif brand_name == 'cogeco': 302 | return 'Cogeco' 303 | elif brand_name == 'Canada Revenue Agency': 304 | return 'Government of Canada' 305 | elif brand_name == 'UniCredit Bulbank': 306 | return 'UniCredit Bank Aktiengesellschaft' 307 | elif brand_name == 'ameli_fr': 308 | return 'French Health Insurance' 309 | elif brand_name == 'Banco de Credito del Peru': 310 | return 'bcp' 311 | else: 312 | return brand_name 313 | 314 | def l2_norm(x): 315 | """ 316 | l2 normalization 317 | :param x: 318 | :return: 319 | """ 320 | if len(x.shape): 321 | x = x.reshape((x.shape[0], -1)) 322 | return F.normalize(x, p=2, dim=1) -------------------------------------------------------------------------------- /scripts/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Union, List, Optional, Dict, Any 3 | from PIL import Image 4 | import io 5 | import base64 6 | import numpy as np 7 | from numpy.typing import ArrayLike, NDArray 8 | from typing import Sequence, Tuple, Union 9 | Number = Union[int, float] 10 | 11 | '''prompt utils''' 12 | def image2base64(image: Union[str, Image.Image]) -> str: 13 | if isinstance(image, str): 14 | image = Image.open(image) 15 | img_byte_arr = io.BytesIO() 16 | image.save(img_byte_arr, format='PNG') # Ensure the format matches your image format, e.g., JPEG, PNG, etc. 17 | img_bytes = img_byte_arr.getvalue() 18 | base64_encoded = base64.b64encode(img_bytes).decode('utf-8') # Convert bytes to base64 string and decode to UTF-8 19 | return base64_encoded 20 | 21 | def prepare_candidate_uis( 22 | candidate_uis_imgs: Sequence[Union[str, Image.Image]], 23 | candidate_uis_text: Sequence[str] 24 | ) -> Sequence[Dict[str, Any]]: 25 | 26 | candidate_uis_json = [] 27 | for ind, (img, text) in enumerate(zip(candidate_uis_imgs, candidate_uis_text)): 28 | base64_image = image2base64(img) 29 | candidate_uis_json.append({"type": "text", 30 | "text": f'Index {ind}: ' + text} 31 | ) 32 | candidate_uis_json.append({"type": "image_url", 33 | "image_url": { 34 | "url": f"data:image/jpeg;base64,{base64_image}"} 35 | } 36 | ) 37 | 38 | return candidate_uis_json 39 | 40 | def vlm_question_template_transition( 41 | candidate_uis_imgs: Sequence[Union[str, Image.Image]], 42 | candidate_uis_text: Sequence[str] 43 | ) -> Dict[str, Any]: 44 | candidate_uis_json = prepare_candidate_uis(candidate_uis_imgs, candidate_uis_text) 45 | 46 | return { 47 | "role": "user", 48 | "content": candidate_uis_json 49 | } 50 | 51 | 52 | def vlm_question_template_prediction(screenshot_img: Image.Image) -> Dict[str, Any]: 53 | return \ 54 | {"role": "user", 55 | "content": [ 56 | {"type": "text", 57 | "text": "Given the HTML webpage screenshot, Question: A. This is a credential-requiring page. B. This is not a credential-requiring page. \n Answer:"}, 58 | { 59 | "type": "image_url", 60 | "image_url": { 61 | "url": f"data:image/jpeg;base64,{image2base64(screenshot_img)}" 62 | }, 63 | }, 64 | ] 65 | } 66 | 67 | 68 | def vlm_question_template_brand(logo_img: Image.Image) -> Dict[str, Any]: 69 | return \ 70 | {"role": "user", 71 | "content": [ 72 | {"type": "text", 73 | "text": "Given the brand's logo, Question: What is the brand's domain? Answer: "}, 74 | { 75 | "type": "image_url", 76 | "image_url": { 77 | "url": f"data:image/jpeg;base64,{image2base64(logo_img)}" 78 | }, 79 | }, 80 | ] 81 | } 82 | 83 | 84 | '''bounding box utils''' 85 | def pairwise_intersect_area( 86 | bboxes1: ArrayLike, 87 | bboxes2: ArrayLike, 88 | ) -> NDArray[np.float32]: 89 | # Convert bboxes lists to 3D arrays 90 | bboxes1 = np.array(bboxes1)[:, np.newaxis, :] 91 | bboxes2 = np.array(bboxes2) 92 | 93 | # Compute overlap for x and y axes separately 94 | overlap_x = np.maximum(0, np.minimum(bboxes1[:, :, 2], bboxes2[:, 2]) - np.maximum(bboxes1[:, :, 0], bboxes2[:, 0])) 95 | overlap_y = np.maximum(0, np.minimum(bboxes1[:, :, 3], bboxes2[:, 3]) - np.maximum(bboxes1[:, :, 1], bboxes2[:, 1])) 96 | 97 | # Compute overlapping areas for each pair 98 | overlap_areas = overlap_x * overlap_y 99 | return overlap_areas 100 | 101 | def expand_bbox( 102 | bbox: Sequence[Number], 103 | image_width: int, 104 | image_height: int, 105 | expand_ratio: Union[Number, Tuple[Number, Number]] = 1.2, 106 | ) -> list[Number]: 107 | # Extract the coordinates 108 | x1, y1, x2, y2 = bbox 109 | 110 | # Calculate the center 111 | center_x = (x1 + x2) / 2 112 | center_y = (y1 + y2) / 2 113 | 114 | # Calculate new width and height 115 | new_width = (x2 - x1) * expand_ratio 116 | new_height = (y2 - y1) * expand_ratio 117 | 118 | # Determine new coordinates 119 | new_x1 = center_x - new_width / 2 120 | new_y1 = center_y - new_height / 2 121 | new_x2 = center_x + new_width / 2 122 | new_y2 = center_y + new_height / 2 123 | 124 | # Ensure coordinates are legitimate 125 | new_x1 = max(0, new_x1) 126 | new_y1 = max(0, new_y1) 127 | new_x2 = min(image_width, new_x2) 128 | new_y2 = min(image_height, new_y2) 129 | 130 | return [new_x1, new_y1, new_x2, new_y2] 131 | 132 | class Regexes(): 133 | # e-mail 134 | EMAIL = r"(e(\-|_|\s)*)?mail(?!(\-|\_|\s)*(password|passwd|pass word|passcode|passwort))" 135 | 136 | # password 137 | PASSWORD = "password|passwd|pass word|passcode|passwort" 138 | # username 139 | USERNAME = "(u(s(e)?r)?|nick|display|profile)(\-|_|\s)*name" 140 | USERID = "^((u(s(e)?r)?|nick|display|profile|customer)(\-|_|\s)*)?id|identifi(ant|er)?|access(\-|_|\s)*code|account" 141 | 142 | # misc identifiers 143 | FULL_NAME = "full(\-|_|\s)?(name|nm|nom)|(celé jméno)" 144 | FIRST_NAME = "(f(irst|ore)?|m(iddle)?|pre)(\-|_|\s)*(name|nm|nom)" 145 | LAST_NAME = "(l(ast|st)?|s(u)?(r)?)(\-|_|\s)*(name|nm|nom)" 146 | NAME_PREFIX = "prefix" 147 | 148 | # Phones 149 | PHONE_AREA = "phone(\-|_|\s)*area|area(\-|_|\s)*code|phone(\-|_|\s)*(pfx|prefix|prfx)" 150 | PHONE = "mobile|phone|telephone|tel" 151 | 152 | # Dates 153 | MONTH = "month" 154 | DAY = "day" 155 | YEAR = "year" 156 | BIRTHDATE = "date|dob|birthdate|birthday|date(\-|_|\s)*of(\-|_|\s)*birth" 157 | 158 | # gender 159 | AGE = "(\-|_|\s)+age(\-|_|\s)+" 160 | GENDER = "gender|sex" 161 | 162 | # profile pics 163 | # FILE = "photo|picture" 164 | SMS = 'sms' 165 | 166 | # Addresses 167 | ADDRESS = "address" 168 | ZIPCODE = "(post(al)?|zip)(\-|_|\s)*(code|no|num)?" 169 | CITY = "city|town|location" 170 | COUNTRY = "countr" 171 | STATE = "stat|province" 172 | STREET = "street" 173 | BUILDING_NO = "(building|bldng|flat|apartment|apt|home|house)(\-|_|\s)*(num|no)" 174 | # SSN etc. 175 | SSN = "(ssn|vat|social(\-|_|\s)*sec(urity)?(\-|_|\s)*(num|no)?)" 176 | 177 | # Credit cards 178 | CREDIT_CARD = "(xxxx xxxx xxxx xxxx)|(0000 0000 0000 0000)|(Número de tarjeta)|(Číslo karty)|(cc(\-|_|\s)*(no|num))|(card(\-|_|\s)*(no|num))|(credit(\-|_|\s)*(no|num|card))|(card$)" 179 | CREDIT_CARD_EXPIRE = "expire|expiration|expiry|expdate|((cc|card|credit)(\-|_|\s)*date)|^exp$" 180 | CREDIT_CARD_CVV = "(sec(urity)?(\-|_|\s)*)?(cvv|csc|cvn)" 181 | ATMPIN = "atmpin|pin" 182 | 183 | # Company stuff 184 | COMPANY_NAME = "company|organi(z|s)ation|institut(e|ion)" 185 | #### END SPECIFIC REGEXES - START GENERIC #### 186 | # NUMBER_COARSE = "num|code" 187 | USERNAME_COARSE = "us(e)?r|login" 188 | 189 | OTHER_FORM = "link|search" 190 | 191 | SSO_SIGNUP_BUTTONS = "((create|register|make)|(new))\s*(new\s*)?(user|account|profile)" 192 | 193 | VERIFY_ACCOUNT = "((verify|activate)(\syour)?\s(account|e(-|\s)*mail|info))|((verification|activation) (e(-|\s)*mail|message|link|code|number))" 194 | VERIFIED_ACCOUNT = "(user(-|\s))?(account|profile)\s+(was|is|has)?(been)?(verified|activated|attivo)|(verification|activation)\s+(was|is|has)(been)?\s+(completed|done|successful)?" 195 | VERIFY_VERBS = "verify|activate" 196 | 197 | IDENTIFIERS = "%s|%s|%s|%s|%s" % (FULL_NAME, FIRST_NAME, LAST_NAME, USERNAME, EMAIL) 198 | IDENTIFIERS_NO_EMAIL = "%s|%s|%s|%s" % (FULL_NAME, FIRST_NAME, LAST_NAME, USERNAME) 199 | 200 | SUBMIT = "submit" 201 | LOGIN = "(log|sign)([^0-9a-zA-Z]|\s)*(in|on)|authenticat(e|ion)|/(my([^0-9a-zA-Z]|\s)*)?(user|account|profile|dashboard)" 202 | SIGNUP = "sign([^0-9a-zA-Z]|\s)*up|regist(er|ration)?|(create|new)([^0-9a-zA-Z]|\s)*(new([^0-9a-zA-Z]|\s)*)?(acc(ount)?|us(e)?r|prof(ile)?)|(forg(et|ot)|reset)([^0-9a-zA-Z]|\s)*((my|the)([^0-9a-zA-Z]|\s)*)?(acc(ount)?|us(e)?r|prof(ile)?|password)" 203 | SSO = "[^0-9a-zA-Z]+sso[^0-9a-zA-Z]+|oauth|openid" 204 | AUTH = "%s|%s|%s|%s|%s|auth|(new|existing)([^0-9a-zA-Z]|\s)*(us(e)?r|acc(ount)?)|account|connect|profile|dashboard|next" % (LOGIN, SIGNUP, SSO, SUBMIT, VERIFY_VERBS) 205 | LOGOUT = "(log|sign)(-|_|\s)*(out|off)" 206 | BUTTON = "suivant|make([^0-9a-zA-Z]|\s)*payment|^OK$|go([^0-9a-zA-Z]|\s)*(in)?to|sign([^0-9a-zA-Z]|\s)*in(?! with| via| using)|log([^0-9a-zA-Z]|\s)*in(?! with| via| using)|log([^0-9a-zA-Z]|\s)*on(?! with| via| using)|verify(?! with| via| using)|verification|submit(?! with| via| using)|ent(er|rar|rer|rance|ra)(?! with| via| using)|acces(o|sar|s)(?! with| via| using)|continu(er|ar)?(?! with| via| using)|connect(er)?(?! with| via| using)|next|confirm|sign([^0-9a-zA-Z]|\s)*on(?! with| via| using)|complete|valid(er|ate)(?! with| via| using)|securipass|登入|登录|登錄|登録|签到|iniciar([^0-9a-zA-Z]|\s)*sesión|identifier|ログインする|サインアップ|ログイン|로그인|시작하기|войти|вход|accedered|gabung|masuk|girişi|Giriş|เข้าสู่ระบบ|Přihlásit|mein([^0-9a-zA-Z]|\s)*konto|anmelden|ingresa|accedi|мой([^0-9a-zA-Z]|\s)*профиль|حسابي|administrer|cadastre-se|είσοδος|accessibilité|accéder|zaloguj|đăng([^0-9a-zA-Z]|\s)*nhập|weitermachen|bestätigen|zověřit|ověřit|weiter" 207 | BUTTON_FORBIDDEN = "single sign-on|guest|here we go|seek|looking for|explore|save|clear|wipe off|(^[0-9]+$)|(^x$)|close|search|(sign|log|verify|submit|ent(er|rar|rer|rance|ra)|acces(o|sar|s)|continu(er|ar)?)?.*(github|microsoft|facebook|google|twitter|linkedin|instagram|line)|keep([^0-9a-zA-Z]|\s)*me([^0-9a-zA-Z]|\s)*(signed|logged)([^0-9a-zA-Z]|\s)*(in|on)|having([^0-9a-zA-Z]|\s)*trouble|remember|subscribe|send([^0-9a-zA-Z]|\s)*me([^0-9a-zA-Z]|\s)*(message|(e)?mail|newsletter|update)|follow([^0-9a-zA-Z]|\s)*us|新規会員|%s" % SIGNUP 208 | # CREDENTIAL_TAKING_KEYWORDS = "log(g)?([^0-9a-zA-Z]|\s)*in(n)?|log([^0-9a-zA-Z]|\s)*on|sign([^0-9a-zA-Z]|\s)*in|sign([^0-9a-zA-Z]|\s)*on|submit|(my|personal)([^0-9a-zA-Z]|\s)*(account|area)|come([^0-9a-zA-Z]|\s)*in|check([^0-9a-zA-Z]|\s)*in|customer([^0-9a-zA-Z]|\s)*centre|登入|登录|登錄|登録|iniciar([^0-9a-zA-Z]|\s)*sesión|identifier|(ログインする)|(サインアップ)|(ログイン)|(로그인)|(시작하기)|(войти)|(вход)|(accedered)|(gabung)|(masuk)|(girişi)|(Giriş)|(وارد)|(عضویت)|(acceso)|(acessar)|(entrar )|(เข้าสู่ระบบ)|(Přihlásit)|(mein konto)|(anmelden)|(me connecter)|(ingresa)|(accedi)|(мой профиль)|(حسابي)|(administrer)|(next)|(entre )|(cadastre-se)|(είσοδος)|(entrance)|(start now)|(accessibilité)|(accéder)|(zaloguj)|(đăng nhập)|weitermachen|bestätigen|zověřit|ověřit" 209 | CREDENTIAL_TAKING_KEYWORDS = r""" 210 | (?: 211 | log(?:g)?in| # Matches 'login', 'loggin' 212 | log(?:g)?on| # Matches 'logon', 'loggon' 213 | sign(?:-|\s)?(?:in|on)| # Matches 'sign in', 'sign on', 'signin', 'signon', 'sign-in', 'sign-on' 214 | submit|apply|continue|update| 215 | (?:my|personal)(?:\W+)(?:account|area)| 216 | come(?:\W+)in| # Matches 'come in' with any non-word delimiters 217 | customer(?:\W+)centre| # Matches 'customer centre' with any non-word delimiters 218 | identifier| 219 | (?:get(?:\W+)started) # Matches 'get started' with any non-word delimiters 220 | ) 221 | | # Alternatives in different languages 222 | 登入|登录|登錄|登録| 223 | iniciar(?:\W+)sesión| 224 | (?:ログインする)|(?:サインアップ)|(?:ログイン)| 225 | (?:로그인)|(?:시작하기)| 226 | (?:войти)|(?:вход)| 227 | (?:acceder(?:\W+)ed)|(?:gabung)|(?:masuk)| 228 | (?:giriş(?:i)?)|(?:وارد)|(?:عضویت)| 229 | (?:acceso)|(?:acessar)|(?:entrar)| 230 | (?:เข้าสู่ระบบ)|(?:Přihlásit)| 231 | (?:mein konto)|(?:anmelden)|(?:me connecter)| 232 | (?:ingresa)|(?:accedi)|(?:мой профиль)| 233 | (?:حسابي)|(?:administrer)|(?:next)| 234 | (?:entre)|(?:cadastre-se)|(?:είσοδος)| 235 | (?:entrance)|(?:start now)|(?:accessibilité)| 236 | (?:accéder)|(?:zaloguj)|(?:đăng nhập)| 237 | weitermachen|bestätigen|zověřit|ověřit 238 | """.strip() 239 | PROFILE = "account|profile|dashboard|settings" 240 | 241 | CAPTCHA = "(re)?captcha" 242 | CONSENT = "consent|gdp" 243 | COOKIES_CONSENT = "agree|accept" 244 | 245 | URL = "(?:(?:https?|ftp)://)(?:\S+(?::\S*)?@)?(?:(?!10(?:\.\d{1,3}){3})(?!127(?:\.\d{1,3}){3})(?!169\.254(?:\.\d{1,3}){2})(?!192\.168(?:\.\d{1,3}){2})(?!172\.(?:1[6-9]|2\d|3[0-1])(?:\.\d{1,3}){2})(?:[1-9]\d?|1\d\d|2[01]\d|22[0-3])(?:\.(?:1?\d{1,2}|2[0-4]\d|25[0-5])){2}(?:\.(?:[1-9]\d?|1\d\d|2[0-4]\d|25[0-4]))|(?:(?:[a-z\\x{00a1}\-\\x{ffff}0-9]+-?)*[a-z\\x{00a1}\-\\x{ffff}0-9]+)(?:\.(?:[a-z\\x{00a1}\-\\x{ffff}0-9]+-?)*[a-z\\x{00a1}\-\\x{ffff}0-9]+)*(?:\.(?:[a-z\\x{00a1}\-\\x{ffff}]{2,})))(?::\d{2,5})?(?:/[^\s]*)?" 246 | 247 | TIME = "([0-9]:){1,2}[0-9]" 248 | TIME_SCRIPT = "setHours|setMinutes|setSeconds" 249 | 250 | # try again error 251 | ERROR_TRY_AGAIN = ["try again|login failed|error logging in|login error|retry"] 252 | 253 | # username incorrect/not exist error 254 | ERROR_INCORRECT = ["(invalid|wrong|incorrect|unknown|no).*(id|credential|login|input|password|account|user(name)?|e(\-|_|\s)?mail|information|(pass)?code|(user([^0-9a-zA-Z]|\s)*)?id)(s)?", 255 | "(do(es)?|did)([^0-9a-zA-Z]|\s)*not match(([^0-9a-zA-Z]|\s)*our records)?", 256 | "limited access|verification failed|not registered|does not exist|access denied|coundn't find|you entered([^0-9a-zA-Z]|\s)*(isn't|doesn't)|(please)?([^0-9a-zA-Z]|\s)*enter a valid", 257 | "(account|password|user(name)?|e(\-|_|\s)?mail|credentials|sms|code)([^0-9a-zA-Z]|\s)*(provided|given|input([^0-9a-zA-Z]|\s)*)?((is incorrect)|(are incorrect)|(isn't right)|(isn't correct)|(doesn't exist)|(does not exist)|(not valid)|(is invalid)|(not recognized)|(were not found))", 258 | "(SMS-Code Fehler)|(SMS kód je neplatný)", 259 | "code incorrectly|no user found|username already taken", 260 | "(cannot|can't) be used|not allowed|must (contain|follow|specify)" 261 | "captcha was not answered correctly" 262 | ] 263 | 264 | # connection error 265 | ERROR_CONNECTION = ["connecting([^0-9a-zA-Z]|\s)*(to)?([^0-9a-zA-Z]|\s)*(mail)?([^0-9a-zA-Z]|\s)*server|connection is lost", 266 | # "(operation|page)([^0-9a-zA-Z]|\s)*((counldn't)|(could not)|cannot|(can not))([^0-9a-zA-Z]|\s)*be([^0-9a-zA-Z]|\s)*(completed|found)", 267 | # 'not found|forbidden|403|404|500|no permission|don\'t have permission' 268 | ] 269 | 270 | # File related 271 | ERROR_FILE = ["processing([^0-9a-zA-Z]|\s)*(your)?([^0-9a-zA-Z]|\s)*download", 272 | "file not found"] 273 | 274 | # anti-bot 275 | ERROR_BOT = ["(not a human)|captcha|(verify you are a human)|(press & hold)"] 276 | 277 | 278 | 279 | 280 | -------------------------------------------------------------------------------- /scripts/phishintention/modules/logo_matching.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageOps 2 | from torchvision import transforms 3 | import torch 4 | from torch.backends import cudnn 5 | import os 6 | import numpy as np 7 | from collections import OrderedDict 8 | from tqdm import tqdm 9 | from tldextract import tldextract 10 | import pickle 11 | 12 | from ..utils.utils import brand_converter, resolution_alignment, l2_norm 13 | from .models2 import KNOWN_MODELS 14 | from ..ocr_lib.models.model_builder import ModelBuilder 15 | from ..ocr_lib.utils.labelmaps import get_vocabulary 16 | 17 | COUNTRY_TLDs = [ 18 | ".af", 19 | ".ax", 20 | ".al", 21 | ".dz", 22 | ".as", 23 | ".ad", 24 | ".ao", 25 | ".ai", 26 | ".aq", 27 | ".ag", 28 | ".ar", 29 | ".am", 30 | ".aw", 31 | ".ac", 32 | ".au", 33 | ".at", 34 | ".az", 35 | ".bs", 36 | ".bh", 37 | ".bd", 38 | ".bb", 39 | ".eus", 40 | ".by", 41 | ".be", 42 | ".bz", 43 | ".bj", 44 | ".bm", 45 | ".bt", 46 | ".bo", 47 | ".bq",".an",".nl", 48 | ".ba", 49 | ".bw", 50 | ".bv", 51 | ".br", 52 | ".io", 53 | ".vg", 54 | ".bn", 55 | ".bg", 56 | ".bf", 57 | ".mm", 58 | ".bi", 59 | ".kh", 60 | ".cm", 61 | ".ca", 62 | ".cv", 63 | ".cat", 64 | ".ky", 65 | ".cf", 66 | ".td", 67 | ".cl", 68 | ".cn", 69 | ".cx", 70 | ".cc", 71 | ".co", 72 | ".km", 73 | ".cd", 74 | ".cg", 75 | ".ck", 76 | ".cr", 77 | ".ci", 78 | ".hr", 79 | ".cu", 80 | ".cw", 81 | ".cy", 82 | ".cz", 83 | ".dk", 84 | ".dj", 85 | ".dm", 86 | ".do", 87 | ".tl",".tp", 88 | ".ec", 89 | ".eg", 90 | ".sv", 91 | ".gq", 92 | ".er", 93 | ".ee", 94 | ".et", 95 | ".eu", 96 | ".fk", 97 | ".fo", 98 | ".fm", 99 | ".fj", 100 | ".fi", 101 | ".fr", 102 | ".gf", 103 | ".pf", 104 | ".tf", 105 | ".ga", 106 | ".gal", 107 | ".gm", 108 | ".ps", 109 | ".ge", 110 | ".de", 111 | ".gh", 112 | ".gi", 113 | ".gr", 114 | ".gl", 115 | ".gd", 116 | ".gp", 117 | ".gu", 118 | ".gt", 119 | ".gg", 120 | ".gn", 121 | ".gw", 122 | ".gy", 123 | ".ht", 124 | ".hm", 125 | ".hn", 126 | ".hk", 127 | ".hu", 128 | ".is", 129 | ".in", 130 | ".id", 131 | ".ir", 132 | ".iq", 133 | ".ie", 134 | ".im", 135 | ".il", 136 | ".it", 137 | ".jm", 138 | ".jp", 139 | ".je", 140 | ".jo", 141 | ".kz", 142 | ".ke", 143 | ".ki", 144 | ".kw", 145 | ".kg", 146 | ".la", 147 | ".lv", 148 | ".lb", 149 | ".ls", 150 | ".lr", 151 | ".ly", 152 | ".li", 153 | ".lt", 154 | ".lu", 155 | ".mo", 156 | ".mk", 157 | ".mg", 158 | ".mw", 159 | ".my", 160 | ".mv", 161 | ".ml", 162 | ".mt", 163 | ".mh", 164 | ".mq", 165 | ".mr", 166 | ".mu", 167 | ".yt", 168 | ".mx", 169 | ".md", 170 | ".mc", 171 | ".mn", 172 | ".me", 173 | ".ms", 174 | ".ma", 175 | ".mz", 176 | ".mm", 177 | ".na", 178 | ".nr", 179 | ".np", 180 | ".nl", 181 | ".nc", 182 | ".nz", 183 | ".ni", 184 | ".ne", 185 | ".ng", 186 | ".nu", 187 | ".nf", 188 | ".nc",".tr", 189 | ".kp", 190 | ".mp", 191 | ".no", 192 | ".om", 193 | ".pk", 194 | ".pw", 195 | ".ps", 196 | ".pa", 197 | ".pg", 198 | ".py", 199 | ".pe", 200 | ".ph", 201 | ".pn", 202 | ".pl", 203 | ".pt", 204 | ".pr", 205 | ".qa", 206 | ".ro", 207 | ".ru", 208 | ".rw", 209 | ".re", 210 | ".bq",".an", 211 | ".bl",".gp",".fr", 212 | ".sh", 213 | ".kn", 214 | ".lc", 215 | ".mf",".gp",".fr", 216 | ".pm", 217 | ".vc", 218 | ".ws", 219 | ".sm", 220 | ".st", 221 | ".sa", 222 | ".sn", 223 | ".rs", 224 | ".sc", 225 | ".sl", 226 | ".sg", 227 | ".bq",".an",".nl", 228 | ".sx",".an", 229 | ".sk", 230 | ".si", 231 | ".sb", 232 | ".so", 233 | ".so", 234 | ".za", 235 | ".gs", 236 | ".kr", 237 | ".ss", 238 | ".es", 239 | ".lk", 240 | ".sd", 241 | ".sr", 242 | ".sj", 243 | ".sz", 244 | ".se", 245 | ".ch", 246 | ".sy", 247 | ".tw", 248 | ".tj", 249 | ".tz", 250 | ".th", 251 | ".tg", 252 | ".tk", 253 | ".to", 254 | ".tt", 255 | ".tn", 256 | ".tr", 257 | ".tm", 258 | ".tc", 259 | ".tv", 260 | ".ug", 261 | ".ua", 262 | ".ae", 263 | ".uk", 264 | ".us", 265 | ".vi", 266 | ".uy", 267 | ".uz", 268 | ".vu", 269 | ".va", 270 | ".ve", 271 | ".vn", 272 | ".wf", 273 | ".eh", 274 | ".ma", 275 | ".ye", 276 | ".zm", 277 | ".zw" 278 | ] 279 | 280 | class DataInfo(object): 281 | """ 282 | Save the info about the dataset. 283 | This a code snippet from dataset.py 284 | """ 285 | def __init__(self, voc_type): 286 | super(DataInfo, self).__init__() 287 | self.voc_type = voc_type 288 | 289 | assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS'] 290 | self.EOS = 'EOS' 291 | self.PADDING = 'PADDING' 292 | self.UNKNOWN = 'UNKNOWN' 293 | self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN) 294 | self.char2id = dict(zip(self.voc, range(len(self.voc)))) 295 | self.id2char = dict(zip(range(len(self.voc)), self.voc)) 296 | 297 | self.rec_num_classes = len(self.voc) 298 | 299 | def ocr_model_config(weights_path, height=None, width=None): 300 | np.random.seed(1234) 301 | torch.manual_seed(1234) 302 | torch.cuda.manual_seed(1234) 303 | torch.cuda.manual_seed_all(1234) 304 | cudnn.benchmark = True 305 | torch.backends.cudnn.deterministic = True 306 | 307 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 308 | if device == 'cuda': 309 | print('using cuda.') 310 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 311 | else: 312 | torch.set_default_tensor_type('torch.FloatTensor') 313 | 314 | # Create data loaders 315 | if height is None or width is None: 316 | height, width = (32, 100) 317 | 318 | dataset_info = DataInfo('ALLCASES_SYMBOLS') 319 | 320 | # Create model 321 | model = ModelBuilder(arch='ResNet_ASTER', rec_num_classes=dataset_info.rec_num_classes, 322 | sDim=512, attDim=512, max_len_labels=100, 323 | eos=dataset_info.char2id[dataset_info.EOS], STN_ON=True) 324 | 325 | # Load from checkpoint 326 | weights_path = torch.load(weights_path, map_location='cpu') 327 | model.load_state_dict(weights_path['state_dict']) 328 | 329 | if device == 'cuda': 330 | model = model.to(device) 331 | 332 | return model 333 | 334 | def siamese_model_config(num_classes: int, weights_path: str): 335 | # Initialize model 336 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 337 | model = KNOWN_MODELS["BiT-M-R50x1"](head_size=num_classes, zero_head=True) 338 | 339 | # Load weights 340 | weights = torch.load(weights_path, map_location='cpu') 341 | weights = weights['model'] if 'model' in weights.keys() else weights 342 | new_state_dict = OrderedDict() 343 | for k, v in weights.items(): 344 | if k.startswith('module'): 345 | name = k.split('module.')[1] 346 | else: 347 | name = k 348 | new_state_dict[name] = v 349 | 350 | model.load_state_dict(new_state_dict) 351 | model.to(device) 352 | model.eval() 353 | 354 | return model 355 | 356 | 357 | def image_process(image_path, imgH=32, imgW=100, keep_ratio=False, min_ratio=1): 358 | img = Image.open(image_path).convert('RGB') if isinstance(image_path, str) else image_path.convert('RGB') 359 | 360 | if keep_ratio: 361 | w, h = img.size 362 | ratio = w / float(h) 363 | imgW = int(np.floor(ratio * imgH)) 364 | imgW = max(imgH * min_ratio, imgW) 365 | 366 | img = img.resize((imgW, imgH), Image.BILINEAR) 367 | img = transforms.ToTensor()(img) 368 | img.sub_(0.5).div_(0.5) 369 | 370 | return img 371 | 372 | 373 | def ocr_main(image_path, model, height=None, width=None): 374 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 375 | # Evaluation 376 | model.eval() 377 | 378 | img = image_process(image_path) 379 | with torch.no_grad(): 380 | img = img.to(device) 381 | input_dict = {} 382 | input_dict['images'] = img.unsqueeze(0) 383 | 384 | dataset_info = DataInfo('ALLCASES_SYMBOLS') 385 | rec_targets = torch.IntTensor(1, 100).fill_(1) 386 | rec_targets[:, 100 - 1] = dataset_info.char2id[dataset_info.EOS] 387 | input_dict['rec_targets'] = rec_targets.to(device) 388 | input_dict['rec_lengths'] = [100] 389 | 390 | with torch.no_grad(): 391 | features, decoder_feat = model.features(input_dict) 392 | features = features.detach().cpu() 393 | decoder_feat = decoder_feat.detach().cpu() 394 | features = torch.mean(features, dim=1) 395 | 396 | return features 397 | 398 | @torch.no_grad() 399 | def get_ocr_aided_siamese_embedding(img, model, ocr_model, grayscale=False): 400 | ''' 401 | Inference for a single image 402 | :param img: image path in str or image in PIL.Image 403 | :param model: Siamese model to make inference 404 | :param ocr_model: OCR model 405 | :param imshow: enable display of image or not 406 | :param title: title of displayed image 407 | :param grayscale: convert image to grayscale or not 408 | :return feature embedding of shape (2048,) 409 | ''' 410 | img_size = 224 411 | mean = [0.5, 0.5, 0.5] 412 | std = [0.5, 0.5, 0.5] 413 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 414 | 415 | img_transforms = transforms.Compose( 416 | [transforms.ToTensor(), 417 | transforms.Normalize(mean=mean, std=std), 418 | ]) 419 | 420 | img = Image.open(img) if isinstance(img, str) else img 421 | img = img.convert("RGBA").convert("L").convert("RGB") if grayscale else img.convert("RGBA").convert("RGB") 422 | 423 | ## Resize the image while keeping the original aspect ratio 424 | pad_color = 255 if grayscale else (255, 255, 255) 425 | img = ImageOps.expand(img, ( 426 | (max(img.size) - img.size[0]) // 2, (max(img.size) - img.size[1]) // 2, 427 | (max(img.size) - img.size[0]) // 2, (max(img.size) - img.size[1]) // 2), fill=pad_color) 428 | 429 | img = img.resize((img_size, img_size)) 430 | 431 | # Predict the embedding 432 | # get ocr embedding from pretrained paddleOCR 433 | with torch.no_grad(): 434 | ocr_emb = ocr_main(image_path=img, model=ocr_model, height=None, width=None) 435 | ocr_emb = ocr_emb[0] 436 | ocr_emb = ocr_emb[None, ...].to(device) # remove batch dimension 437 | 438 | # Predict the embedding 439 | with torch.no_grad(): 440 | img = img_transforms(img) 441 | img = img[None, ...].to(device) 442 | logo_feat = model.features(img, ocr_emb) 443 | logo_feat = l2_norm(logo_feat).squeeze(0).cpu().numpy() # L2-normalization final shape is (2560,) 444 | 445 | return logo_feat 446 | 447 | def chunked_dot(logo_feat_list, img_feat, chunk_size=128): 448 | sim_list = [] 449 | 450 | for start in range(0, logo_feat_list.shape[0], chunk_size): 451 | end = start + chunk_size 452 | chunk = logo_feat_list[start:end] 453 | sim_chunk = np.dot(chunk, img_feat.T) # shape: (chunk_size, M) 454 | sim_list.extend(sim_chunk) 455 | 456 | return sim_list 457 | 458 | def pred_brand(model, ocr_model, domain_map, logo_feat_list, file_name_list, shot_path: str, pred_bbox, t_s, grayscale=False): 459 | ''' 460 | Return predicted brand for one cropped image 461 | :param model: model to use 462 | :param domain_map: brand-domain dictionary 463 | :param logo_feat_list: reference logo feature embeddings 464 | :param file_name_list: reference logo paths 465 | :param shot_path: path to the screenshot 466 | :param pred_bbox: 1x4 np.ndarray/list/tensor bounding box coords 467 | :param t_s: similarity threshold for siamese 468 | :param grayscale: convert image(cropped) to grayscale or not 469 | :return: predicted target, predicted target's domain 470 | ''' 471 | 472 | try: 473 | img = Image.open(shot_path) 474 | except OSError: # if the image cannot be identified, return nothing 475 | print('Screenshot cannot be open') 476 | return None, None, None 477 | 478 | ## get predicted box --> crop from screenshot 479 | cropped = img.crop((pred_bbox[0], pred_bbox[1], pred_bbox[2], pred_bbox[3])) 480 | img_feat = get_ocr_aided_siamese_embedding(cropped, model, ocr_model, grayscale=grayscale) 481 | 482 | ## get cosine similarity with every protected logo 483 | sim_list = chunked_dot(logo_feat_list, img_feat) # take dot product for every pair of embeddings (Cosine Similarity) 484 | pred_brand_list = file_name_list 485 | 486 | assert len(sim_list) == len(pred_brand_list) 487 | 488 | ## get top 3 brands 489 | idx = np.argsort(sim_list)[::-1][:3] 490 | pred_brand_list = np.array(pred_brand_list)[idx] 491 | sim_list = np.array(sim_list)[idx] 492 | 493 | # top1,2,3 candidate logos 494 | top3_logolist = [Image.open(x) for x in pred_brand_list] 495 | top3_brandlist = [brand_converter(os.path.basename(os.path.dirname(x))) for x in pred_brand_list] 496 | top3_domainlist = [domain_map[x] for x in top3_brandlist] 497 | top3_simlist = sim_list 498 | 499 | for j in range(3): 500 | predicted_brand, predicted_domain = None, None 501 | 502 | ## If we are trying those lower rank logo, the predicted brand of them should be the same as top1 logo, otherwise might be false positive 503 | if top3_brandlist[j] != top3_brandlist[0]: 504 | continue 505 | 506 | ## If the largest similarity exceeds threshold 507 | if top3_simlist[j] >= t_s: 508 | predicted_brand = top3_brandlist[j] 509 | predicted_domain = top3_domainlist[j] 510 | final_sim = top3_simlist[j] 511 | 512 | ## Else if not exceed, try resolution alignment, see if can improve 513 | else: 514 | cropped, candidate_logo = resolution_alignment(cropped, top3_logolist[j]) 515 | img_feat = get_ocr_aided_siamese_embedding(cropped, model, ocr_model, grayscale=grayscale) 516 | logo_feat = get_ocr_aided_siamese_embedding(candidate_logo, model, ocr_model, grayscale=grayscale) 517 | final_sim = logo_feat.dot(img_feat) 518 | if final_sim >= t_s: 519 | predicted_brand = top3_brandlist[j] 520 | predicted_domain = top3_domainlist[j] 521 | else: 522 | break # no hope, do not try other lower rank logos 523 | 524 | ## If there is a prediction, do aspect ratio check 525 | if predicted_brand is not None: 526 | ratio_crop = cropped.size[0] / cropped.size[1] 527 | ratio_logo = top3_logolist[j].size[0] / top3_logolist[j].size[1] 528 | # aspect ratios of matched pair must not deviate by more than factor of 2.5 529 | if max(ratio_crop, ratio_logo) / min(ratio_crop, ratio_logo) > 2.5: 530 | continue # did not pass aspect ratio check, try other 531 | # If pass aspect ratio check, report a match 532 | else: 533 | return predicted_brand, predicted_domain, final_sim 534 | 535 | return None, None, top3_simlist[0] 536 | 537 | def cache_reference_list(model, ocr_model, targetlist_path: str, grayscale=False): 538 | ''' 539 | cache the embeddings of the reference list 540 | ''' 541 | 542 | # Prediction for targetlists 543 | logo_feat_list = [] 544 | file_name_list = [] 545 | 546 | for target in tqdm(os.listdir(targetlist_path)): 547 | if target.startswith('.'): # skip hidden files 548 | continue 549 | for logo_path in os.listdir(os.path.join(targetlist_path, target)): 550 | if logo_path.endswith('.png') or logo_path.endswith('.jpeg') or logo_path.endswith('.jpg') or logo_path.endswith('.PNG') \ 551 | or logo_path.endswith('.JPG') or logo_path.endswith('.JPEG'): 552 | if logo_path.startswith('loginpage') or logo_path.startswith('homepage'): # skip homepage/loginpage 553 | continue 554 | logo_feat_list.append(get_ocr_aided_siamese_embedding(img=os.path.join(targetlist_path, target, logo_path), 555 | model=model, 556 | ocr_model=ocr_model, 557 | grayscale=grayscale)) 558 | file_name_list.append(str(os.path.join(targetlist_path, target, logo_path))) 559 | 560 | return np.asarray(logo_feat_list), np.asarray(file_name_list) 561 | 562 | def check_domain_brand_inconsistency(logo_boxes, 563 | domain_map_path: str, 564 | model, 565 | ocr_model, 566 | logo_feat_list, 567 | file_name_list, 568 | shot_path: str, 569 | url: str, 570 | ts: float): 571 | 572 | # targetlist domain list 573 | with open(domain_map_path, 'rb') as handle: 574 | domain_map = pickle.load(handle) 575 | 576 | # look at boxes for logo class only 577 | print('number of logo boxes:', len(logo_boxes)) 578 | suffix_part = '.'+ tldextract.extract(url).suffix 579 | domain_part = tldextract.extract(url).domain 580 | extracted_domain = domain_part + suffix_part 581 | 582 | matched_target, matched_domain, matched_coord, this_conf = None, None, None, None 583 | 584 | 585 | # run logo matcher 586 | if len(logo_boxes) > 0: 587 | # siamese prediction for logo box 588 | for i, coord in enumerate(logo_boxes): 589 | min_x, min_y, max_x, max_y = coord 590 | bbox = [float(min_x), float(min_y), float(max_x), float(max_y)] 591 | matched_target, matched_domain, this_conf = pred_brand(model, ocr_model, domain_map, 592 | logo_feat_list, file_name_list, 593 | shot_path, bbox, t_s=ts, grayscale=False) 594 | 595 | # domain matcher to avoid FP 596 | # if matched_target is not None: 597 | # matched_coord = coord 598 | # # if tldextract.extract(url).domain+ '.'+tldextract.extract(url).suffix not in matched_domain: 599 | # if tldextract.extract(url).domain not in matched_domain: 600 | # # avoid fp due to godaddy domain parking, ignore webmail provider (ambiguous) 601 | # if matched_target == 'GoDaddy' or matched_target == "Webmail Provider" or matched_target == "Government of the United Kingdom": 602 | # matched_target = None # ignore the prediction 603 | # matched_domain = None # ignore the prediction 604 | # else: # benign, real target 605 | # matched_target = None # ignore the prediction 606 | # matched_domain = None # ignore the prediction 607 | # break # break if target is matched 608 | # break # only look at 1st logo 609 | if (matched_target is not None) and (matched_domain is not None): 610 | matched_coord = coord 611 | matched_domain_parts = [tldextract.extract(x).domain for x in matched_domain] 612 | matched_suffix_parts = [tldextract.extract(x).suffix for x in matched_domain] 613 | 614 | # If the webpage domain exactly aligns with the target website's domain => Benign 615 | if extracted_domain in matched_domain: 616 | matched_target, matched_domain = None, None # Clear if domains are consistent 617 | elif domain_part in matched_domain_parts: # # elIf only the 2nd-level-domains align, and the tld is regional => Benign 618 | if "." + suffix_part.split('.')[-1] in COUNTRY_TLDs: 619 | matched_target, matched_domain = None, None 620 | else: 621 | break # Inconsistent domain found, break the loop 622 | else: 623 | break # Inconsistent domain found, break the loop 624 | break # only look at 1st logo 625 | 626 | return brand_converter(matched_target), matched_domain, matched_coord, this_conf 627 | 628 | 629 | 630 | -------------------------------------------------------------------------------- /scripts/utils/web_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Optional, Dict, Set 2 | from numpy.typing import ArrayLike, NDArray 3 | from typing import Sequence, Tuple, Union 4 | from PIL import Image 5 | from selenium import webdriver 6 | from selenium.webdriver.chrome.service import Service 7 | from webdriver_manager.chrome import ChromeDriverManager 8 | from selenium.webdriver.chrome.options import Options 9 | import re 10 | import requests 11 | from concurrent.futures import ThreadPoolExecutor 12 | from selenium.common.exceptions import ( 13 | NoSuchElementException, 14 | TimeoutException, 15 | StaleElementReferenceException, 16 | WebDriverException, 17 | JavascriptException 18 | ) 19 | import os 20 | import io 21 | import time 22 | from typing import Optional, Tuple 23 | from selenium.webdriver.remote.webdriver import WebDriver 24 | from selenium.webdriver.remote.webelement import WebElement 25 | from selenium.webdriver.common.by import By 26 | from selenium.webdriver.common.action_chains import ActionChains 27 | from selenium.webdriver.support.ui import WebDriverWait 28 | from selenium.webdriver.support import expected_conditions as EC 29 | import numpy as np 30 | from .logger_utils import PhishLLMLogger 31 | import torch.nn as nn 32 | from functools import partial 33 | import logging 34 | from logging.handlers import RotatingFileHandler 35 | from tldextract import tldextract 36 | 37 | '''webdriver utils''' 38 | def _enable_python_logging(log_path: str = "selenium-debug.log") -> None: 39 | # Root logger (console + rotating file) 40 | root = logging.getLogger() 41 | if not root.handlers: 42 | logging.basicConfig( 43 | level=logging.DEBUG, 44 | format="%(asctime)s %(levelname)s:%(name)s:%(message)s" 45 | ) 46 | fh = RotatingFileHandler(log_path, maxBytes=5_000_000, backupCount=3) 47 | fh.setLevel(logging.DEBUG) 48 | fh.setFormatter(logging.Formatter("%(asctime)s %(levelname)s:%(name)s:%(message)s")) 49 | root.addHandler(fh) 50 | logging.getLogger("selenium").setLevel(logging.DEBUG) 51 | logging.getLogger("urllib3").setLevel(logging.DEBUG) 52 | 53 | def boot_driver( 54 | python_log_file: Optional[str] = "selenium-debug.log", 55 | ) -> WebDriver: 56 | if python_log_file: 57 | _enable_python_logging(python_log_file) 58 | options = Options() 59 | options.add_argument("--headless") 60 | options.add_argument("--window-size=1920,1080") # set resolution 61 | options.add_argument("--no-sandbox") # (Linux) avoids sandbox issues 62 | options.add_argument("--disable-dev-shm-usage") # Fixes shared memory errors 63 | options.add_argument("--disable-gpu") # (Windows) GPU acceleration off in headless 64 | options.add_argument("--no-proxy-server") 65 | service = Service( 66 | ChromeDriverManager().install(), 67 | ) 68 | driver = webdriver.Chrome(service=service, options=options) 69 | return driver 70 | 71 | 72 | def restart_driver(driver: WebDriver) -> WebDriver: 73 | driver.quit() 74 | time.sleep(2) 75 | return boot_driver() 76 | 77 | def is_valid_domain(domain: Union[str, None]) -> bool: 78 | # Regular expression to check if the string is a valid domain without spaces 79 | if domain is None: 80 | return False 81 | pattern = re.compile( 82 | r'^(?!-)' # Cannot start with a hyphen 83 | r'(?!.*--)' # Cannot have two consecutive hyphens 84 | r'(?!.*\.\.)' # Cannot have two consecutive periods 85 | r'(?!.*\s)' # Cannot contain any spaces 86 | r'[a-zA-Z0-9-]{1,63}' # Valid characters are alphanumeric and hyphen 87 | r'(?:\.[a-zA-Z]{2,})+$' # Ends with a valid top-level domain 88 | ) 89 | it_is_a_domain = bool(pattern.fullmatch(domain)) 90 | return it_is_a_domain 91 | 92 | 93 | # -- Robust domain extraction from free-form answers -- 94 | def normalize_domain(text: str) -> Optional[str]: 95 | """ 96 | Extract and normalize a domain from model output. 97 | Accepts bare domains possibly wrapped with punctuation or code fences. 98 | Returns eTLD+1 style if valid, else None. 99 | """ 100 | if not text: 101 | return None 102 | 103 | # Common cleanup: strip code fences/quotes and trailing punctuation 104 | s = text.strip().strip("`'\" \t\r\n;,:.()[]{}") 105 | s = s.replace("http://", "").replace("https://", "").replace("www.", "") 106 | s = s.split()[0] # take the first token if multiple words 107 | 108 | # Prefer explicit domain-like substrings anywhere in the string 109 | candidates = re.findall(r'\b(?:[A-Za-z0-9-]+\.)+[A-Za-z]{2,}\b', text) 110 | if s not in candidates: 111 | candidates = [s] + candidates 112 | 113 | for cand in candidates: 114 | cand = cand.strip().lower().strip("`'\" \t\r\n;,:.()[]{}") 115 | # Validate via tldextract + your is_valid_domain helper 116 | try: 117 | ext = tldextract.extract(cand) 118 | dom = '.'.join(p for p in (ext.domain, ext.suffix) if p) 119 | except Exception: 120 | continue 121 | if dom and is_valid_domain(dom): 122 | return dom 123 | return None 124 | 125 | def url2logo( 126 | driver: WebDriver, 127 | url: str, 128 | logo_extractor: nn.Module 129 | ) -> Optional[Image.Image]: 130 | 131 | reference_logo = None 132 | try: 133 | driver.get(url) # Visit the webpage 134 | time.sleep(2) 135 | screenshot_path = "tmp.png" 136 | driver.get_screenshot_as_file(screenshot_path) 137 | logo_boxes = logo_extractor(screenshot_path) 138 | if len(logo_boxes): 139 | logo_coord = logo_boxes[0] 140 | screenshot_img = Image.open(screenshot_path).convert("RGB") 141 | reference_logo = screenshot_img.crop((int(logo_coord[0]), int(logo_coord[1]), 142 | int(logo_coord[2]), int(logo_coord[3]))) 143 | os.remove(screenshot_path) 144 | except WebDriverException as e: 145 | print(f"Error accessing the webpage: {e}") 146 | except Exception as e: 147 | print(f"Failed to take screenshot: {e}") 148 | finally: 149 | driver = restart_driver(driver) 150 | return reference_logo 151 | 152 | 153 | def query2url( 154 | query: str, 155 | SEARCH_ENGINE_API: str, 156 | SEARCH_ENGINE_ID: str, 157 | num: int = 10, 158 | proxies: Optional[Dict] = None 159 | ) -> List[str]: 160 | ''' 161 | Google Search 162 | ''' 163 | if len(query) == 0: 164 | return [] 165 | 166 | num = int(num) 167 | URL = f"https://www.googleapis.com/customsearch/v1?key={SEARCH_ENGINE_API}&cx={SEARCH_ENGINE_ID}&q={query}&num={num}&filter=1" 168 | while True: 169 | try: 170 | data = requests.get(URL, proxies=proxies).json() 171 | break 172 | except requests.exceptions.SSLError as e: 173 | print(e) 174 | time.sleep(1) 175 | 176 | if data.get('error', {}).get('code') == 429: 177 | raise RuntimeError("Google search exceeds quota limit") 178 | 179 | search_items = data.get("items") 180 | if search_items is None: 181 | return [] 182 | 183 | returned_urls = [item.get("link") for item in search_items] 184 | 185 | return returned_urls 186 | 187 | 188 | 189 | def query2image( 190 | query: str, 191 | SEARCH_ENGINE_API: str, 192 | SEARCH_ENGINE_ID: str, 193 | num: int = 10, 194 | proxies: Optional[Dict] = None 195 | ) -> List[str]: 196 | ''' 197 | Google Image Search 198 | ''' 199 | if len(query) == 0: 200 | return [] 201 | 202 | num = int(num) 203 | URL = f"https://www.googleapis.com/customsearch/v1?key={SEARCH_ENGINE_API}&cx={SEARCH_ENGINE_ID}&q={query}&searchType=image&num={num}&filter=1" 204 | while True: 205 | try: 206 | data = requests.get(URL, proxies=proxies).json() 207 | break 208 | except requests.exceptions.SSLError as e: 209 | print(e) 210 | time.sleep(1) 211 | 212 | if data.get('error', {}).get('code') == 429: 213 | raise RuntimeError("Google search exceeds quota limit") 214 | 215 | returned_urls = [item.get("image")["thumbnailLink"] for item in data.get("items", [])] 216 | 217 | return returned_urls 218 | 219 | 220 | def download_image( 221 | url: str, 222 | proxies: Optional[Dict] = None 223 | ) -> Optional[Image.Image]: 224 | 225 | try: 226 | response = requests.get(url, proxies=proxies, timeout=5) 227 | if response.status_code == 200: 228 | img = Image.open(io.BytesIO(response.content)) 229 | return img 230 | except requests.exceptions.Timeout: 231 | print("Request timed out after", 5, "seconds.") 232 | except requests.exceptions.RequestException as e: 233 | print(f"An error occurred while downloading image: {e}") 234 | 235 | return None 236 | 237 | 238 | def get_images( 239 | image_urls: List[str], 240 | proxies: Optional[Dict] = None 241 | ) -> List[Image.Image]: 242 | 243 | images = [] 244 | if len(image_urls) > 0: 245 | with ThreadPoolExecutor(max_workers=len(image_urls)) as executor: 246 | futures = [executor.submit(download_image, url, proxies) for url in image_urls] 247 | for future in futures: 248 | img = future.result() 249 | if img: 250 | images.append(img) 251 | 252 | return images 253 | 254 | 255 | def is_alive_domain( 256 | domain: str, 257 | proxies: Optional[Dict] = None 258 | ) -> bool: 259 | try: 260 | response = requests.head('https://www.' + domain, timeout=10, proxies=proxies) # Reduced timeout and used HEAD 261 | PhishLLMLogger.spit(f'Domain {domain}, status code {response.status_code}', 262 | caller_prefix=PhishLLMLogger._caller_prefix, debug=True) 263 | if response.status_code < 400 or response.status_code in [405, 429] or response.status_code >= 500: 264 | PhishLLMLogger.spit(f'Domain {domain} is valid and alive', caller_prefix=PhishLLMLogger._caller_prefix, 265 | debug=True) 266 | return True 267 | elif response.history and any([r.status_code < 400 for r in response.history]): 268 | PhishLLMLogger.spit(f'Domain {domain} is valid and alive', caller_prefix=PhishLLMLogger._caller_prefix, 269 | debug=True) 270 | return True 271 | 272 | except Exception as err: 273 | PhishLLMLogger.spit(f'Error {err} when checking the aliveness of domain {domain}', 274 | caller_prefix=PhishLLMLogger._caller_prefix, debug=True) 275 | return False 276 | 277 | PhishLLMLogger.spit(f'Domain {domain} is invalid or dead', caller_prefix=PhishLLMLogger._caller_prefix, debug=True) 278 | return False 279 | 280 | def has_page_content_changed( 281 | curr_screenshot_elements: List[int], 282 | prev_screenshot_elements: List[int] 283 | )-> bool: 284 | bincount_prev_elements = np.bincount(prev_screenshot_elements) 285 | bincount_curr_elements = np.bincount(curr_screenshot_elements) 286 | set_of_elements = min(len(bincount_prev_elements), len(bincount_curr_elements)) 287 | screenshot_ele_change_ts = np.sum( 288 | bincount_prev_elements) // 2 # half the different UI elements distribution has changed 289 | 290 | if np.sum(np.abs(bincount_curr_elements[:set_of_elements] - bincount_prev_elements[ 291 | :set_of_elements])) > screenshot_ele_change_ts: 292 | PhishLLMLogger.spit(f"Webpage content has changed", caller_prefix=PhishLLMLogger._caller_prefix, debug=True) 293 | return True 294 | else: 295 | PhishLLMLogger.spit(f"Webpage content didn't change", caller_prefix=PhishLLMLogger._caller_prefix, debug=True) 296 | return False 297 | 298 | 299 | def screenshot_element( 300 | elem: WebElement, 301 | dom: str, 302 | driver: WebDriver 303 | ) -> Tuple[Optional[str], 304 | Optional[Image.Image], 305 | Optional[str]]: 306 | """ 307 | Returns: 308 | (candidate_ui, ele_screenshot_img, candidate_ui_text) 309 | - candidate_ui: the clickable_dom you passed in (or None on failure) 310 | - ele_screenshot_img: PIL.Image.Image of the element (or None on failure) 311 | - candidate_ui_text: element text/value (or None) 312 | """ 313 | candidate_ui = None 314 | ele_screenshot_img = None 315 | candidate_ui_text = None 316 | 317 | try: 318 | # Scroll to top (plain Selenium) 319 | driver.execute_script("window.scrollTo(0, 0);") 320 | 321 | # Ensure the element is in view (center it to reduce cropping issues) 322 | try: 323 | driver.execute_script("arguments[0].scrollIntoView({block:'center', inline:'center'});", elem) 324 | except Exception: 325 | pass 326 | 327 | # Basic visibility by rect 328 | rect = elem.rect # {'x','y','width','height'} in CSS pixels 329 | w, h = rect.get("width", 0), rect.get("height", 0) 330 | if w <= 0 or h <= 0: 331 | return candidate_ui, ele_screenshot_img, candidate_ui_text 332 | 333 | # Preferred path: Selenium can screenshot elements directly 334 | try: 335 | png = elem.screenshot_as_png # bytes 336 | ele_screenshot_img = Image.open(io.BytesIO(png)) 337 | candidate_ui = dom 338 | etext = (elem.text or "") # visible text 339 | if not etext: 340 | etext = elem.get_attribute("value") or "" 341 | candidate_ui_text = etext 342 | return candidate_ui, ele_screenshot_img, candidate_ui_text 343 | 344 | except (WebDriverException, StaleElementReferenceException): 345 | pass 346 | 347 | try: 348 | # Scroll offsets + device pixel ratio for accurate cropping 349 | sx, sy, dpr = driver.execute_script( 350 | "return [window.scrollX, window.scrollY, window.devicePixelRatio || 1];" 351 | ) 352 | 353 | # Re-fetch rect in case it changed after scroll 354 | rect = elem.rect 355 | x, y, w, h = rect["x"], rect["y"], rect["width"], rect["height"] 356 | 357 | # Convert page coords -> viewport coords, then scale by DPR 358 | left = int((x - sx) * dpr) 359 | top = int((y - sy) * dpr) 360 | right = int((x - sx + w) * dpr) 361 | bottom = int((y - sy + h) * dpr) 362 | 363 | # Take a viewport screenshot and crop 364 | viewport_png = driver.get_screenshot_as_png() 365 | image = Image.open(io.BytesIO(viewport_png)) 366 | 367 | # Clamp to image bounds 368 | left = max(0, min(left, image.width)) 369 | top = max(0, min(top, image.height)) 370 | right = max(0, min(right, image.width)) 371 | bottom = max(0, min(bottom, image.height)) 372 | 373 | if right > left and bottom > top: 374 | ele_screenshot_img = image.crop((left, top, right, bottom)) 375 | candidate_ui = dom 376 | etext = (elem.text or "") 377 | if not etext: 378 | etext = elem.get_attribute("value") or "" 379 | candidate_ui_text = etext 380 | 381 | except Exception as e2: 382 | print(f"Error processing element {dom} (crop fallback): {e2}") 383 | 384 | except Exception as e: 385 | print(f"Error accessing element {dom}: {e}") 386 | 387 | return candidate_ui, ele_screenshot_img, candidate_ui_text 388 | 389 | 390 | def get_all_clickable_elements( 391 | driver: WebDriver 392 | ) -> Tuple[Tuple[List[WebElement], List[str]], 393 | Tuple[List[WebElement], List[str]], 394 | Tuple[List[WebElement], List[str]], 395 | Tuple[List[WebElement], List[str]]]: 396 | """ 397 | Collect clickable elements using plain Selenium: 398 | - Buttons (