18 |
19 | ## Introduction
20 | Existing reference-based phishing detection:
21 |
22 | - :x: Relies on a pre-defined reference list, which is lack of comprehensiveness and incurs high maintenance cost
23 | - :x: Does not fully make use of the textual semantics present on the webpage
24 |
25 | In our PhishVLM, we build a reference-based phishing detection framework:
26 |
27 | - ✅ **Without the pre-defined reference list**: Modern VLMs have encoded far more extensive brand-domain information than any predefined list
28 | - ✅ **Chain-of-thought credential-taking prediction**: Reasoning the credential-taking status in a step-by-step way by looking at the screenshot
29 |
30 | ## Framework
31 |
32 |
33 | ```Input```: a URL and its screenshot, ```Output```: Phish/Benign, Phishing target
34 |
35 | - **Step 1: Brand recognition model**
36 | - Input: Logo Screenshot
37 | - Output: VLM's predicted brand
38 |
39 | - **Step 2: Credential-Requiring-Page classification model**
40 | - Input: Webpage Screenshot
41 | - Output: VLM chooses from A. Credential-Taking Page or B. Non-Credential-Taking Page
42 | - Go to step 4 if VLM chooses 'A', otherwise go to step 3.
43 |
44 | - **Step 3: Credential-Requiring-Page transition model (activates if VLM chooses 'B' from the last step)**
45 | - Input: All clickable UI elements screenshots
46 | - Intermediate Output: Top-1 most likely login UI
47 | - Output: Webpage after clicking that UI, **go back to Step 1** with the updated webpage and URL
48 |
49 | - **Step 4: Output step**
50 | - _Case 1_: If the domain is from a web hosting domain: it is flagged as **phishing** if
51 | (i) VLM predicts a targeted brand inconsistent with the webpage's domain
52 | and (ii) VLM chooses 'A' from Step 2
53 |
54 | - _Case 2_: If the domain is not from a web hosting domain: it is flagged as **phishing** if
55 | (i) VLM predicts a targeted brand inconsistent with the webpage's domain
56 | (ii) VLM chooses 'A' from Step 2
57 | and (iii) the domain is not a popular domain indexed by Google
58 |
59 | - _Otherwise_: reported as **benign**
60 |
61 | ## Project structure
62 |
63 |
77 |
78 | ## Setup
79 |
80 | ### Step 1: **Install Requirements**.
81 |
82 | Tested on Ubuntu, CUDA 11
83 |
84 | - A new conda environment "phishllm" will be created after this step
85 | ```bash
86 | conda create -n phishllm python=3.10
87 | conda activate phishllm
88 | pip install -r requirements.txt
89 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
90 | pip install --no-build-isolation git+https://github.com/facebookresearch/detectron2.git
91 | cd scripts/phishintention
92 | chmod +x setup.sh
93 | ./setup.sh
94 | ```
95 |
96 |
97 | ### Step 2: **Install Chrome**
98 | ```bash
99 | sudo apt install ./google-chrome-stable_current_amd64.deb
100 | ```
101 |
102 | ### Step 3: Register **Two API Keys**.
103 |
104 | - 🔑 **OpenAI API key**, [See Tutorial here](https://platform.openai.com/docs/quickstart). Paste the API key to ``./datasets/openai_key.txt``.
105 |
106 | - 🔑 **Google Programmable Search API Key**, [See Tutorial here](https://meta.discourse.org/t/google-search-for-discourse-ai-programmable-search-engine-and-custom-search-api/307107).
107 | Paste your API Key (in the first line) and Search Engine ID (in the second line) to ``./datasets/google_api_key.txt``:
108 | ```text
109 | [API_KEY]
110 | [SEARCH_ENGINE_ID]
111 | ```
112 |
113 | ## Prepare the Dataset
114 | To test on your own dataset, you need to prepare the dataset in the following structure:
115 |
116 | testing_dir/
117 | ├── aaa.com/
118 | │ ├── shot.png # save the webpage screenshot
119 | │ ├── info.txt # save the webpage URL
120 | │ └── html.txt # save the webpage HTML source
121 | ├── bbb.com/
122 | │ ├── shot.png # save the webpage screenshot
123 | │ ├── info.txt # save the webpage URL
124 | │ └── html.txt # save the webpage HTML source
125 | ├── ccc.com/
126 | │ ├── shot.png # save the webpage screenshot
127 | │ ├── info.txt # save the webpage URL
128 | │ └── html.txt # save the webpage HTML source
129 |
130 |
131 |
132 | ## Inference: Run PhishLLM
133 | ```bash
134 | conda activate phishllm
135 | python scripts/infer/test.py --folder [folder to test, e.g., ./datasets/test_sites]
136 | ```
137 |
138 | ## Understand the Output
139 | - You will see the console is printing logs like the following Expand to see the sample log
140 |
141 | [PhishLLMLogger][DEBUG] Folder ./datasets/field_study/2023-09-01/device-862044b2-5124-4735-b6d5-f114eea4a232.remotewd.com
142 | [PhishLLMLogger][DEBUG] Time taken for LLM brand prediction: 0.9699530601501465 Detected brand: sonicwall.com
143 | [PhishLLMLogger][DEBUG] Domain sonicwall.com is valid and alive
144 | [PhishLLMLogger][DEBUG] Time taken for LLM CRP classification: 2.9195783138275146 CRP prediction: A. This is a credential-requiring page.
145 | [❗️] Phishing discovered, phishing target is sonicwall.com
146 |
147 |
148 | - Meanwhile, a txt file named "[today's date]_phishllm.txt" is being created, it has the following columns:
149 | - "folder": name of the folder
150 | - "phish_prediction": "phish" | "benign"
151 | - "target_prediction": phishing target brand's domain, e.g. paypal.com, meta.com
152 | - "brand_recog_time": time taken for brand recognition
153 | - "crp_prediction_time": time taken for CRP prediction
154 | - "crp_transition_time": time taken for CRP transition
155 |
156 | ## Citations
157 | ```bibtex
158 | @inproceedings {299838,
159 | author = {Ruofan Liu and Yun Lin and Xiwen Teoh and Gongshen Liu and Zhiyong Huang and Jin Song Dong},
160 | title = {Less Defined Knowledge and More True Alarms: Reference-based Phishing Detection without a Pre-defined Reference List},
161 | booktitle = {33rd USENIX Security Symposium (USENIX Security 24)},
162 | year = {2024},
163 | isbn = {978-1-939133-44-1},
164 | address = {Philadelphia, PA},
165 | pages = {523--540},
166 | url = {https://www.usenix.org/conference/usenixsecurity24/presentation/liu-ruofan},
167 | publisher = {USENIX Association},
168 | month = aug
169 | }
170 | ```
171 | If you have any issues running our code, you can raise a Github issue or email us liu.ruofan16@u.nus.edu, lin_yun@sjtu.edu.cn, dcsdjs@nus.edu.sg.
--------------------------------------------------------------------------------
/scripts/phishintention/modules/models2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Lint as: python3
16 | """Bottleneck ResNet v2 with GroupNorm and Weight Standardization."""
17 |
18 | from collections import OrderedDict # pylint: disable=g-importing-member
19 |
20 | import torch
21 | import torch.nn as nn
22 | import torch.nn.functional as F
23 |
24 |
25 | class StdConv2d(nn.Conv2d):
26 |
27 | def forward(self, x):
28 | w = self.weight
29 | v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
30 | w = (w - m) / torch.sqrt(v + 1e-10)
31 | return F.conv2d(x, w, self.bias, self.stride, self.padding,
32 | self.dilation, self.groups)
33 |
34 |
35 | def conv3x3(cin, cout, stride=1, groups=1, bias=False):
36 | return StdConv2d(cin, cout, kernel_size=3, stride=stride,
37 | padding=1, bias=bias, groups=groups)
38 |
39 |
40 | def conv1x1(cin, cout, stride=1, bias=False):
41 | return StdConv2d(cin, cout, kernel_size=1, stride=stride,
42 | padding=0, bias=bias)
43 |
44 |
45 | def tf2th(conv_weights):
46 | """Possibly convert HWIO to OIHW."""
47 | if conv_weights.ndim == 4:
48 | conv_weights = conv_weights.transpose([3, 2, 0, 1])
49 | return torch.from_numpy(conv_weights)
50 |
51 |
52 | class PreActBottleneck(nn.Module):
53 | """Pre-activation (v2) bottleneck block.
54 |
55 | Follows the implementation of "Identity Mappings in Deep Residual Networks":
56 | https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua
57 |
58 | Except it puts the stride on 3x3 conv when available.
59 | """
60 |
61 | def __init__(self, cin, cout=None, cmid=None, stride=1):
62 | super().__init__()
63 | cout = cout or cin
64 | cmid = cmid or cout//4
65 |
66 | self.gn1 = nn.GroupNorm(32, cin)
67 | self.conv1 = conv1x1(cin, cmid)
68 | self.gn2 = nn.GroupNorm(32, cmid)
69 | self.conv2 = conv3x3(cmid, cmid, stride) # Original code has it on conv1!!
70 | self.gn3 = nn.GroupNorm(32, cmid)
71 | self.conv3 = conv1x1(cmid, cout)
72 | self.relu = nn.ReLU(inplace=True)
73 |
74 | if (stride != 1 or cin != cout):
75 | # Projection also with pre-activation according to paper.
76 | self.downsample = conv1x1(cin, cout, stride)
77 |
78 | def forward(self, x):
79 | out = self.relu(self.gn1(x))
80 |
81 | # Residual branch
82 | residual = x
83 | if hasattr(self, 'downsample'):
84 | residual = self.downsample(out)
85 |
86 | # Unit's branch
87 | out = self.conv1(out)
88 | out = self.conv2(self.relu(self.gn2(out)))
89 | out = self.conv3(self.relu(self.gn3(out)))
90 |
91 | return out + residual
92 |
93 | def load_from(self, weights, prefix=''):
94 | convname = 'standardized_conv2d'
95 | with torch.no_grad():
96 | self.conv1.weight.copy_(tf2th(weights[f'{prefix}a/{convname}/kernel']))
97 | self.conv2.weight.copy_(tf2th(weights[f'{prefix}b/{convname}/kernel']))
98 | self.conv3.weight.copy_(tf2th(weights[f'{prefix}c/{convname}/kernel']))
99 | self.gn1.weight.copy_(tf2th(weights[f'{prefix}a/group_norm/gamma']))
100 | self.gn2.weight.copy_(tf2th(weights[f'{prefix}b/group_norm/gamma']))
101 | self.gn3.weight.copy_(tf2th(weights[f'{prefix}c/group_norm/gamma']))
102 | self.gn1.bias.copy_(tf2th(weights[f'{prefix}a/group_norm/beta']))
103 | self.gn2.bias.copy_(tf2th(weights[f'{prefix}b/group_norm/beta']))
104 | self.gn3.bias.copy_(tf2th(weights[f'{prefix}c/group_norm/beta']))
105 | if hasattr(self, 'downsample'):
106 | w = weights[f'{prefix}a/proj/{convname}/kernel']
107 | self.downsample.weight.copy_(tf2th(w))
108 |
109 |
110 | class ResNetV2(nn.Module):
111 | """Implementation of Pre-activation (v2) ResNet mode."""
112 |
113 | def __init__(self, block_units, width_factor, head_size=21843, zero_head=False, ocr_emb_size=512):
114 | super().__init__()
115 | wf = width_factor
116 | self.wf = wf
117 | # The following will be unreadable if we split lines.
118 | # pylint: disable=line-too-long
119 | self.root = nn.Sequential(OrderedDict([
120 | ('conv', StdConv2d(3, 64*wf, kernel_size=7, stride=2, padding=3, bias=False)),
121 | ('pad', nn.ConstantPad2d(1, 0)),
122 | ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)),
123 | ]))
124 |
125 | self.body = nn.Sequential(OrderedDict([
126 | ('block1', nn.Sequential(OrderedDict(
127 | [('unit01', PreActBottleneck(cin=64*wf, cout=256*wf, cmid=64*wf))] +
128 | [(f'unit{i:02d}', PreActBottleneck(cin=256*wf, cout=256*wf, cmid=64*wf)) for i in range(2, block_units[0] + 1)],
129 | ))),
130 | ('block2', nn.Sequential(OrderedDict(
131 | [('unit01', PreActBottleneck(cin=256*wf, cout=512*wf, cmid=128*wf, stride=2))] +
132 | [(f'unit{i:02d}', PreActBottleneck(cin=512*wf, cout=512*wf, cmid=128*wf)) for i in range(2, block_units[1] + 1)],
133 | ))),
134 | ('block3', nn.Sequential(OrderedDict(
135 | [('unit01', PreActBottleneck(cin=512*wf, cout=1024*wf, cmid=256*wf, stride=2))] +
136 | [(f'unit{i:02d}', PreActBottleneck(cin=1024*wf, cout=1024*wf, cmid=256*wf)) for i in range(2, block_units[2] + 1)],
137 | ))),
138 | ('block4', nn.Sequential(OrderedDict(
139 | [('unit01', PreActBottleneck(cin=1024*wf, cout=2048*wf, cmid=512*wf, stride=2))] +
140 | [(f'unit{i:02d}', PreActBottleneck(cin=2048*wf, cout=2048*wf, cmid=512*wf)) for i in range(2, block_units[3] + 1)],
141 | ))),
142 | ]))
143 | # pylint: enable=line-too-long
144 |
145 | self.zero_head = zero_head
146 | self.head = nn.Sequential(OrderedDict([
147 | ('gn', nn.GroupNorm(32, 2048*wf)),
148 | ('relu', nn.ReLU(inplace=True)),
149 | ('avg', nn.AdaptiveAvgPool2d(output_size=1)),
150 | ]))
151 |
152 | self.additionalfc = nn.Sequential(OrderedDict([
153 | ('conv_add', nn.Linear(2048*wf+ocr_emb_size, head_size)),
154 | ]))
155 |
156 | def features(self, x, ocr_emb):
157 | x = self.head(self.body(self.root(x)))
158 | x = x.view(-1, 2048*self.wf)
159 | x = torch.cat((x, ocr_emb), dim=1)
160 | return x.squeeze(-1).squeeze(-1)
161 |
162 | def forward(self, x, ocr_emb):
163 | x = self.head(self.body(self.root(x)))
164 | x = x.view(-1, 2048*self.wf)
165 | x = torch.cat((x, ocr_emb), dim=1)
166 | x = self.additionalfc(x)
167 | print(x.shape)
168 |
169 | return x
170 |
171 | def load_from(self, weights, prefix='resnet/'):
172 | with torch.no_grad():
173 | self.root.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])) # pylint: disable=line-too-long
174 | self.head.gn.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
175 | self.head.gn.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
176 | for bname, block in self.body.named_children():
177 | for uname, unit in block.named_children():
178 | unit.load_from(weights, prefix=f'{prefix}{bname}/{uname}/')
179 |
180 |
181 | KNOWN_MODELS = OrderedDict([
182 | ('BiT-M-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)),
183 | ('BiT-M-R50x3', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)),
184 | ('BiT-M-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)),
185 | ('BiT-M-R101x3', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)),
186 | ('BiT-M-R152x2', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)),
187 | ('BiT-M-R152x4', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)),
188 | ('BiT-S-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)),
189 | ('BiT-S-R50x3', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)),
190 | ('BiT-S-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)),
191 | ('BiT-S-R101x3', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)),
192 | ('BiT-S-R152x2', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)),
193 | ('BiT-S-R152x4', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)),
194 | ])
--------------------------------------------------------------------------------
/scripts/phishintention/ocr_lib/models/attention_recognition_head.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import sys
4 |
5 | import torch
6 | from torch import nn
7 | from torch.nn import functional as F
8 | from torch.nn import init
9 |
10 |
11 | class AttentionRecognitionHead(nn.Module):
12 | """
13 | input: [b x 16 x 64 x in_planes]
14 | output: probability sequence: [b x T x num_classes]
15 | """
16 | def __init__(self, num_classes, in_planes, sDim, attDim, max_len_labels):
17 | super(AttentionRecognitionHead, self).__init__()
18 | self.num_classes = num_classes # this is the output classes. So it includes the .
19 | self.in_planes = in_planes
20 | self.sDim = sDim
21 | self.attDim = attDim
22 | self.max_len_labels = max_len_labels
23 |
24 | self.decoder = DecoderUnit(sDim=sDim, xDim=in_planes, yDim=num_classes, attDim=attDim)
25 |
26 | def forward(self, x):
27 | x, targets, lengths = x
28 | batch_size = x.size(0)
29 | # Decoder
30 | state = torch.zeros(1, batch_size, self.sDim)
31 | outputs = []
32 |
33 | for i in range(max(lengths)):
34 | if i == 0:
35 | y_prev = torch.zeros((batch_size)).fill_(self.num_classes) # the last one is used as the .
36 | else:
37 | y_prev = targets[:,i-1]
38 |
39 | output, state = self.decoder(x, state, y_prev)
40 | outputs.append(output)
41 | outputs = torch.cat([_.unsqueeze(1) for _ in outputs], 1)
42 | return outputs
43 |
44 | # inference stage.
45 | def sample(self, x):
46 | x, _, _ = x
47 | batch_size = x.size(0)
48 | # Decoder
49 | state = torch.zeros(1, batch_size, self.sDim)
50 |
51 | predicted_ids, predicted_scores = [], []
52 | for i in range(self.max_len_labels):
53 | if i == 0:
54 | y_prev = torch.zeros((batch_size)).fill_(self.num_classes)
55 | else:
56 | y_prev = predicted
57 |
58 | output, state = self.decoder(x, state, y_prev)
59 | output = F.softmax(output, dim=1)
60 | score, predicted = output.max(1)
61 | predicted_ids.append(predicted.unsqueeze(1))
62 | predicted_scores.append(score.unsqueeze(1))
63 | predicted_ids = torch.cat(predicted_ids, 1)
64 | predicted_scores = torch.cat(predicted_scores, 1)
65 | # return predicted_ids.squeeze(), predicted_scores.squeeze()
66 | return predicted_ids, predicted_scores
67 |
68 | def beam_search(self, x, beam_width, eos):
69 |
70 | def _inflate(tensor, times, dim):
71 | repeat_dims = [1] * tensor.dim()
72 | repeat_dims[dim] = times
73 | return tensor.repeat(*repeat_dims)
74 |
75 | # https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py
76 | batch_size, l, d = x.size()
77 | # inflated_encoder_feats = _inflate(encoder_feats, beam_width, 0) # ABC --> AABBCC -/-> ABCABC
78 | inflated_encoder_feats = x.unsqueeze(1).permute((1,0,2,3)).repeat((beam_width,1,1,1)).permute((1,0,2,3)).contiguous().view(-1, l, d)
79 |
80 | # Initialize the decoder
81 | state = torch.zeros(1, batch_size * beam_width, self.sDim)
82 | pos_index = (torch.Tensor(range(batch_size)) * beam_width).long().view(-1, 1)
83 |
84 | # Initialize the scores
85 | sequence_scores = torch.Tensor(batch_size * beam_width, 1)
86 | sequence_scores.fill_(-float('Inf'))
87 | sequence_scores.index_fill_(0, torch.Tensor([i * beam_width for i in range(0, batch_size)]).long(), 0.0)
88 | # sequence_scores.fill_(0.0)
89 |
90 | # Initialize the input vector
91 | y_prev = torch.zeros((batch_size * beam_width)).fill_(self.num_classes)
92 |
93 | # Store decisions for backtracking
94 | stored_scores = list()
95 | stored_predecessors = list()
96 | stored_emitted_symbols = list()
97 |
98 | for i in range(self.max_len_labels):
99 | output, state = self.decoder(inflated_encoder_feats, state, y_prev)
100 | log_softmax_output = F.log_softmax(output, dim=1)
101 |
102 | sequence_scores = _inflate(sequence_scores, self.num_classes, 1)
103 | sequence_scores += log_softmax_output
104 | scores, candidates = sequence_scores.view(batch_size, -1).topk(beam_width, dim=1)
105 |
106 | # Reshape input = (bk, 1) and sequence_scores = (bk, 1)
107 | y_prev = (candidates % self.num_classes).view(batch_size * beam_width)
108 | sequence_scores = scores.view(batch_size * beam_width, 1)
109 |
110 | # Update fields for next timestep
111 | predecessors = (candidates / self.num_classes + pos_index.expand_as(candidates)).view(batch_size * beam_width, 1)
112 | state = state.index_select(1, predecessors.squeeze())
113 |
114 | # Update sequence socres and erase scores for symbol so that they aren't expanded
115 | stored_scores.append(sequence_scores.clone())
116 | eos_indices = y_prev.view(-1, 1).eq(eos)
117 | if eos_indices.nonzero().dim() > 0:
118 | sequence_scores.masked_fill_(eos_indices, -float('inf'))
119 |
120 | # Cache results for backtracking
121 | stored_predecessors.append(predecessors)
122 | stored_emitted_symbols.append(y_prev)
123 |
124 | # Do backtracking to return the optimal values
125 | #====== backtrak ======#
126 | # Initialize return variables given different types
127 | p = list()
128 | l = [[self.max_len_labels] * beam_width for _ in range(batch_size)] # Placeholder for lengths of top-k sequences
129 |
130 | # the last step output of the beams are not sorted
131 | # thus they are sorted here
132 | sorted_score, sorted_idx = stored_scores[-1].view(batch_size, beam_width).topk(beam_width)
133 | # initialize the sequence scores with the sorted last step beam scores
134 | s = sorted_score.clone()
135 |
136 | batch_eos_found = [0] * batch_size # the number of EOS found
137 | # in the backward loop below for each batch
138 | t = self.max_len_labels - 1
139 | # initialize the back pointer with the sorted order of the last step beams.
140 | # add pos_index for indexing variable with b*k as the first dimension.
141 | t_predecessors = (sorted_idx + pos_index.expand_as(sorted_idx)).view(batch_size * beam_width)
142 | while t >= 0:
143 | # Re-order the variables with the back pointer
144 | current_symbol = stored_emitted_symbols[t].index_select(0, t_predecessors)
145 | t_predecessors = stored_predecessors[t].index_select(0, t_predecessors).squeeze()
146 | eos_indices = stored_emitted_symbols[t].eq(eos).nonzero()
147 | if eos_indices.dim() > 0:
148 | for i in range(eos_indices.size(0)-1, -1, -1):
149 | # Indices of the EOS symbol for both variables
150 | # with b*k as the first dimension, and b, k for
151 | # the first two dimensions
152 | idx = eos_indices[i]
153 | b_idx = int(idx[0] / beam_width)
154 | # The indices of the replacing position
155 | # according to the replacement strategy noted above
156 | res_k_idx = beam_width - (batch_eos_found[b_idx] % beam_width) - 1
157 | batch_eos_found[b_idx] += 1
158 | res_idx = b_idx * beam_width + res_k_idx
159 |
160 | # Replace the old information in return variables
161 | # with the new ended sequence information
162 | t_predecessors[res_idx] = stored_predecessors[t][idx[0]]
163 | current_symbol[res_idx] = stored_emitted_symbols[t][idx[0]]
164 | s[b_idx, res_k_idx] = stored_scores[t][idx[0], [0]]
165 | l[b_idx][res_k_idx] = t + 1
166 |
167 | # record the back tracked results
168 | p.append(current_symbol)
169 |
170 | t -= 1
171 |
172 | # Sort and re-order again as the added ended sequences may change
173 | # the order (very unlikely)
174 | s, re_sorted_idx = s.topk(beam_width)
175 | for b_idx in range(batch_size):
176 | l[b_idx] = [l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx,:]]
177 |
178 | re_sorted_idx = (re_sorted_idx + pos_index.expand_as(re_sorted_idx)).view(batch_size*beam_width)
179 |
180 | # Reverse the sequences and re-order at the same time
181 | # It is reversed because the backtracking happens in reverse time order
182 | p = [step.index_select(0, re_sorted_idx).view(batch_size, beam_width, -1) for step in reversed(p)]
183 | p = torch.cat(p, -1)[:,0,:]
184 | return p, torch.ones_like(p)
185 |
186 |
187 | class AttentionUnit(nn.Module):
188 | def __init__(self, sDim, xDim, attDim):
189 | super(AttentionUnit, self).__init__()
190 |
191 | self.sDim = sDim
192 | self.xDim = xDim
193 | self.attDim = attDim
194 |
195 | self.sEmbed = nn.Linear(sDim, attDim)
196 | self.xEmbed = nn.Linear(xDim, attDim)
197 | self.wEmbed = nn.Linear(attDim, 1)
198 |
199 | # self.init_weights()
200 |
201 | def init_weights(self):
202 | init.normal_(self.sEmbed.weight, std=0.01)
203 | init.constant_(self.sEmbed.bias, 0)
204 | init.normal_(self.xEmbed.weight, std=0.01)
205 | init.constant_(self.xEmbed.bias, 0)
206 | init.normal_(self.wEmbed.weight, std=0.01)
207 | init.constant_(self.wEmbed.bias, 0)
208 |
209 | def forward(self, x, sPrev):
210 | batch_size, T, _ = x.size() # [b x T x xDim]
211 | x = x.view(-1, self.xDim) # [(b x T) x xDim]
212 | xProj = self.xEmbed(x) # [(b x T) x attDim]
213 | xProj = xProj.view(batch_size, T, -1) # [b x T x attDim]
214 |
215 | sPrev = sPrev.squeeze(0)
216 | sProj = self.sEmbed(sPrev) # [b x attDim]
217 | sProj = torch.unsqueeze(sProj, 1) # [b x 1 x attDim]
218 | sProj = sProj.expand(batch_size, T, self.attDim) # [b x T x attDim]
219 |
220 | sumTanh = torch.tanh(sProj + xProj)
221 | sumTanh = sumTanh.view(-1, self.attDim)
222 |
223 | vProj = self.wEmbed(sumTanh) # [(b x T) x 1]
224 | vProj = vProj.view(batch_size, T)
225 |
226 | alpha = F.softmax(vProj, dim=1) # attention weights for each sample in the minibatch
227 |
228 | return alpha
229 |
230 |
231 | class DecoderUnit(nn.Module):
232 | def __init__(self, sDim, xDim, yDim, attDim):
233 | super(DecoderUnit, self).__init__()
234 | self.sDim = sDim
235 | self.xDim = xDim
236 | self.yDim = yDim
237 | self.attDim = attDim
238 | self.emdDim = attDim
239 |
240 | self.attention_unit = AttentionUnit(sDim, xDim, attDim)
241 | self.tgt_embedding = nn.Embedding(yDim+1, self.emdDim) # the last is used for
242 | self.gru = nn.GRU(input_size=xDim+self.emdDim, hidden_size=sDim, batch_first=True)
243 | self.fc = nn.Linear(sDim, yDim)
244 |
245 | # self.init_weights()
246 |
247 | def init_weights(self):
248 | init.normal_(self.tgt_embedding.weight, std=0.01)
249 | init.normal_(self.fc.weight, std=0.01)
250 | init.constant_(self.fc.bias, 0)
251 |
252 | def forward(self, x, sPrev, yPrev):
253 | # x: feature sequence from the image decoder.
254 | batch_size, T, _ = x.size()
255 | alpha = self.attention_unit(x, sPrev)
256 | context = torch.bmm(alpha.unsqueeze(1), x).squeeze(1)
257 | yProj = self.tgt_embedding(yPrev.long())
258 | # self.gru.flatten_parameters()
259 | output, state = self.gru(torch.cat([yProj, context], 1).unsqueeze(1), sPrev)
260 | output = output.squeeze(1)
261 |
262 | output = self.fc(output)
263 | return output, state
--------------------------------------------------------------------------------
/scripts/phishintention/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cv2
3 | import numpy as np
4 | import os
5 | import math
6 |
7 |
8 | def coord_reshape(coords, image_shape, reshaped_size=(256, 512)):
9 | '''
10 | Revise coordinates when the image is resized
11 | '''
12 | height, width = image_shape
13 | new_coords = []
14 | for c in coords:
15 | x1, y1, x2, y2 = c
16 | x1n, y1n, x2n, y2n = reshaped_size[1] * x1 / width, reshaped_size[0] * y1 / height, \
17 | reshaped_size[1] * x2 / width, reshaped_size[0] * y2 / height
18 | new_coords.append([x1n, y1n, x2n, y2n])
19 |
20 | return np.asarray(new_coords)
21 |
22 |
23 | def coord2pixel_reverse(img_path, coords, types, num_types=5, reshaped_size=(256, 512)) -> torch.Tensor:
24 | '''
25 | Convert coordinate to multi-hot encodings for coordinate class
26 | '''
27 | img = cv2.imread(img_path) if not isinstance(img_path, np.ndarray) else img_path
28 | coords = coords.numpy() if not isinstance(coords, np.ndarray) else coords
29 | coords = coord_reshape(coords, img.shape[:2], reshaped_size) # reshape coordinates
30 | types = types.numpy() if not isinstance(types, np.ndarray) else types
31 |
32 | # Incorrect path/empty image
33 | if img is None:
34 | raise AttributeError('Image is None')
35 | height, width = img.shape[:2]
36 | # Empty image
37 | if height == 0 or width == 0:
38 | raise AttributeError('Empty image')
39 |
40 | # grid array of shape ClassxHxW
41 | grid_arrs = np.zeros((num_types, reshaped_size[0], reshaped_size[1]))
42 |
43 | for j, coord in enumerate(coords):
44 | x1, y1, x2, y2 = coord
45 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
46 | if x2 - x1 <= 0 or y2 - y1 <= 0:
47 | continue # ignore
48 |
49 | # multi-hot encoding for type?
50 | class_position = types[j]
51 | grid_arrs[class_position, y1:y2, x1:x2] = 1.
52 |
53 | return torch.from_numpy(grid_arrs)
54 |
55 |
56 | def coord2pixel(img_path, coords, types, num_types=5, reshaped_size=(256, 512)) -> torch.Tensor:
57 | '''
58 | Convert coordinate to multi-hot encodings for coordinate class
59 | '''
60 | img = cv2.imread(img_path) if not isinstance(img_path, np.ndarray) else img_path
61 | coords = coords.numpy() if not isinstance(coords, np.ndarray) else coords
62 | coords = coord_reshape(coords, img.shape[:2], reshaped_size) # reshape coordinates
63 | types = types.numpy() if not isinstance(types, np.ndarray) else types
64 |
65 | # Incorrect path/empty image
66 | if img is None:
67 | raise AttributeError('Image is None')
68 | height, width = img.shape[:2]
69 | # Empty image
70 | if height == 0 or width == 0:
71 | raise AttributeError('Empty image')
72 |
73 | # grid array of shape ClassxHxW = 5xHxW
74 | grid_arrs = np.zeros((num_types, reshaped_size[0], reshaped_size[1]))
75 | type_dict = {'logo': 1, 'input': 2, 'button': 3, 'label': 4, 'block': 5}
76 |
77 | for j, coord in enumerate(coords):
78 | x1, y1, x2, y2 = coord
79 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
80 | if x2 - x1 <= 0 or y2 - y1 <= 0:
81 | continue # ignore
82 |
83 | # multi-hot encoding for type?
84 | class_position = type_dict[types[j]] - 1
85 | grid_arrs[class_position, y1:y2, x1:x2] = 1.
86 |
87 | return torch.from_numpy(grid_arrs)
88 |
89 |
90 | def topo2pixel(img_path, coords, knn_matrix, reshaped_size=(256, 512)) -> torch.Tensor:
91 | '''
92 | Convert coordinate to multi-hot encodings for coordinate class
93 | '''
94 | img = cv2.imread(img_path) if not isinstance(img_path, np.ndarray) else img_path
95 | coords = coords.numpy() if not isinstance(coords, np.ndarray) else coords
96 | coords = coord_reshape(coords, img.shape[:2], reshaped_size) # reshape coordinates
97 | knn_matrix = knn_matrix.numpy() if not isinstance(knn_matrix, np.ndarray) else knn_matrix
98 |
99 | # Incorrect path/empty image
100 | if img is None:
101 | raise AttributeError('Image is None')
102 | height, width = img.shape[:2]
103 | # Empty image
104 | if height == 0 or width == 0:
105 | raise AttributeError('Empty image')
106 |
107 | # grid array of shape (KxZ)xHxW = 12xHxW
108 | topo_arrs = np.zeros((12, reshaped_size[0], reshaped_size[1]))
109 | if len(coords) <= 1: # num of components smaller than 2
110 | return torch.from_numpy(topo_arrs)
111 |
112 | for j, coord in enumerate(coords):
113 | x1, y1, x2, y2 = coord
114 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
115 | if x2 - x1 <= 0 or y2 - y1 <= 0:
116 | continue # ignore
117 |
118 | # fill in topological info (zero padding if number of neighbors is less than 3)
119 | topo_arrs[:min(len(knn_matrix[j]), 12), y1:y2, x1:x2] = knn_matrix[j][:, np.newaxis][:, np.newaxis]
120 |
121 | return torch.from_numpy(topo_arrs)
122 |
123 |
124 | def read_img_reverse(img, coords, types, num_types=5, grid_num=10) -> torch.Tensor:
125 | '''
126 | Convert image with bbox predictions as into grid format
127 | :param img: image path in str or image in np.ndarray
128 | :param coords: Nx4 tensor/np.ndarray for box coords
129 | :param types: Nx1 tensor/np.ndarray for box types (logo, input etc.)
130 | :param num_types: total number of box types
131 | :param grid_num: number of grids needed
132 | :return: grid tensor
133 | '''
134 |
135 | img = cv2.imread(img) if not isinstance(img, np.ndarray) else img
136 | coords = coords.numpy() if not isinstance(coords, np.ndarray) else coords
137 | types = types.numpy() if not isinstance(types, np.ndarray) else types
138 |
139 | # Incorrect path/empty image
140 | if img is None:
141 | raise AttributeError('Image is None')
142 |
143 | height, width = img.shape[:2]
144 |
145 | # Empty image
146 | if height == 0 or width == 0:
147 | raise AttributeError('Empty image')
148 |
149 | # grid array of shape CxHxW
150 | grid_arrs = np.zeros((4 + num_types, grid_num, grid_num)) # Must be [0, 1], use rel_x, rel_y, rel_w, rel_h
151 |
152 | for j, coord in enumerate(coords):
153 | x1, y1, x2, y2 = coord
154 | w = max(0, x2 - x1)
155 | h = max(0, y2 - y1)
156 | if w == 0 or h == 0:
157 | continue # ignore
158 |
159 | # get the assigned grid index
160 | assigned_grid_w, assigned_grid_h = int(((x1 + x2) / 2) // (width // grid_num)), int(
161 | ((y1 + y2) / 2) // (height // grid_num))
162 |
163 | # bound above
164 | assigned_grid_w = min(grid_num - 1, assigned_grid_w)
165 | assigned_grid_h = min(grid_num - 1, assigned_grid_h)
166 |
167 | # if this grid has been assigned before, check whether need to re-assign
168 | if grid_arrs[0, assigned_grid_h, assigned_grid_w] != 0: # visted
169 | exist_type = np.where(grid_arrs[:, assigned_grid_h, assigned_grid_w] == 1)[0][0] - 4
170 | new_type = types[j]
171 | if new_type > exist_type: # if new type has lower priority than existing type
172 | continue
173 |
174 | # fill in rel_xywh
175 | grid_arrs[0, assigned_grid_h, assigned_grid_w] = float(x1 / width)
176 | grid_arrs[1, assigned_grid_h, assigned_grid_w] = float(y1 / height)
177 | grid_arrs[2, assigned_grid_h, assigned_grid_w] = float(w / width)
178 | grid_arrs[3, assigned_grid_h, assigned_grid_w] = float(h / height)
179 |
180 | # one-hot encoding for type
181 | cls_arr = np.zeros(num_types)
182 | cls_arr[types[j]] = 1
183 |
184 | grid_arrs[4:, assigned_grid_h, assigned_grid_w] = cls_arr
185 |
186 | return torch.from_numpy(grid_arrs)
187 |
188 |
189 | import torch.nn.functional as F
190 | from PIL import Image
191 | import math
192 |
193 | def resolution_alignment(img1, img2):
194 | '''
195 | Resize two images according to the minimum resolution between the two
196 | :param img1: first image in PIL.Image
197 | :param img2: second image in PIL.Image
198 | :return: resized img1 in PIL.Image, resized img2 in PIL.Image
199 | '''
200 | w1, h1 = img1.size
201 | w2, h2 = img2.size
202 | w_min, h_min = min(w1, w2), min(h1, h2)
203 | if w_min == 0 or h_min == 0: ## something wrong, stop resizing
204 | return img1, img2
205 | if w_min < h_min:
206 | img1_resize = img1.resize((int(w_min), math.ceil(h1 * (w_min/w1)))) # ceiling to prevent rounding to 0
207 | img2_resize = img2.resize((int(w_min), math.ceil(h2 * (w_min/w2))))
208 | else:
209 | img1_resize = img1.resize((math.ceil(w1 * (h_min/h1)), int(h_min)))
210 | img2_resize = img2.resize((math.ceil(w2 * (h_min/h2)), int(h_min)))
211 | return img1_resize, img2_resize
212 |
213 | def brand_converter(brand_name):
214 | '''
215 | Helper function to deal with inconsistency in brand naming
216 | '''
217 | if brand_name == 'Adobe Inc.' or brand_name == 'Adobe Inc':
218 | return 'Adobe'
219 | elif brand_name == 'ADP, LLC' or brand_name == 'ADP, LLC.':
220 | return 'ADP'
221 | elif brand_name == 'Amazon.com Inc.' or brand_name == 'Amazon.com Inc':
222 | return 'Amazon'
223 | elif brand_name == 'Americanas.com S,A Comercio Electrnico':
224 | return 'Americanas.com S'
225 | elif brand_name == 'AOL Inc.' or brand_name == 'AOL Inc':
226 | return 'AOL'
227 | elif brand_name == 'Apple Inc.' or brand_name == 'Apple Inc':
228 | return 'Apple'
229 | elif brand_name == 'AT&T Inc.' or brand_name == 'AT&T Inc':
230 | return 'AT&T'
231 | elif brand_name == 'Banco do Brasil S.A.':
232 | return 'Banco do Brasil S.A'
233 | elif brand_name == 'Credit Agricole S.A.':
234 | return 'Credit Agricole S.A'
235 | elif brand_name == 'DGI (French Tax Authority)':
236 | return 'DGI French Tax Authority'
237 | elif brand_name == 'DHL Airways, Inc.' or brand_name == 'DHL Airways, Inc' or brand_name == 'DHL':
238 | return 'DHL Airways'
239 | elif brand_name == 'Dropbox, Inc.' or brand_name == 'Dropbox, Inc':
240 | return 'Dropbox'
241 | elif brand_name == 'eBay Inc.' or brand_name == 'eBay Inc':
242 | return 'eBay'
243 | elif brand_name == 'Facebook, Inc.' or brand_name == 'Facebook, Inc':
244 | return 'Facebook'
245 | elif brand_name == 'Free (ISP)':
246 | return 'Free ISP'
247 | elif brand_name == 'Google Inc.' or brand_name == 'Google Inc':
248 | return 'Google'
249 | elif brand_name == 'Mastercard International Incorporated':
250 | return 'Mastercard International'
251 | elif brand_name == 'Netflix Inc.' or brand_name == 'Netflix Inc':
252 | return 'Netflix'
253 | elif brand_name == 'PayPal Inc.' or brand_name == 'PayPal Inc':
254 | return 'PayPal'
255 | elif brand_name == 'Royal KPN N.V.':
256 | return 'Royal KPN N.V'
257 | elif brand_name == 'SF Express Co.':
258 | return 'SF Express Co'
259 | elif brand_name == 'SNS Bank N.V.':
260 | return 'SNS Bank N.V'
261 | elif brand_name == 'Square, Inc.' or brand_name == 'Square, Inc':
262 | return 'Square'
263 | elif brand_name == 'Webmail Providers':
264 | return 'Webmail Provider'
265 | elif brand_name == 'Yahoo! Inc' or brand_name == 'Yahoo! Inc.':
266 | return 'Yahoo!'
267 | elif brand_name == 'Microsoft OneDrive' or brand_name == 'Office365' or brand_name == 'Outlook':
268 | return 'Microsoft'
269 | elif brand_name == 'Global Sources (HK)':
270 | return 'Global Sources HK'
271 | elif brand_name == 'T-Online':
272 | return 'Deutsche Telekom'
273 | elif brand_name == 'Airbnb, Inc':
274 | return 'Airbnb, Inc.'
275 | elif brand_name == 'azul':
276 | return 'Azul'
277 | elif brand_name == 'Raiffeisen Bank S.A':
278 | return 'Raiffeisen Bank S.A.'
279 | elif brand_name == 'Twitter, Inc' or brand_name == 'Twitter':
280 | return 'Twitter, Inc.'
281 | elif brand_name == 'capital_one':
282 | return 'Capital One Financial Corporation'
283 | elif brand_name == 'la_banque_postale':
284 | return 'La Banque postale'
285 | elif brand_name == 'db':
286 | return 'Deutsche Bank AG'
287 | elif brand_name == 'Swiss Post' or brand_name == 'PostFinance':
288 | return 'PostFinance'
289 | elif brand_name == 'grupo_bancolombia':
290 | return 'Bancolombia'
291 | elif brand_name == 'barclays':
292 | return 'Barclays Bank Plc'
293 | elif brand_name == 'gov_uk':
294 | return 'Government of the United Kingdom'
295 | elif brand_name == 'Aruba S.p.A':
296 | return 'Aruba S.p.A.'
297 | elif brand_name == 'TSB Bank Plc':
298 | return 'TSB Bank Limited'
299 | elif brand_name == 'strato':
300 | return 'Strato AG'
301 | elif brand_name == 'cogeco':
302 | return 'Cogeco'
303 | elif brand_name == 'Canada Revenue Agency':
304 | return 'Government of Canada'
305 | elif brand_name == 'UniCredit Bulbank':
306 | return 'UniCredit Bank Aktiengesellschaft'
307 | elif brand_name == 'ameli_fr':
308 | return 'French Health Insurance'
309 | elif brand_name == 'Banco de Credito del Peru':
310 | return 'bcp'
311 | else:
312 | return brand_name
313 |
314 | def l2_norm(x):
315 | """
316 | l2 normalization
317 | :param x:
318 | :return:
319 | """
320 | if len(x.shape):
321 | x = x.reshape((x.shape[0], -1))
322 | return F.normalize(x, p=2, dim=1)
--------------------------------------------------------------------------------
/scripts/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from typing import Union, List, Optional, Dict, Any
3 | from PIL import Image
4 | import io
5 | import base64
6 | import numpy as np
7 | from numpy.typing import ArrayLike, NDArray
8 | from typing import Sequence, Tuple, Union
9 | Number = Union[int, float]
10 |
11 | '''prompt utils'''
12 | def image2base64(image: Union[str, Image.Image]) -> str:
13 | if isinstance(image, str):
14 | image = Image.open(image)
15 | img_byte_arr = io.BytesIO()
16 | image.save(img_byte_arr, format='PNG') # Ensure the format matches your image format, e.g., JPEG, PNG, etc.
17 | img_bytes = img_byte_arr.getvalue()
18 | base64_encoded = base64.b64encode(img_bytes).decode('utf-8') # Convert bytes to base64 string and decode to UTF-8
19 | return base64_encoded
20 |
21 | def prepare_candidate_uis(
22 | candidate_uis_imgs: Sequence[Union[str, Image.Image]],
23 | candidate_uis_text: Sequence[str]
24 | ) -> Sequence[Dict[str, Any]]:
25 |
26 | candidate_uis_json = []
27 | for ind, (img, text) in enumerate(zip(candidate_uis_imgs, candidate_uis_text)):
28 | base64_image = image2base64(img)
29 | candidate_uis_json.append({"type": "text",
30 | "text": f'Index {ind}: ' + text}
31 | )
32 | candidate_uis_json.append({"type": "image_url",
33 | "image_url": {
34 | "url": f"data:image/jpeg;base64,{base64_image}"}
35 | }
36 | )
37 |
38 | return candidate_uis_json
39 |
40 | def vlm_question_template_transition(
41 | candidate_uis_imgs: Sequence[Union[str, Image.Image]],
42 | candidate_uis_text: Sequence[str]
43 | ) -> Dict[str, Any]:
44 | candidate_uis_json = prepare_candidate_uis(candidate_uis_imgs, candidate_uis_text)
45 |
46 | return {
47 | "role": "user",
48 | "content": candidate_uis_json
49 | }
50 |
51 |
52 | def vlm_question_template_prediction(screenshot_img: Image.Image) -> Dict[str, Any]:
53 | return \
54 | {"role": "user",
55 | "content": [
56 | {"type": "text",
57 | "text": "Given the HTML webpage screenshot, Question: A. This is a credential-requiring page. B. This is not a credential-requiring page. \n Answer:"},
58 | {
59 | "type": "image_url",
60 | "image_url": {
61 | "url": f"data:image/jpeg;base64,{image2base64(screenshot_img)}"
62 | },
63 | },
64 | ]
65 | }
66 |
67 |
68 | def vlm_question_template_brand(logo_img: Image.Image) -> Dict[str, Any]:
69 | return \
70 | {"role": "user",
71 | "content": [
72 | {"type": "text",
73 | "text": "Given the brand's logo, Question: What is the brand's domain? Answer: "},
74 | {
75 | "type": "image_url",
76 | "image_url": {
77 | "url": f"data:image/jpeg;base64,{image2base64(logo_img)}"
78 | },
79 | },
80 | ]
81 | }
82 |
83 |
84 | '''bounding box utils'''
85 | def pairwise_intersect_area(
86 | bboxes1: ArrayLike,
87 | bboxes2: ArrayLike,
88 | ) -> NDArray[np.float32]:
89 | # Convert bboxes lists to 3D arrays
90 | bboxes1 = np.array(bboxes1)[:, np.newaxis, :]
91 | bboxes2 = np.array(bboxes2)
92 |
93 | # Compute overlap for x and y axes separately
94 | overlap_x = np.maximum(0, np.minimum(bboxes1[:, :, 2], bboxes2[:, 2]) - np.maximum(bboxes1[:, :, 0], bboxes2[:, 0]))
95 | overlap_y = np.maximum(0, np.minimum(bboxes1[:, :, 3], bboxes2[:, 3]) - np.maximum(bboxes1[:, :, 1], bboxes2[:, 1]))
96 |
97 | # Compute overlapping areas for each pair
98 | overlap_areas = overlap_x * overlap_y
99 | return overlap_areas
100 |
101 | def expand_bbox(
102 | bbox: Sequence[Number],
103 | image_width: int,
104 | image_height: int,
105 | expand_ratio: Union[Number, Tuple[Number, Number]] = 1.2,
106 | ) -> list[Number]:
107 | # Extract the coordinates
108 | x1, y1, x2, y2 = bbox
109 |
110 | # Calculate the center
111 | center_x = (x1 + x2) / 2
112 | center_y = (y1 + y2) / 2
113 |
114 | # Calculate new width and height
115 | new_width = (x2 - x1) * expand_ratio
116 | new_height = (y2 - y1) * expand_ratio
117 |
118 | # Determine new coordinates
119 | new_x1 = center_x - new_width / 2
120 | new_y1 = center_y - new_height / 2
121 | new_x2 = center_x + new_width / 2
122 | new_y2 = center_y + new_height / 2
123 |
124 | # Ensure coordinates are legitimate
125 | new_x1 = max(0, new_x1)
126 | new_y1 = max(0, new_y1)
127 | new_x2 = min(image_width, new_x2)
128 | new_y2 = min(image_height, new_y2)
129 |
130 | return [new_x1, new_y1, new_x2, new_y2]
131 |
132 | class Regexes():
133 | # e-mail
134 | EMAIL = r"(e(\-|_|\s)*)?mail(?!(\-|\_|\s)*(password|passwd|pass word|passcode|passwort))"
135 |
136 | # password
137 | PASSWORD = "password|passwd|pass word|passcode|passwort"
138 | # username
139 | USERNAME = "(u(s(e)?r)?|nick|display|profile)(\-|_|\s)*name"
140 | USERID = "^((u(s(e)?r)?|nick|display|profile|customer)(\-|_|\s)*)?id|identifi(ant|er)?|access(\-|_|\s)*code|account"
141 |
142 | # misc identifiers
143 | FULL_NAME = "full(\-|_|\s)?(name|nm|nom)|(celé jméno)"
144 | FIRST_NAME = "(f(irst|ore)?|m(iddle)?|pre)(\-|_|\s)*(name|nm|nom)"
145 | LAST_NAME = "(l(ast|st)?|s(u)?(r)?)(\-|_|\s)*(name|nm|nom)"
146 | NAME_PREFIX = "prefix"
147 |
148 | # Phones
149 | PHONE_AREA = "phone(\-|_|\s)*area|area(\-|_|\s)*code|phone(\-|_|\s)*(pfx|prefix|prfx)"
150 | PHONE = "mobile|phone|telephone|tel"
151 |
152 | # Dates
153 | MONTH = "month"
154 | DAY = "day"
155 | YEAR = "year"
156 | BIRTHDATE = "date|dob|birthdate|birthday|date(\-|_|\s)*of(\-|_|\s)*birth"
157 |
158 | # gender
159 | AGE = "(\-|_|\s)+age(\-|_|\s)+"
160 | GENDER = "gender|sex"
161 |
162 | # profile pics
163 | # FILE = "photo|picture"
164 | SMS = 'sms'
165 |
166 | # Addresses
167 | ADDRESS = "address"
168 | ZIPCODE = "(post(al)?|zip)(\-|_|\s)*(code|no|num)?"
169 | CITY = "city|town|location"
170 | COUNTRY = "countr"
171 | STATE = "stat|province"
172 | STREET = "street"
173 | BUILDING_NO = "(building|bldng|flat|apartment|apt|home|house)(\-|_|\s)*(num|no)"
174 | # SSN etc.
175 | SSN = "(ssn|vat|social(\-|_|\s)*sec(urity)?(\-|_|\s)*(num|no)?)"
176 |
177 | # Credit cards
178 | CREDIT_CARD = "(xxxx xxxx xxxx xxxx)|(0000 0000 0000 0000)|(Número de tarjeta)|(Číslo karty)|(cc(\-|_|\s)*(no|num))|(card(\-|_|\s)*(no|num))|(credit(\-|_|\s)*(no|num|card))|(card$)"
179 | CREDIT_CARD_EXPIRE = "expire|expiration|expiry|expdate|((cc|card|credit)(\-|_|\s)*date)|^exp$"
180 | CREDIT_CARD_CVV = "(sec(urity)?(\-|_|\s)*)?(cvv|csc|cvn)"
181 | ATMPIN = "atmpin|pin"
182 |
183 | # Company stuff
184 | COMPANY_NAME = "company|organi(z|s)ation|institut(e|ion)"
185 | #### END SPECIFIC REGEXES - START GENERIC ####
186 | # NUMBER_COARSE = "num|code"
187 | USERNAME_COARSE = "us(e)?r|login"
188 |
189 | OTHER_FORM = "link|search"
190 |
191 | SSO_SIGNUP_BUTTONS = "((create|register|make)|(new))\s*(new\s*)?(user|account|profile)"
192 |
193 | VERIFY_ACCOUNT = "((verify|activate)(\syour)?\s(account|e(-|\s)*mail|info))|((verification|activation) (e(-|\s)*mail|message|link|code|number))"
194 | VERIFIED_ACCOUNT = "(user(-|\s))?(account|profile)\s+(was|is|has)?(been)?(verified|activated|attivo)|(verification|activation)\s+(was|is|has)(been)?\s+(completed|done|successful)?"
195 | VERIFY_VERBS = "verify|activate"
196 |
197 | IDENTIFIERS = "%s|%s|%s|%s|%s" % (FULL_NAME, FIRST_NAME, LAST_NAME, USERNAME, EMAIL)
198 | IDENTIFIERS_NO_EMAIL = "%s|%s|%s|%s" % (FULL_NAME, FIRST_NAME, LAST_NAME, USERNAME)
199 |
200 | SUBMIT = "submit"
201 | LOGIN = "(log|sign)([^0-9a-zA-Z]|\s)*(in|on)|authenticat(e|ion)|/(my([^0-9a-zA-Z]|\s)*)?(user|account|profile|dashboard)"
202 | SIGNUP = "sign([^0-9a-zA-Z]|\s)*up|regist(er|ration)?|(create|new)([^0-9a-zA-Z]|\s)*(new([^0-9a-zA-Z]|\s)*)?(acc(ount)?|us(e)?r|prof(ile)?)|(forg(et|ot)|reset)([^0-9a-zA-Z]|\s)*((my|the)([^0-9a-zA-Z]|\s)*)?(acc(ount)?|us(e)?r|prof(ile)?|password)"
203 | SSO = "[^0-9a-zA-Z]+sso[^0-9a-zA-Z]+|oauth|openid"
204 | AUTH = "%s|%s|%s|%s|%s|auth|(new|existing)([^0-9a-zA-Z]|\s)*(us(e)?r|acc(ount)?)|account|connect|profile|dashboard|next" % (LOGIN, SIGNUP, SSO, SUBMIT, VERIFY_VERBS)
205 | LOGOUT = "(log|sign)(-|_|\s)*(out|off)"
206 | BUTTON = "suivant|make([^0-9a-zA-Z]|\s)*payment|^OK$|go([^0-9a-zA-Z]|\s)*(in)?to|sign([^0-9a-zA-Z]|\s)*in(?! with| via| using)|log([^0-9a-zA-Z]|\s)*in(?! with| via| using)|log([^0-9a-zA-Z]|\s)*on(?! with| via| using)|verify(?! with| via| using)|verification|submit(?! with| via| using)|ent(er|rar|rer|rance|ra)(?! with| via| using)|acces(o|sar|s)(?! with| via| using)|continu(er|ar)?(?! with| via| using)|connect(er)?(?! with| via| using)|next|confirm|sign([^0-9a-zA-Z]|\s)*on(?! with| via| using)|complete|valid(er|ate)(?! with| via| using)|securipass|登入|登录|登錄|登録|签到|iniciar([^0-9a-zA-Z]|\s)*sesión|identifier|ログインする|サインアップ|ログイン|로그인|시작하기|войти|вход|accedered|gabung|masuk|girişi|Giriş|เข้าสู่ระบบ|Přihlásit|mein([^0-9a-zA-Z]|\s)*konto|anmelden|ingresa|accedi|мой([^0-9a-zA-Z]|\s)*профиль|حسابي|administrer|cadastre-se|είσοδος|accessibilité|accéder|zaloguj|đăng([^0-9a-zA-Z]|\s)*nhập|weitermachen|bestätigen|zověřit|ověřit|weiter"
207 | BUTTON_FORBIDDEN = "single sign-on|guest|here we go|seek|looking for|explore|save|clear|wipe off|(^[0-9]+$)|(^x$)|close|search|(sign|log|verify|submit|ent(er|rar|rer|rance|ra)|acces(o|sar|s)|continu(er|ar)?)?.*(github|microsoft|facebook|google|twitter|linkedin|instagram|line)|keep([^0-9a-zA-Z]|\s)*me([^0-9a-zA-Z]|\s)*(signed|logged)([^0-9a-zA-Z]|\s)*(in|on)|having([^0-9a-zA-Z]|\s)*trouble|remember|subscribe|send([^0-9a-zA-Z]|\s)*me([^0-9a-zA-Z]|\s)*(message|(e)?mail|newsletter|update)|follow([^0-9a-zA-Z]|\s)*us|新規会員|%s" % SIGNUP
208 | # CREDENTIAL_TAKING_KEYWORDS = "log(g)?([^0-9a-zA-Z]|\s)*in(n)?|log([^0-9a-zA-Z]|\s)*on|sign([^0-9a-zA-Z]|\s)*in|sign([^0-9a-zA-Z]|\s)*on|submit|(my|personal)([^0-9a-zA-Z]|\s)*(account|area)|come([^0-9a-zA-Z]|\s)*in|check([^0-9a-zA-Z]|\s)*in|customer([^0-9a-zA-Z]|\s)*centre|登入|登录|登錄|登録|iniciar([^0-9a-zA-Z]|\s)*sesión|identifier|(ログインする)|(サインアップ)|(ログイン)|(로그인)|(시작하기)|(войти)|(вход)|(accedered)|(gabung)|(masuk)|(girişi)|(Giriş)|(وارد)|(عضویت)|(acceso)|(acessar)|(entrar )|(เข้าสู่ระบบ)|(Přihlásit)|(mein konto)|(anmelden)|(me connecter)|(ingresa)|(accedi)|(мой профиль)|(حسابي)|(administrer)|(next)|(entre )|(cadastre-se)|(είσοδος)|(entrance)|(start now)|(accessibilité)|(accéder)|(zaloguj)|(đăng nhập)|weitermachen|bestätigen|zověřit|ověřit"
209 | CREDENTIAL_TAKING_KEYWORDS = r"""
210 | (?:
211 | log(?:g)?in| # Matches 'login', 'loggin'
212 | log(?:g)?on| # Matches 'logon', 'loggon'
213 | sign(?:-|\s)?(?:in|on)| # Matches 'sign in', 'sign on', 'signin', 'signon', 'sign-in', 'sign-on'
214 | submit|apply|continue|update|
215 | (?:my|personal)(?:\W+)(?:account|area)|
216 | come(?:\W+)in| # Matches 'come in' with any non-word delimiters
217 | customer(?:\W+)centre| # Matches 'customer centre' with any non-word delimiters
218 | identifier|
219 | (?:get(?:\W+)started) # Matches 'get started' with any non-word delimiters
220 | )
221 | | # Alternatives in different languages
222 | 登入|登录|登錄|登録|
223 | iniciar(?:\W+)sesión|
224 | (?:ログインする)|(?:サインアップ)|(?:ログイン)|
225 | (?:로그인)|(?:시작하기)|
226 | (?:войти)|(?:вход)|
227 | (?:acceder(?:\W+)ed)|(?:gabung)|(?:masuk)|
228 | (?:giriş(?:i)?)|(?:وارد)|(?:عضویت)|
229 | (?:acceso)|(?:acessar)|(?:entrar)|
230 | (?:เข้าสู่ระบบ)|(?:Přihlásit)|
231 | (?:mein konto)|(?:anmelden)|(?:me connecter)|
232 | (?:ingresa)|(?:accedi)|(?:мой профиль)|
233 | (?:حسابي)|(?:administrer)|(?:next)|
234 | (?:entre)|(?:cadastre-se)|(?:είσοδος)|
235 | (?:entrance)|(?:start now)|(?:accessibilité)|
236 | (?:accéder)|(?:zaloguj)|(?:đăng nhập)|
237 | weitermachen|bestätigen|zověřit|ověřit
238 | """.strip()
239 | PROFILE = "account|profile|dashboard|settings"
240 |
241 | CAPTCHA = "(re)?captcha"
242 | CONSENT = "consent|gdp"
243 | COOKIES_CONSENT = "agree|accept"
244 |
245 | URL = "(?:(?:https?|ftp)://)(?:\S+(?::\S*)?@)?(?:(?!10(?:\.\d{1,3}){3})(?!127(?:\.\d{1,3}){3})(?!169\.254(?:\.\d{1,3}){2})(?!192\.168(?:\.\d{1,3}){2})(?!172\.(?:1[6-9]|2\d|3[0-1])(?:\.\d{1,3}){2})(?:[1-9]\d?|1\d\d|2[01]\d|22[0-3])(?:\.(?:1?\d{1,2}|2[0-4]\d|25[0-5])){2}(?:\.(?:[1-9]\d?|1\d\d|2[0-4]\d|25[0-4]))|(?:(?:[a-z\\x{00a1}\-\\x{ffff}0-9]+-?)*[a-z\\x{00a1}\-\\x{ffff}0-9]+)(?:\.(?:[a-z\\x{00a1}\-\\x{ffff}0-9]+-?)*[a-z\\x{00a1}\-\\x{ffff}0-9]+)*(?:\.(?:[a-z\\x{00a1}\-\\x{ffff}]{2,})))(?::\d{2,5})?(?:/[^\s]*)?"
246 |
247 | TIME = "([0-9]:){1,2}[0-9]"
248 | TIME_SCRIPT = "setHours|setMinutes|setSeconds"
249 |
250 | # try again error
251 | ERROR_TRY_AGAIN = ["try again|login failed|error logging in|login error|retry"]
252 |
253 | # username incorrect/not exist error
254 | ERROR_INCORRECT = ["(invalid|wrong|incorrect|unknown|no).*(id|credential|login|input|password|account|user(name)?|e(\-|_|\s)?mail|information|(pass)?code|(user([^0-9a-zA-Z]|\s)*)?id)(s)?",
255 | "(do(es)?|did)([^0-9a-zA-Z]|\s)*not match(([^0-9a-zA-Z]|\s)*our records)?",
256 | "limited access|verification failed|not registered|does not exist|access denied|coundn't find|you entered([^0-9a-zA-Z]|\s)*(isn't|doesn't)|(please)?([^0-9a-zA-Z]|\s)*enter a valid",
257 | "(account|password|user(name)?|e(\-|_|\s)?mail|credentials|sms|code)([^0-9a-zA-Z]|\s)*(provided|given|input([^0-9a-zA-Z]|\s)*)?((is incorrect)|(are incorrect)|(isn't right)|(isn't correct)|(doesn't exist)|(does not exist)|(not valid)|(is invalid)|(not recognized)|(were not found))",
258 | "(SMS-Code Fehler)|(SMS kód je neplatný)",
259 | "code incorrectly|no user found|username already taken",
260 | "(cannot|can't) be used|not allowed|must (contain|follow|specify)"
261 | "captcha was not answered correctly"
262 | ]
263 |
264 | # connection error
265 | ERROR_CONNECTION = ["connecting([^0-9a-zA-Z]|\s)*(to)?([^0-9a-zA-Z]|\s)*(mail)?([^0-9a-zA-Z]|\s)*server|connection is lost",
266 | # "(operation|page)([^0-9a-zA-Z]|\s)*((counldn't)|(could not)|cannot|(can not))([^0-9a-zA-Z]|\s)*be([^0-9a-zA-Z]|\s)*(completed|found)",
267 | # 'not found|forbidden|403|404|500|no permission|don\'t have permission'
268 | ]
269 |
270 | # File related
271 | ERROR_FILE = ["processing([^0-9a-zA-Z]|\s)*(your)?([^0-9a-zA-Z]|\s)*download",
272 | "file not found"]
273 |
274 | # anti-bot
275 | ERROR_BOT = ["(not a human)|captcha|(verify you are a human)|(press & hold)"]
276 |
277 |
278 |
279 |
280 |
--------------------------------------------------------------------------------
/scripts/phishintention/modules/logo_matching.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageOps
2 | from torchvision import transforms
3 | import torch
4 | from torch.backends import cudnn
5 | import os
6 | import numpy as np
7 | from collections import OrderedDict
8 | from tqdm import tqdm
9 | from tldextract import tldextract
10 | import pickle
11 |
12 | from ..utils.utils import brand_converter, resolution_alignment, l2_norm
13 | from .models2 import KNOWN_MODELS
14 | from ..ocr_lib.models.model_builder import ModelBuilder
15 | from ..ocr_lib.utils.labelmaps import get_vocabulary
16 |
17 | COUNTRY_TLDs = [
18 | ".af",
19 | ".ax",
20 | ".al",
21 | ".dz",
22 | ".as",
23 | ".ad",
24 | ".ao",
25 | ".ai",
26 | ".aq",
27 | ".ag",
28 | ".ar",
29 | ".am",
30 | ".aw",
31 | ".ac",
32 | ".au",
33 | ".at",
34 | ".az",
35 | ".bs",
36 | ".bh",
37 | ".bd",
38 | ".bb",
39 | ".eus",
40 | ".by",
41 | ".be",
42 | ".bz",
43 | ".bj",
44 | ".bm",
45 | ".bt",
46 | ".bo",
47 | ".bq",".an",".nl",
48 | ".ba",
49 | ".bw",
50 | ".bv",
51 | ".br",
52 | ".io",
53 | ".vg",
54 | ".bn",
55 | ".bg",
56 | ".bf",
57 | ".mm",
58 | ".bi",
59 | ".kh",
60 | ".cm",
61 | ".ca",
62 | ".cv",
63 | ".cat",
64 | ".ky",
65 | ".cf",
66 | ".td",
67 | ".cl",
68 | ".cn",
69 | ".cx",
70 | ".cc",
71 | ".co",
72 | ".km",
73 | ".cd",
74 | ".cg",
75 | ".ck",
76 | ".cr",
77 | ".ci",
78 | ".hr",
79 | ".cu",
80 | ".cw",
81 | ".cy",
82 | ".cz",
83 | ".dk",
84 | ".dj",
85 | ".dm",
86 | ".do",
87 | ".tl",".tp",
88 | ".ec",
89 | ".eg",
90 | ".sv",
91 | ".gq",
92 | ".er",
93 | ".ee",
94 | ".et",
95 | ".eu",
96 | ".fk",
97 | ".fo",
98 | ".fm",
99 | ".fj",
100 | ".fi",
101 | ".fr",
102 | ".gf",
103 | ".pf",
104 | ".tf",
105 | ".ga",
106 | ".gal",
107 | ".gm",
108 | ".ps",
109 | ".ge",
110 | ".de",
111 | ".gh",
112 | ".gi",
113 | ".gr",
114 | ".gl",
115 | ".gd",
116 | ".gp",
117 | ".gu",
118 | ".gt",
119 | ".gg",
120 | ".gn",
121 | ".gw",
122 | ".gy",
123 | ".ht",
124 | ".hm",
125 | ".hn",
126 | ".hk",
127 | ".hu",
128 | ".is",
129 | ".in",
130 | ".id",
131 | ".ir",
132 | ".iq",
133 | ".ie",
134 | ".im",
135 | ".il",
136 | ".it",
137 | ".jm",
138 | ".jp",
139 | ".je",
140 | ".jo",
141 | ".kz",
142 | ".ke",
143 | ".ki",
144 | ".kw",
145 | ".kg",
146 | ".la",
147 | ".lv",
148 | ".lb",
149 | ".ls",
150 | ".lr",
151 | ".ly",
152 | ".li",
153 | ".lt",
154 | ".lu",
155 | ".mo",
156 | ".mk",
157 | ".mg",
158 | ".mw",
159 | ".my",
160 | ".mv",
161 | ".ml",
162 | ".mt",
163 | ".mh",
164 | ".mq",
165 | ".mr",
166 | ".mu",
167 | ".yt",
168 | ".mx",
169 | ".md",
170 | ".mc",
171 | ".mn",
172 | ".me",
173 | ".ms",
174 | ".ma",
175 | ".mz",
176 | ".mm",
177 | ".na",
178 | ".nr",
179 | ".np",
180 | ".nl",
181 | ".nc",
182 | ".nz",
183 | ".ni",
184 | ".ne",
185 | ".ng",
186 | ".nu",
187 | ".nf",
188 | ".nc",".tr",
189 | ".kp",
190 | ".mp",
191 | ".no",
192 | ".om",
193 | ".pk",
194 | ".pw",
195 | ".ps",
196 | ".pa",
197 | ".pg",
198 | ".py",
199 | ".pe",
200 | ".ph",
201 | ".pn",
202 | ".pl",
203 | ".pt",
204 | ".pr",
205 | ".qa",
206 | ".ro",
207 | ".ru",
208 | ".rw",
209 | ".re",
210 | ".bq",".an",
211 | ".bl",".gp",".fr",
212 | ".sh",
213 | ".kn",
214 | ".lc",
215 | ".mf",".gp",".fr",
216 | ".pm",
217 | ".vc",
218 | ".ws",
219 | ".sm",
220 | ".st",
221 | ".sa",
222 | ".sn",
223 | ".rs",
224 | ".sc",
225 | ".sl",
226 | ".sg",
227 | ".bq",".an",".nl",
228 | ".sx",".an",
229 | ".sk",
230 | ".si",
231 | ".sb",
232 | ".so",
233 | ".so",
234 | ".za",
235 | ".gs",
236 | ".kr",
237 | ".ss",
238 | ".es",
239 | ".lk",
240 | ".sd",
241 | ".sr",
242 | ".sj",
243 | ".sz",
244 | ".se",
245 | ".ch",
246 | ".sy",
247 | ".tw",
248 | ".tj",
249 | ".tz",
250 | ".th",
251 | ".tg",
252 | ".tk",
253 | ".to",
254 | ".tt",
255 | ".tn",
256 | ".tr",
257 | ".tm",
258 | ".tc",
259 | ".tv",
260 | ".ug",
261 | ".ua",
262 | ".ae",
263 | ".uk",
264 | ".us",
265 | ".vi",
266 | ".uy",
267 | ".uz",
268 | ".vu",
269 | ".va",
270 | ".ve",
271 | ".vn",
272 | ".wf",
273 | ".eh",
274 | ".ma",
275 | ".ye",
276 | ".zm",
277 | ".zw"
278 | ]
279 |
280 | class DataInfo(object):
281 | """
282 | Save the info about the dataset.
283 | This a code snippet from dataset.py
284 | """
285 | def __init__(self, voc_type):
286 | super(DataInfo, self).__init__()
287 | self.voc_type = voc_type
288 |
289 | assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS']
290 | self.EOS = 'EOS'
291 | self.PADDING = 'PADDING'
292 | self.UNKNOWN = 'UNKNOWN'
293 | self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN)
294 | self.char2id = dict(zip(self.voc, range(len(self.voc))))
295 | self.id2char = dict(zip(range(len(self.voc)), self.voc))
296 |
297 | self.rec_num_classes = len(self.voc)
298 |
299 | def ocr_model_config(weights_path, height=None, width=None):
300 | np.random.seed(1234)
301 | torch.manual_seed(1234)
302 | torch.cuda.manual_seed(1234)
303 | torch.cuda.manual_seed_all(1234)
304 | cudnn.benchmark = True
305 | torch.backends.cudnn.deterministic = True
306 |
307 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
308 | if device == 'cuda':
309 | print('using cuda.')
310 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
311 | else:
312 | torch.set_default_tensor_type('torch.FloatTensor')
313 |
314 | # Create data loaders
315 | if height is None or width is None:
316 | height, width = (32, 100)
317 |
318 | dataset_info = DataInfo('ALLCASES_SYMBOLS')
319 |
320 | # Create model
321 | model = ModelBuilder(arch='ResNet_ASTER', rec_num_classes=dataset_info.rec_num_classes,
322 | sDim=512, attDim=512, max_len_labels=100,
323 | eos=dataset_info.char2id[dataset_info.EOS], STN_ON=True)
324 |
325 | # Load from checkpoint
326 | weights_path = torch.load(weights_path, map_location='cpu')
327 | model.load_state_dict(weights_path['state_dict'])
328 |
329 | if device == 'cuda':
330 | model = model.to(device)
331 |
332 | return model
333 |
334 | def siamese_model_config(num_classes: int, weights_path: str):
335 | # Initialize model
336 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
337 | model = KNOWN_MODELS["BiT-M-R50x1"](head_size=num_classes, zero_head=True)
338 |
339 | # Load weights
340 | weights = torch.load(weights_path, map_location='cpu')
341 | weights = weights['model'] if 'model' in weights.keys() else weights
342 | new_state_dict = OrderedDict()
343 | for k, v in weights.items():
344 | if k.startswith('module'):
345 | name = k.split('module.')[1]
346 | else:
347 | name = k
348 | new_state_dict[name] = v
349 |
350 | model.load_state_dict(new_state_dict)
351 | model.to(device)
352 | model.eval()
353 |
354 | return model
355 |
356 |
357 | def image_process(image_path, imgH=32, imgW=100, keep_ratio=False, min_ratio=1):
358 | img = Image.open(image_path).convert('RGB') if isinstance(image_path, str) else image_path.convert('RGB')
359 |
360 | if keep_ratio:
361 | w, h = img.size
362 | ratio = w / float(h)
363 | imgW = int(np.floor(ratio * imgH))
364 | imgW = max(imgH * min_ratio, imgW)
365 |
366 | img = img.resize((imgW, imgH), Image.BILINEAR)
367 | img = transforms.ToTensor()(img)
368 | img.sub_(0.5).div_(0.5)
369 |
370 | return img
371 |
372 |
373 | def ocr_main(image_path, model, height=None, width=None):
374 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
375 | # Evaluation
376 | model.eval()
377 |
378 | img = image_process(image_path)
379 | with torch.no_grad():
380 | img = img.to(device)
381 | input_dict = {}
382 | input_dict['images'] = img.unsqueeze(0)
383 |
384 | dataset_info = DataInfo('ALLCASES_SYMBOLS')
385 | rec_targets = torch.IntTensor(1, 100).fill_(1)
386 | rec_targets[:, 100 - 1] = dataset_info.char2id[dataset_info.EOS]
387 | input_dict['rec_targets'] = rec_targets.to(device)
388 | input_dict['rec_lengths'] = [100]
389 |
390 | with torch.no_grad():
391 | features, decoder_feat = model.features(input_dict)
392 | features = features.detach().cpu()
393 | decoder_feat = decoder_feat.detach().cpu()
394 | features = torch.mean(features, dim=1)
395 |
396 | return features
397 |
398 | @torch.no_grad()
399 | def get_ocr_aided_siamese_embedding(img, model, ocr_model, grayscale=False):
400 | '''
401 | Inference for a single image
402 | :param img: image path in str or image in PIL.Image
403 | :param model: Siamese model to make inference
404 | :param ocr_model: OCR model
405 | :param imshow: enable display of image or not
406 | :param title: title of displayed image
407 | :param grayscale: convert image to grayscale or not
408 | :return feature embedding of shape (2048,)
409 | '''
410 | img_size = 224
411 | mean = [0.5, 0.5, 0.5]
412 | std = [0.5, 0.5, 0.5]
413 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
414 |
415 | img_transforms = transforms.Compose(
416 | [transforms.ToTensor(),
417 | transforms.Normalize(mean=mean, std=std),
418 | ])
419 |
420 | img = Image.open(img) if isinstance(img, str) else img
421 | img = img.convert("RGBA").convert("L").convert("RGB") if grayscale else img.convert("RGBA").convert("RGB")
422 |
423 | ## Resize the image while keeping the original aspect ratio
424 | pad_color = 255 if grayscale else (255, 255, 255)
425 | img = ImageOps.expand(img, (
426 | (max(img.size) - img.size[0]) // 2, (max(img.size) - img.size[1]) // 2,
427 | (max(img.size) - img.size[0]) // 2, (max(img.size) - img.size[1]) // 2), fill=pad_color)
428 |
429 | img = img.resize((img_size, img_size))
430 |
431 | # Predict the embedding
432 | # get ocr embedding from pretrained paddleOCR
433 | with torch.no_grad():
434 | ocr_emb = ocr_main(image_path=img, model=ocr_model, height=None, width=None)
435 | ocr_emb = ocr_emb[0]
436 | ocr_emb = ocr_emb[None, ...].to(device) # remove batch dimension
437 |
438 | # Predict the embedding
439 | with torch.no_grad():
440 | img = img_transforms(img)
441 | img = img[None, ...].to(device)
442 | logo_feat = model.features(img, ocr_emb)
443 | logo_feat = l2_norm(logo_feat).squeeze(0).cpu().numpy() # L2-normalization final shape is (2560,)
444 |
445 | return logo_feat
446 |
447 | def chunked_dot(logo_feat_list, img_feat, chunk_size=128):
448 | sim_list = []
449 |
450 | for start in range(0, logo_feat_list.shape[0], chunk_size):
451 | end = start + chunk_size
452 | chunk = logo_feat_list[start:end]
453 | sim_chunk = np.dot(chunk, img_feat.T) # shape: (chunk_size, M)
454 | sim_list.extend(sim_chunk)
455 |
456 | return sim_list
457 |
458 | def pred_brand(model, ocr_model, domain_map, logo_feat_list, file_name_list, shot_path: str, pred_bbox, t_s, grayscale=False):
459 | '''
460 | Return predicted brand for one cropped image
461 | :param model: model to use
462 | :param domain_map: brand-domain dictionary
463 | :param logo_feat_list: reference logo feature embeddings
464 | :param file_name_list: reference logo paths
465 | :param shot_path: path to the screenshot
466 | :param pred_bbox: 1x4 np.ndarray/list/tensor bounding box coords
467 | :param t_s: similarity threshold for siamese
468 | :param grayscale: convert image(cropped) to grayscale or not
469 | :return: predicted target, predicted target's domain
470 | '''
471 |
472 | try:
473 | img = Image.open(shot_path)
474 | except OSError: # if the image cannot be identified, return nothing
475 | print('Screenshot cannot be open')
476 | return None, None, None
477 |
478 | ## get predicted box --> crop from screenshot
479 | cropped = img.crop((pred_bbox[0], pred_bbox[1], pred_bbox[2], pred_bbox[3]))
480 | img_feat = get_ocr_aided_siamese_embedding(cropped, model, ocr_model, grayscale=grayscale)
481 |
482 | ## get cosine similarity with every protected logo
483 | sim_list = chunked_dot(logo_feat_list, img_feat) # take dot product for every pair of embeddings (Cosine Similarity)
484 | pred_brand_list = file_name_list
485 |
486 | assert len(sim_list) == len(pred_brand_list)
487 |
488 | ## get top 3 brands
489 | idx = np.argsort(sim_list)[::-1][:3]
490 | pred_brand_list = np.array(pred_brand_list)[idx]
491 | sim_list = np.array(sim_list)[idx]
492 |
493 | # top1,2,3 candidate logos
494 | top3_logolist = [Image.open(x) for x in pred_brand_list]
495 | top3_brandlist = [brand_converter(os.path.basename(os.path.dirname(x))) for x in pred_brand_list]
496 | top3_domainlist = [domain_map[x] for x in top3_brandlist]
497 | top3_simlist = sim_list
498 |
499 | for j in range(3):
500 | predicted_brand, predicted_domain = None, None
501 |
502 | ## If we are trying those lower rank logo, the predicted brand of them should be the same as top1 logo, otherwise might be false positive
503 | if top3_brandlist[j] != top3_brandlist[0]:
504 | continue
505 |
506 | ## If the largest similarity exceeds threshold
507 | if top3_simlist[j] >= t_s:
508 | predicted_brand = top3_brandlist[j]
509 | predicted_domain = top3_domainlist[j]
510 | final_sim = top3_simlist[j]
511 |
512 | ## Else if not exceed, try resolution alignment, see if can improve
513 | else:
514 | cropped, candidate_logo = resolution_alignment(cropped, top3_logolist[j])
515 | img_feat = get_ocr_aided_siamese_embedding(cropped, model, ocr_model, grayscale=grayscale)
516 | logo_feat = get_ocr_aided_siamese_embedding(candidate_logo, model, ocr_model, grayscale=grayscale)
517 | final_sim = logo_feat.dot(img_feat)
518 | if final_sim >= t_s:
519 | predicted_brand = top3_brandlist[j]
520 | predicted_domain = top3_domainlist[j]
521 | else:
522 | break # no hope, do not try other lower rank logos
523 |
524 | ## If there is a prediction, do aspect ratio check
525 | if predicted_brand is not None:
526 | ratio_crop = cropped.size[0] / cropped.size[1]
527 | ratio_logo = top3_logolist[j].size[0] / top3_logolist[j].size[1]
528 | # aspect ratios of matched pair must not deviate by more than factor of 2.5
529 | if max(ratio_crop, ratio_logo) / min(ratio_crop, ratio_logo) > 2.5:
530 | continue # did not pass aspect ratio check, try other
531 | # If pass aspect ratio check, report a match
532 | else:
533 | return predicted_brand, predicted_domain, final_sim
534 |
535 | return None, None, top3_simlist[0]
536 |
537 | def cache_reference_list(model, ocr_model, targetlist_path: str, grayscale=False):
538 | '''
539 | cache the embeddings of the reference list
540 | '''
541 |
542 | # Prediction for targetlists
543 | logo_feat_list = []
544 | file_name_list = []
545 |
546 | for target in tqdm(os.listdir(targetlist_path)):
547 | if target.startswith('.'): # skip hidden files
548 | continue
549 | for logo_path in os.listdir(os.path.join(targetlist_path, target)):
550 | if logo_path.endswith('.png') or logo_path.endswith('.jpeg') or logo_path.endswith('.jpg') or logo_path.endswith('.PNG') \
551 | or logo_path.endswith('.JPG') or logo_path.endswith('.JPEG'):
552 | if logo_path.startswith('loginpage') or logo_path.startswith('homepage'): # skip homepage/loginpage
553 | continue
554 | logo_feat_list.append(get_ocr_aided_siamese_embedding(img=os.path.join(targetlist_path, target, logo_path),
555 | model=model,
556 | ocr_model=ocr_model,
557 | grayscale=grayscale))
558 | file_name_list.append(str(os.path.join(targetlist_path, target, logo_path)))
559 |
560 | return np.asarray(logo_feat_list), np.asarray(file_name_list)
561 |
562 | def check_domain_brand_inconsistency(logo_boxes,
563 | domain_map_path: str,
564 | model,
565 | ocr_model,
566 | logo_feat_list,
567 | file_name_list,
568 | shot_path: str,
569 | url: str,
570 | ts: float):
571 |
572 | # targetlist domain list
573 | with open(domain_map_path, 'rb') as handle:
574 | domain_map = pickle.load(handle)
575 |
576 | # look at boxes for logo class only
577 | print('number of logo boxes:', len(logo_boxes))
578 | suffix_part = '.'+ tldextract.extract(url).suffix
579 | domain_part = tldextract.extract(url).domain
580 | extracted_domain = domain_part + suffix_part
581 |
582 | matched_target, matched_domain, matched_coord, this_conf = None, None, None, None
583 |
584 |
585 | # run logo matcher
586 | if len(logo_boxes) > 0:
587 | # siamese prediction for logo box
588 | for i, coord in enumerate(logo_boxes):
589 | min_x, min_y, max_x, max_y = coord
590 | bbox = [float(min_x), float(min_y), float(max_x), float(max_y)]
591 | matched_target, matched_domain, this_conf = pred_brand(model, ocr_model, domain_map,
592 | logo_feat_list, file_name_list,
593 | shot_path, bbox, t_s=ts, grayscale=False)
594 |
595 | # domain matcher to avoid FP
596 | # if matched_target is not None:
597 | # matched_coord = coord
598 | # # if tldextract.extract(url).domain+ '.'+tldextract.extract(url).suffix not in matched_domain:
599 | # if tldextract.extract(url).domain not in matched_domain:
600 | # # avoid fp due to godaddy domain parking, ignore webmail provider (ambiguous)
601 | # if matched_target == 'GoDaddy' or matched_target == "Webmail Provider" or matched_target == "Government of the United Kingdom":
602 | # matched_target = None # ignore the prediction
603 | # matched_domain = None # ignore the prediction
604 | # else: # benign, real target
605 | # matched_target = None # ignore the prediction
606 | # matched_domain = None # ignore the prediction
607 | # break # break if target is matched
608 | # break # only look at 1st logo
609 | if (matched_target is not None) and (matched_domain is not None):
610 | matched_coord = coord
611 | matched_domain_parts = [tldextract.extract(x).domain for x in matched_domain]
612 | matched_suffix_parts = [tldextract.extract(x).suffix for x in matched_domain]
613 |
614 | # If the webpage domain exactly aligns with the target website's domain => Benign
615 | if extracted_domain in matched_domain:
616 | matched_target, matched_domain = None, None # Clear if domains are consistent
617 | elif domain_part in matched_domain_parts: # # elIf only the 2nd-level-domains align, and the tld is regional => Benign
618 | if "." + suffix_part.split('.')[-1] in COUNTRY_TLDs:
619 | matched_target, matched_domain = None, None
620 | else:
621 | break # Inconsistent domain found, break the loop
622 | else:
623 | break # Inconsistent domain found, break the loop
624 | break # only look at 1st logo
625 |
626 | return brand_converter(matched_target), matched_domain, matched_coord, this_conf
627 |
628 |
629 |
630 |
--------------------------------------------------------------------------------
/scripts/utils/web_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List, Optional, Dict, Set
2 | from numpy.typing import ArrayLike, NDArray
3 | from typing import Sequence, Tuple, Union
4 | from PIL import Image
5 | from selenium import webdriver
6 | from selenium.webdriver.chrome.service import Service
7 | from webdriver_manager.chrome import ChromeDriverManager
8 | from selenium.webdriver.chrome.options import Options
9 | import re
10 | import requests
11 | from concurrent.futures import ThreadPoolExecutor
12 | from selenium.common.exceptions import (
13 | NoSuchElementException,
14 | TimeoutException,
15 | StaleElementReferenceException,
16 | WebDriverException,
17 | JavascriptException
18 | )
19 | import os
20 | import io
21 | import time
22 | from typing import Optional, Tuple
23 | from selenium.webdriver.remote.webdriver import WebDriver
24 | from selenium.webdriver.remote.webelement import WebElement
25 | from selenium.webdriver.common.by import By
26 | from selenium.webdriver.common.action_chains import ActionChains
27 | from selenium.webdriver.support.ui import WebDriverWait
28 | from selenium.webdriver.support import expected_conditions as EC
29 | import numpy as np
30 | from .logger_utils import PhishLLMLogger
31 | import torch.nn as nn
32 | from functools import partial
33 | import logging
34 | from logging.handlers import RotatingFileHandler
35 | from tldextract import tldextract
36 |
37 | '''webdriver utils'''
38 | def _enable_python_logging(log_path: str = "selenium-debug.log") -> None:
39 | # Root logger (console + rotating file)
40 | root = logging.getLogger()
41 | if not root.handlers:
42 | logging.basicConfig(
43 | level=logging.DEBUG,
44 | format="%(asctime)s %(levelname)s:%(name)s:%(message)s"
45 | )
46 | fh = RotatingFileHandler(log_path, maxBytes=5_000_000, backupCount=3)
47 | fh.setLevel(logging.DEBUG)
48 | fh.setFormatter(logging.Formatter("%(asctime)s %(levelname)s:%(name)s:%(message)s"))
49 | root.addHandler(fh)
50 | logging.getLogger("selenium").setLevel(logging.DEBUG)
51 | logging.getLogger("urllib3").setLevel(logging.DEBUG)
52 |
53 | def boot_driver(
54 | python_log_file: Optional[str] = "selenium-debug.log",
55 | ) -> WebDriver:
56 | if python_log_file:
57 | _enable_python_logging(python_log_file)
58 | options = Options()
59 | options.add_argument("--headless")
60 | options.add_argument("--window-size=1920,1080") # set resolution
61 | options.add_argument("--no-sandbox") # (Linux) avoids sandbox issues
62 | options.add_argument("--disable-dev-shm-usage") # Fixes shared memory errors
63 | options.add_argument("--disable-gpu") # (Windows) GPU acceleration off in headless
64 | options.add_argument("--no-proxy-server")
65 | service = Service(
66 | ChromeDriverManager().install(),
67 | )
68 | driver = webdriver.Chrome(service=service, options=options)
69 | return driver
70 |
71 |
72 | def restart_driver(driver: WebDriver) -> WebDriver:
73 | driver.quit()
74 | time.sleep(2)
75 | return boot_driver()
76 |
77 | def is_valid_domain(domain: Union[str, None]) -> bool:
78 | # Regular expression to check if the string is a valid domain without spaces
79 | if domain is None:
80 | return False
81 | pattern = re.compile(
82 | r'^(?!-)' # Cannot start with a hyphen
83 | r'(?!.*--)' # Cannot have two consecutive hyphens
84 | r'(?!.*\.\.)' # Cannot have two consecutive periods
85 | r'(?!.*\s)' # Cannot contain any spaces
86 | r'[a-zA-Z0-9-]{1,63}' # Valid characters are alphanumeric and hyphen
87 | r'(?:\.[a-zA-Z]{2,})+$' # Ends with a valid top-level domain
88 | )
89 | it_is_a_domain = bool(pattern.fullmatch(domain))
90 | return it_is_a_domain
91 |
92 |
93 | # -- Robust domain extraction from free-form answers --
94 | def normalize_domain(text: str) -> Optional[str]:
95 | """
96 | Extract and normalize a domain from model output.
97 | Accepts bare domains possibly wrapped with punctuation or code fences.
98 | Returns eTLD+1 style if valid, else None.
99 | """
100 | if not text:
101 | return None
102 |
103 | # Common cleanup: strip code fences/quotes and trailing punctuation
104 | s = text.strip().strip("`'\" \t\r\n;,:.()[]{}")
105 | s = s.replace("http://", "").replace("https://", "").replace("www.", "")
106 | s = s.split()[0] # take the first token if multiple words
107 |
108 | # Prefer explicit domain-like substrings anywhere in the string
109 | candidates = re.findall(r'\b(?:[A-Za-z0-9-]+\.)+[A-Za-z]{2,}\b', text)
110 | if s not in candidates:
111 | candidates = [s] + candidates
112 |
113 | for cand in candidates:
114 | cand = cand.strip().lower().strip("`'\" \t\r\n;,:.()[]{}")
115 | # Validate via tldextract + your is_valid_domain helper
116 | try:
117 | ext = tldextract.extract(cand)
118 | dom = '.'.join(p for p in (ext.domain, ext.suffix) if p)
119 | except Exception:
120 | continue
121 | if dom and is_valid_domain(dom):
122 | return dom
123 | return None
124 |
125 | def url2logo(
126 | driver: WebDriver,
127 | url: str,
128 | logo_extractor: nn.Module
129 | ) -> Optional[Image.Image]:
130 |
131 | reference_logo = None
132 | try:
133 | driver.get(url) # Visit the webpage
134 | time.sleep(2)
135 | screenshot_path = "tmp.png"
136 | driver.get_screenshot_as_file(screenshot_path)
137 | logo_boxes = logo_extractor(screenshot_path)
138 | if len(logo_boxes):
139 | logo_coord = logo_boxes[0]
140 | screenshot_img = Image.open(screenshot_path).convert("RGB")
141 | reference_logo = screenshot_img.crop((int(logo_coord[0]), int(logo_coord[1]),
142 | int(logo_coord[2]), int(logo_coord[3])))
143 | os.remove(screenshot_path)
144 | except WebDriverException as e:
145 | print(f"Error accessing the webpage: {e}")
146 | except Exception as e:
147 | print(f"Failed to take screenshot: {e}")
148 | finally:
149 | driver = restart_driver(driver)
150 | return reference_logo
151 |
152 |
153 | def query2url(
154 | query: str,
155 | SEARCH_ENGINE_API: str,
156 | SEARCH_ENGINE_ID: str,
157 | num: int = 10,
158 | proxies: Optional[Dict] = None
159 | ) -> List[str]:
160 | '''
161 | Google Search
162 | '''
163 | if len(query) == 0:
164 | return []
165 |
166 | num = int(num)
167 | URL = f"https://www.googleapis.com/customsearch/v1?key={SEARCH_ENGINE_API}&cx={SEARCH_ENGINE_ID}&q={query}&num={num}&filter=1"
168 | while True:
169 | try:
170 | data = requests.get(URL, proxies=proxies).json()
171 | break
172 | except requests.exceptions.SSLError as e:
173 | print(e)
174 | time.sleep(1)
175 |
176 | if data.get('error', {}).get('code') == 429:
177 | raise RuntimeError("Google search exceeds quota limit")
178 |
179 | search_items = data.get("items")
180 | if search_items is None:
181 | return []
182 |
183 | returned_urls = [item.get("link") for item in search_items]
184 |
185 | return returned_urls
186 |
187 |
188 |
189 | def query2image(
190 | query: str,
191 | SEARCH_ENGINE_API: str,
192 | SEARCH_ENGINE_ID: str,
193 | num: int = 10,
194 | proxies: Optional[Dict] = None
195 | ) -> List[str]:
196 | '''
197 | Google Image Search
198 | '''
199 | if len(query) == 0:
200 | return []
201 |
202 | num = int(num)
203 | URL = f"https://www.googleapis.com/customsearch/v1?key={SEARCH_ENGINE_API}&cx={SEARCH_ENGINE_ID}&q={query}&searchType=image&num={num}&filter=1"
204 | while True:
205 | try:
206 | data = requests.get(URL, proxies=proxies).json()
207 | break
208 | except requests.exceptions.SSLError as e:
209 | print(e)
210 | time.sleep(1)
211 |
212 | if data.get('error', {}).get('code') == 429:
213 | raise RuntimeError("Google search exceeds quota limit")
214 |
215 | returned_urls = [item.get("image")["thumbnailLink"] for item in data.get("items", [])]
216 |
217 | return returned_urls
218 |
219 |
220 | def download_image(
221 | url: str,
222 | proxies: Optional[Dict] = None
223 | ) -> Optional[Image.Image]:
224 |
225 | try:
226 | response = requests.get(url, proxies=proxies, timeout=5)
227 | if response.status_code == 200:
228 | img = Image.open(io.BytesIO(response.content))
229 | return img
230 | except requests.exceptions.Timeout:
231 | print("Request timed out after", 5, "seconds.")
232 | except requests.exceptions.RequestException as e:
233 | print(f"An error occurred while downloading image: {e}")
234 |
235 | return None
236 |
237 |
238 | def get_images(
239 | image_urls: List[str],
240 | proxies: Optional[Dict] = None
241 | ) -> List[Image.Image]:
242 |
243 | images = []
244 | if len(image_urls) > 0:
245 | with ThreadPoolExecutor(max_workers=len(image_urls)) as executor:
246 | futures = [executor.submit(download_image, url, proxies) for url in image_urls]
247 | for future in futures:
248 | img = future.result()
249 | if img:
250 | images.append(img)
251 |
252 | return images
253 |
254 |
255 | def is_alive_domain(
256 | domain: str,
257 | proxies: Optional[Dict] = None
258 | ) -> bool:
259 | try:
260 | response = requests.head('https://www.' + domain, timeout=10, proxies=proxies) # Reduced timeout and used HEAD
261 | PhishLLMLogger.spit(f'Domain {domain}, status code {response.status_code}',
262 | caller_prefix=PhishLLMLogger._caller_prefix, debug=True)
263 | if response.status_code < 400 or response.status_code in [405, 429] or response.status_code >= 500:
264 | PhishLLMLogger.spit(f'Domain {domain} is valid and alive', caller_prefix=PhishLLMLogger._caller_prefix,
265 | debug=True)
266 | return True
267 | elif response.history and any([r.status_code < 400 for r in response.history]):
268 | PhishLLMLogger.spit(f'Domain {domain} is valid and alive', caller_prefix=PhishLLMLogger._caller_prefix,
269 | debug=True)
270 | return True
271 |
272 | except Exception as err:
273 | PhishLLMLogger.spit(f'Error {err} when checking the aliveness of domain {domain}',
274 | caller_prefix=PhishLLMLogger._caller_prefix, debug=True)
275 | return False
276 |
277 | PhishLLMLogger.spit(f'Domain {domain} is invalid or dead', caller_prefix=PhishLLMLogger._caller_prefix, debug=True)
278 | return False
279 |
280 | def has_page_content_changed(
281 | curr_screenshot_elements: List[int],
282 | prev_screenshot_elements: List[int]
283 | )-> bool:
284 | bincount_prev_elements = np.bincount(prev_screenshot_elements)
285 | bincount_curr_elements = np.bincount(curr_screenshot_elements)
286 | set_of_elements = min(len(bincount_prev_elements), len(bincount_curr_elements))
287 | screenshot_ele_change_ts = np.sum(
288 | bincount_prev_elements) // 2 # half the different UI elements distribution has changed
289 |
290 | if np.sum(np.abs(bincount_curr_elements[:set_of_elements] - bincount_prev_elements[
291 | :set_of_elements])) > screenshot_ele_change_ts:
292 | PhishLLMLogger.spit(f"Webpage content has changed", caller_prefix=PhishLLMLogger._caller_prefix, debug=True)
293 | return True
294 | else:
295 | PhishLLMLogger.spit(f"Webpage content didn't change", caller_prefix=PhishLLMLogger._caller_prefix, debug=True)
296 | return False
297 |
298 |
299 | def screenshot_element(
300 | elem: WebElement,
301 | dom: str,
302 | driver: WebDriver
303 | ) -> Tuple[Optional[str],
304 | Optional[Image.Image],
305 | Optional[str]]:
306 | """
307 | Returns:
308 | (candidate_ui, ele_screenshot_img, candidate_ui_text)
309 | - candidate_ui: the clickable_dom you passed in (or None on failure)
310 | - ele_screenshot_img: PIL.Image.Image of the element (or None on failure)
311 | - candidate_ui_text: element text/value (or None)
312 | """
313 | candidate_ui = None
314 | ele_screenshot_img = None
315 | candidate_ui_text = None
316 |
317 | try:
318 | # Scroll to top (plain Selenium)
319 | driver.execute_script("window.scrollTo(0, 0);")
320 |
321 | # Ensure the element is in view (center it to reduce cropping issues)
322 | try:
323 | driver.execute_script("arguments[0].scrollIntoView({block:'center', inline:'center'});", elem)
324 | except Exception:
325 | pass
326 |
327 | # Basic visibility by rect
328 | rect = elem.rect # {'x','y','width','height'} in CSS pixels
329 | w, h = rect.get("width", 0), rect.get("height", 0)
330 | if w <= 0 or h <= 0:
331 | return candidate_ui, ele_screenshot_img, candidate_ui_text
332 |
333 | # Preferred path: Selenium can screenshot elements directly
334 | try:
335 | png = elem.screenshot_as_png # bytes
336 | ele_screenshot_img = Image.open(io.BytesIO(png))
337 | candidate_ui = dom
338 | etext = (elem.text or "") # visible text
339 | if not etext:
340 | etext = elem.get_attribute("value") or ""
341 | candidate_ui_text = etext
342 | return candidate_ui, ele_screenshot_img, candidate_ui_text
343 |
344 | except (WebDriverException, StaleElementReferenceException):
345 | pass
346 |
347 | try:
348 | # Scroll offsets + device pixel ratio for accurate cropping
349 | sx, sy, dpr = driver.execute_script(
350 | "return [window.scrollX, window.scrollY, window.devicePixelRatio || 1];"
351 | )
352 |
353 | # Re-fetch rect in case it changed after scroll
354 | rect = elem.rect
355 | x, y, w, h = rect["x"], rect["y"], rect["width"], rect["height"]
356 |
357 | # Convert page coords -> viewport coords, then scale by DPR
358 | left = int((x - sx) * dpr)
359 | top = int((y - sy) * dpr)
360 | right = int((x - sx + w) * dpr)
361 | bottom = int((y - sy + h) * dpr)
362 |
363 | # Take a viewport screenshot and crop
364 | viewport_png = driver.get_screenshot_as_png()
365 | image = Image.open(io.BytesIO(viewport_png))
366 |
367 | # Clamp to image bounds
368 | left = max(0, min(left, image.width))
369 | top = max(0, min(top, image.height))
370 | right = max(0, min(right, image.width))
371 | bottom = max(0, min(bottom, image.height))
372 |
373 | if right > left and bottom > top:
374 | ele_screenshot_img = image.crop((left, top, right, bottom))
375 | candidate_ui = dom
376 | etext = (elem.text or "")
377 | if not etext:
378 | etext = elem.get_attribute("value") or ""
379 | candidate_ui_text = etext
380 |
381 | except Exception as e2:
382 | print(f"Error processing element {dom} (crop fallback): {e2}")
383 |
384 | except Exception as e:
385 | print(f"Error accessing element {dom}: {e}")
386 |
387 | return candidate_ui, ele_screenshot_img, candidate_ui_text
388 |
389 |
390 | def get_all_clickable_elements(
391 | driver: WebDriver
392 | ) -> Tuple[Tuple[List[WebElement], List[str]],
393 | Tuple[List[WebElement], List[str]],
394 | Tuple[List[WebElement], List[str]],
395 | Tuple[List[WebElement], List[str]]]:
396 | """
397 | Collect clickable elements using plain Selenium:
398 | - Buttons (