├── .github
└── FUNDING.yml
├── .gitignore
├── LICENSE
├── README.md
├── datasets
└── .gitignore
├── extension
├── dev
│ ├── TODO.md
│ ├── model
│ │ └── .gitignore
│ ├── ort-wasm-simd.wasm
│ └── scripts
│ │ ├── model.js
│ │ ├── ort.js
│ │ └── tokenizer.js
├── public
│ ├── icons
│ │ ├── icon-128x128.png
│ │ └── icon.png
│ ├── manifest.json
│ ├── popup
│ │ ├── popup.css
│ │ ├── popup.html
│ │ └── popup.js
│ ├── scripts
│ │ ├── content.js
│ │ ├── defaults.js
│ │ ├── detection.js
│ │ ├── emojis.js
│ │ ├── labels.js
│ │ ├── preprocess.js
│ │ ├── storage.js
│ │ └── utils.js
│ └── styles
│ │ └── style.css
└── store
│ └── teaser.png
├── output
└── .gitignore
├── requirements.txt
└── src
├── cluster.py
├── database.py
├── dataset.py
├── detection_rules.py
├── downloader.py
├── labels.py
├── model.py
├── moderate.py
├── predict.py
├── preprocess.py
├── shared.py
├── train.py
├── urls.py
└── youtube.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | github: xenova
2 | ko_fi: xenova
3 | custom: https://www.buymeacoffee.com/xenova
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | venv
2 | __pycache__
3 | .vscode
4 | *.env
5 | *.db
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Xenova
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CommentBlock
2 |
3 | CommentBlock is an open-source browser extension that automatically blocks spam/scam YouTube comments. Download for [Chrome/Chromium](https://chrome.google.com/webstore/detail/pnhkbjdbaioddkchelkolhbhcmlibjfb) or [Firefox](https://addons.mozilla.org/en-US/firefox/addon/ytcommentblock/).
4 |
5 | 
6 |
7 | ## Examples
8 | Want to see the extension in action? Here are some videos to test it on:
9 |
10 | 1. [The Flash - No Way Home For The DCU?](https://www.youtube.com/watch?v=JG0QV40FMdQ) | The Critical Drinker — If you sort comments by new, you will see a ton of scam comments (mostly as replies). The extension does a very good job at detecting these.
11 |
12 |
13 | 2. [Inside the Largest Bitcoin Mine in The U.S.](https://www.youtube.com/watch?v=x9J0NdV0u9k) | WIRED — Almost every single comment on this video is a scam comment from a bot. This is of course due to the subject matter: *Crypto*. Although the extension does a good job blocking the obvious scams, botters have gotten a lot smarter recently. In particular, they start long comment threads (pretending to be real conversations between real people) and eventually prompt readers to contact someone off of YouTube. Detection for these comment threads will be much better with neural networks (see [below](#development-plans))!
14 |
15 |
16 | ## Development Plans
17 |
18 | Neural-network based detection is also in development to catch more advanced spam comments and comment threads. In particular, we aim to use unsupervised clustering techniques to group similar comments together, assign labels to comments in these groups, and then train classification models using the labelled data.
19 |
20 | ## Contributions
21 |
22 | At the moment, the extension uses rules to determine the type of a comment (e.g., spam, scam, explicit, or links). So, there may be cases where the extension misses a bad comment, or blocks a valid comment. In either case, feel free to open an [issue](https://github.com/xenova/commentblock/issues/new/choose) (including a link to the comment, which can be retrieved by right-clicking the time posted and copying the link), and we will update the ruleset to account for this.
23 |
24 | ## Credit
25 | Inspired by [ThioJoe's Spammer Purge](https://github.com/ThioJoe/YT-Spammer-Purge) tool.
26 |
--------------------------------------------------------------------------------
/datasets/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
5 |
--------------------------------------------------------------------------------
/extension/dev/TODO.md:
--------------------------------------------------------------------------------
1 | ## TODO (Upcoming features):
2 |
3 | ### 1. Add to manifest.json:
4 | ```
5 | "web_accessible_resources": [{
6 | "resources": [
7 | "model/*",
8 | "ort-wasm-simd.wasm"
9 | ],
10 | "matches": [
11 | "*://*.youtube.com/*"
12 | ]
13 | }],
14 | ...
15 | "content_scripts": [{
16 | "js": [
17 | "scripts/defaults.js",
18 | "scripts/utils.js",
19 | "scripts/storage.js",
20 | "scripts/ort.js",
21 | "scripts/labels.js",
22 | "scripts/preprocess.js",
23 | "scripts/tokenizer.js",
24 | "scripts/model.js",
25 | "scripts/emojis.js",
26 | "scripts/detection.js",
27 | "scripts/content.js"
28 | ],
29 | }],
30 | ```
--------------------------------------------------------------------------------
/extension/dev/model/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
5 |
--------------------------------------------------------------------------------
/extension/dev/ort-wasm-simd.wasm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xenova/commentblock/98cf557c16f3893a488350c94abe9d01d8686b62/extension/dev/ort-wasm-simd.wasm
--------------------------------------------------------------------------------
/extension/dev/scripts/model.js:
--------------------------------------------------------------------------------
1 |
2 | let ModelFactory = (function () {
3 | // https://stackoverflow.com/a/4842961
4 |
5 | async function createModelPromise() {
6 | console.log('Start loading model');
7 | // Load the tokenizer and model.
8 | const tokenizerURL = chrome.runtime.getURL("model/tokenizer.json");
9 | const modelURL = chrome.runtime.getURL("model/model.onnx");
10 |
11 | let model = new Model(tokenizerURL, modelURL);
12 | await model.load() // Wait for model to load fully
13 | console.log('Model finished loading');
14 | return model
15 | }
16 |
17 | let instance;
18 | return {
19 | getInstance: async function () {
20 | if (instance == null) {
21 | instance = createModelPromise();
22 | }
23 | return await instance;
24 | }
25 | };
26 | })();
27 |
28 |
29 | class Model {
30 | constructor(tokenizerURL, modelURL) {
31 | this.tokenizerURL = tokenizerURL;
32 | this.modelURL = modelURL;
33 | this.labels = Object.values(COMMENT_LABEL)
34 | }
35 | async load() {
36 | // Load tokenizer
37 | let tokenizer = new WordPieceTokenizer()
38 | await tokenizer.load(this.tokenizerURL)
39 |
40 | this.tokenizer = tokenizer;
41 |
42 | // Load model
43 | let response = await fetch(this.modelURL, {
44 | cache: 'force-cache'
45 | });
46 | let modelBuffer = await response.arrayBuffer();
47 | this.session = await ort.InferenceSession.create(modelBuffer, {
48 | executionProviders: ["wasm"]
49 | });
50 | }
51 |
52 | create_model_input(encoded) {
53 | // TODO optimise this
54 | // Adapted from https://github.com/jobergum/browser-ml-inference/blob/main/src/inference.js
55 | // (https://www.youtube.com/watch?v=W_lUGPMW_Eg)
56 |
57 | var input_ids = new Array(encoded.length + 2);
58 | var attention_mask = new Array(encoded.length + 2);
59 | input_ids[0] = BigInt(101); // [CLS]
60 | attention_mask[0] = BigInt(1);
61 | var i = 0;
62 | for (; i < encoded.length; i++) {
63 | input_ids[i + 1] = BigInt(encoded[i]);
64 | attention_mask[i + 1] = BigInt(1);
65 | }
66 | input_ids[i + 1] = BigInt(102); // [SEP]
67 | attention_mask[i + 1] = BigInt(1);
68 | const sequence_length = input_ids.length;
69 | input_ids = new ort.Tensor('int64', BigInt64Array.from(input_ids), [1, sequence_length]);
70 | attention_mask = new ort.Tensor('int64', BigInt64Array.from(attention_mask), [1, sequence_length]);
71 | return {
72 | input_ids: input_ids,
73 | attention_mask: attention_mask
74 | }
75 | }
76 | preprocess(authorName, commentText) {
77 | // Normalise author name and comment text
78 | authorName = this.tokenizer.normalize(authorName);
79 | commentText = this.tokenizer.normalize(commentText);
80 |
81 | return `${authorName} commented ${commentText}`;
82 |
83 | }
84 | async predict(authorName, commentText) {
85 | let text = this.preprocess(authorName, commentText);
86 | // console.log('text', text)
87 |
88 | let encoded = this.tokenizer.call(text);
89 |
90 | let model_input = this.create_model_input(encoded);
91 | // console.log('model_input', model_input)
92 | let output = await this.session.run(model_input);
93 |
94 | // console.log('output.logits', output.logits.data)
95 | let maxIndex = indexOfMax(output.logits.data);
96 | // console.log('maxIndex', maxIndex)
97 |
98 | let prediction = this.labels[maxIndex]
99 | // console.log('prediction', prediction)
100 | return prediction;
101 |
102 | }
103 | }
104 |
105 | function indexOfMax(arr) {
106 | // https://stackoverflow.com/a/11301464
107 |
108 | if (arr.length === 0) {
109 | return -1;
110 | }
111 |
112 | var max = arr[0];
113 | var maxIndex = 0;
114 |
115 | for (var i = 1; i < arr.length; i++) {
116 | if (arr[i] > max) {
117 | maxIndex = i;
118 | max = arr[i];
119 | }
120 | }
121 |
122 | return maxIndex;
123 | }
124 |
--------------------------------------------------------------------------------
/extension/dev/scripts/tokenizer.js:
--------------------------------------------------------------------------------
1 | "use strict";
2 |
3 | class WordPieceTokenizer {
4 | constructor() {
5 | this.separator = '[SEP]';
6 | this.unknown_token = '[UNK]';
7 | }
8 | async load(vocabUrl) {
9 |
10 | let v = await this.loadVocab(vocabUrl);
11 |
12 | this.vocabMapping = v.model.vocab;
13 |
14 | let tempVocab = {};
15 | // Reverse vocab
16 | for (const [key, value] of Object.entries(this.vocabMapping)) {
17 | tempVocab[parseInt(value)] = key;
18 | }
19 |
20 | this.vocab = Object.values(tempVocab)
21 | }
22 | async loadVocab(url) {
23 | const response = await fetch(url);
24 | return await response.json();
25 | }
26 |
27 | normalize(string) {
28 | return normalize(string)
29 | }
30 |
31 | pretokenize(text) {
32 | return text.trim().match(/\b\w+\b|[^\s\w]/g) || [];
33 | }
34 | tokenize(text) {
35 | var outputTokens = [];
36 |
37 | // whitespace_tokenize
38 | let tokens = this.pretokenize(text);
39 |
40 | for (let token of tokens) {
41 | let chars = [...token];
42 | // if len(chars) > self.max_input_chars_per_word:
43 | // output_tokens.append(self.unk_token)
44 | // continue
45 |
46 | let isUnknown = false;
47 | let start = 0;
48 | let subTokens = [];
49 |
50 | while (start < chars.length) {
51 | var end = chars.length;
52 | var currentSubstring = null;
53 | while (start < end) {
54 | var substr = chars.slice(start, end).join('');
55 |
56 | if (start > 0) {
57 | substr = '##' + substr
58 | }
59 | if (this.vocab.includes(substr)) {
60 | currentSubstring = substr;
61 | break;
62 | }
63 |
64 | --end;
65 | }
66 | if (currentSubstring == null) {
67 | isUnknown = true;
68 | break;
69 | }
70 | subTokens.push(currentSubstring);
71 | start = end;
72 | }
73 | if (isUnknown) {
74 | outputTokens.push(this.unknown_token);
75 | } else {
76 | outputTokens = outputTokens.concat(subTokens);
77 | }
78 | }
79 |
80 | return outputTokens;
81 | }
82 | convert_tokens_to_ids(outputTokens) {
83 | // get ids
84 | let ids = [];
85 | for (let t of outputTokens) {
86 | ids.push(this.vocabMapping[t]);
87 | }
88 | return ids;
89 | }
90 |
91 | call(text) {
92 | return this.convert_tokens_to_ids(this.tokenize(text));
93 | }
94 | }
95 |
96 |
--------------------------------------------------------------------------------
/extension/public/icons/icon-128x128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xenova/commentblock/98cf557c16f3893a488350c94abe9d01d8686b62/extension/public/icons/icon-128x128.png
--------------------------------------------------------------------------------
/extension/public/icons/icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xenova/commentblock/98cf557c16f3893a488350c94abe9d01d8686b62/extension/public/icons/icon.png
--------------------------------------------------------------------------------
/extension/public/manifest.json:
--------------------------------------------------------------------------------
1 | {
2 | "manifest_version": 3,
3 | "name": "CommentBlock",
4 | "description": "Block spam/scam YouTube comments",
5 | "version": "0.0.3",
6 | "permissions": [
7 | "storage"
8 | ],
9 | "host_permissions": [
10 | "*://*.youtube.com/*"
11 | ],
12 | "content_scripts": [
13 | {
14 | "css": [
15 | "styles/style.css"
16 | ],
17 | "js": [
18 | "scripts/defaults.js",
19 | "scripts/utils.js",
20 | "scripts/storage.js",
21 | "scripts/labels.js",
22 | "scripts/preprocess.js",
23 | "scripts/emojis.js",
24 | "scripts/detection.js",
25 | "scripts/content.js"
26 | ],
27 | "matches": [
28 | "https://*.youtube.com/*"
29 | ]
30 | }
31 | ],
32 | "action": {
33 | "default_icon": {
34 | "16": "icons/icon.png",
35 | "24": "icons/icon.png",
36 | "32": "icons/icon.png",
37 | "128": "icons/icon-128x128.png"
38 | },
39 | "default_title": "CommentBlock",
40 | "default_popup": "popup/popup.html"
41 | },
42 | "icons": {
43 | "16": "icons/icon.png",
44 | "32": "icons/icon.png",
45 | "64": "icons/icon.png",
46 | "128": "icons/icon-128x128.png"
47 | },
48 | "browser_specific_settings": {
49 | "gecko": {
50 | "id": "commentblock@xenova.com"
51 | }
52 | }
53 | }
--------------------------------------------------------------------------------
/extension/public/popup/popup.css:
--------------------------------------------------------------------------------
1 | * {
2 | margin: 0;
3 | padding: 0;
4 | box-sizing: border-box;
5 | font-family: 'Open Sans', sans-serif;
6 | font-size: 14px;
7 | }
8 |
9 | h2 {
10 | margin: 8px 0 2px 0;
11 | font-size: 18px;
12 | }
13 |
14 | /* Style the buttons that are used to open and close the accordion panel */
15 | .accordion {
16 | background-color: #eee;
17 | color: #444;
18 | /* cursor: pointer; */
19 | /* width: 100%; */
20 | text-align: left;
21 | border: none;
22 | outline: none;
23 | transition: 0.4s;
24 |
25 | }
26 |
27 | /* Add a background color to the button if it is clicked on (add the .active class with JS), and when you move the mouse over it (hover) */
28 | .accordion:hover {
29 | background-color: #ccc;
30 | }
31 |
32 | /* Style the accordion panel. Note: hidden by default */
33 | .panel {
34 | background-color: white;
35 | max-height: 0;
36 | overflow: hidden;
37 | transition: max-height 0.2s ease-out;
38 | /* padding:8px 16px; */
39 | /* margin: 10px; */
40 | }
41 |
42 | .panel>.inner-panel {
43 | padding: 8px 16px 16px 16px;
44 |
45 | }
46 |
47 | fieldset {
48 | margin: 4px 0;
49 | padding: 4px 8px;
50 | }
51 |
52 | fieldset>legend {
53 | margin-left: 5px;
54 | padding: 0 5px;
55 | }
56 |
57 | .top {
58 | position: relative;
59 | padding: 8px 16px;
60 |
61 | }
62 |
63 | .top:after {
64 | position: absolute;
65 | right: 16px;
66 | content: '\02795';
67 | /* Unicode character for "plus" sign (+) */
68 | font-size: 13px;
69 | color: #777;
70 | float: right;
71 | margin-left: 5px;
72 | }
73 |
74 | .top>label,
75 | .rule>label {
76 | margin-left: 5px;
77 | }
78 |
79 | .active:after {
80 | content: "\2796";
81 | /* Unicode character for "minus" sign (-) */
82 | }
83 |
84 | .center-vertical {
85 | display: flex;
86 | align-items: center;
87 | }
88 |
89 |
90 | .main-container {
91 | padding: 16px 12px;
92 | min-height: 500px;
93 | max-height: 500px;
94 |
95 | min-width: 375px;
96 | max-width: 375px;
97 | overflow: overlay;
98 | /* overflow: scroll; */
99 | }
100 |
101 |
102 |
103 | /* width */
104 | ::-webkit-scrollbar {
105 | width: 8px;
106 | height: 8px;
107 | }
108 |
109 | /* Track */
110 | ::-webkit-scrollbar-track {
111 | background: transparent;
112 | /* box-shadow: inset 0 0 5px #dddddd; */
113 | border-radius: 4px;
114 | border-left: 2px solid transparent;
115 | border-right: 2px solid transparent;
116 | }
117 |
118 | /* Handle */
119 | ::-webkit-scrollbar-thumb {
120 | background-color: #C1C1C1;
121 | border-radius: 4px;
122 | }
123 |
124 | .heading {
125 | display: flex;
126 | flex-direction: column;
127 | justify-content: center;
128 | align-items: center;
129 | }
130 |
131 | .heading>h1 {
132 | font-size: 36px;
133 | font-weight: 700;
134 | line-height: 100%;
135 | }
136 |
137 | .heading>h5 {
138 | font-size: 12px;
139 | font-weight: 500;
140 | }
141 |
142 | .bottom{
143 | padding: 16px 8px;
144 | display: flex;
145 | justify-content: center;
146 | align-items: center;
147 | }
148 | #reset:hover{
149 | background-color: rgba(51, 51, 51, 0.15);
150 |
151 | }
152 | #reset {
153 | background-color: rgba(51, 51, 51, 0.05);
154 | border-radius: 8px;
155 | border-width: 0;
156 | cursor: pointer;
157 | display: inline-block;
158 | font-size: 14px;
159 | font-weight: 500;
160 | line-height: 20px;
161 | list-style: none;
162 | margin: 0;
163 | padding: 10px 12px;
164 | text-align: center;
165 | transition: all 200ms;
166 | vertical-align: baseline;
167 | white-space: nowrap;
168 | user-select: none;
169 | -webkit-user-select: none;
170 | touch-action: manipulation;
171 | }
--------------------------------------------------------------------------------
/extension/public/popup/popup.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | Document
9 |
10 |
11 |
12 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
CommentBlock
30 | (v0.0.3)
31 |
32 |
33 |
Categories
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
Actions
182 |
183 |
184 |
188 |
189 |
190 |
Settings
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
Banned users/phrases
201 |
202 | Coming soon...
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
--------------------------------------------------------------------------------
/extension/public/popup/popup.js:
--------------------------------------------------------------------------------
1 | (function () {
2 | var acc = document.getElementsByClassName("top");
3 | var i;
4 |
5 | for (i = 0; i < acc.length; i++) {
6 | acc[i].addEventListener("click", function (e) {
7 | if (e.target.tagName === 'INPUT' || e.target.tagName === 'LABEL') {
8 | // Using checkbox
9 | return;
10 | }
11 | this.classList.toggle("active");
12 | var panel = this.nextElementSibling;
13 | if (panel.style.maxHeight) {
14 | panel.style.maxHeight = null;
15 | } else {
16 | panel.style.maxHeight = panel.scrollHeight + "px";
17 | }
18 | });
19 | }
20 | }());
21 |
22 | function resetSettings() {
23 | // https://developer.chrome.com/extensions/storage#method-StorageArea-clear
24 | chrome.storage.local.set(DEFAULT_OPTIONS)
25 | updateUIFromStorage();
26 |
27 | }
28 |
29 | function updateUIFromStorage() {
30 | document.querySelectorAll('input[setting]').forEach(async function (i) {
31 | let setting = await getSetting(i.id)
32 |
33 | if (i.type === 'checkbox') {
34 | i.checked = setting;
35 | } else {
36 | alert(`Undefined input type: ${i.type}`)
37 | }
38 | });
39 |
40 | document.querySelectorAll('select').forEach(async function (i) {
41 | let setting = await getSetting(i.id)
42 | i.value = setting;
43 | });
44 | }
45 | document.addEventListener('DOMContentLoaded', function () {
46 | updateUIFromStorage();
47 |
48 | document.getElementById('reset').addEventListener('click', resetSettings)
49 |
50 | document.querySelectorAll('input[setting]').forEach(async function (i) {
51 | i.addEventListener('input', function (e) {
52 | chrome.storage.local.set({ [i.id]: i.checked });
53 | });
54 | })
55 | document.querySelectorAll('select').forEach(async function (i) {
56 | i.addEventListener('change', function (e) {
57 | chrome.storage.local.set({ [i.id]: i.value });
58 | });
59 | })
60 |
61 | }, false);
62 |
63 |
64 |
65 | // chrome.storage.onChanged.addListener(function (changes, namespace) {
66 | // for (var key in changes) {
67 | // var storageChange = changes[key];
68 | // console.log('Storage key "%s" in namespace "%s" changed. ' + 'Old value was "%s", new value is "%s".', key, namespace, storageChange.oldValue, storageChange.newValue);
69 | // }
70 | // });
71 |
--------------------------------------------------------------------------------
/extension/public/scripts/content.js:
--------------------------------------------------------------------------------
1 |
2 |
3 | const LABEL_RULES_MAPPING = {
4 | // prediction: rule | [rules]
5 | // if any of the rules is enabled,
6 | [COMMENT_LABEL.SCAM]: "rules-enabled-scam",
7 | [COMMENT_LABEL.EXPLICIT]: "rules-enabled-explicit",
8 | [COMMENT_LABEL.LINK_SPAM]: ["rules-links-spam", "rules-links-contains"],
9 | [COMMENT_LABEL.LINK_ONLY]: ["rules-links-only", "rules-links-contains"],
10 | [COMMENT_LABEL.LINK_CONTAINS]: "rules-links-contains",
11 | [COMMENT_LABEL.SELF_PROMO]: "rules-enabled-selfpromo",
12 | [COMMENT_LABEL.OTHER_PROMO]: "rules-enabled-otherpromo",
13 | [COMMENT_LABEL.SPONSOR]: "rules-enabled-sponsor",
14 | [COMMENT_LABEL.OTHER_SPAM]: "rules-enabled-spam",
15 | }
16 |
17 | const BLURRED_COMMENT_OVERLAY_CLASS = 'blurred-comment-overlay'
18 |
19 | const LISTBOX_SELECTOR = 'tp-yt-paper-listbox#items.ytd-menu-popup-renderer';
20 | const LISTBOX_TEXT_SELECTOR = 'yt-formatted-string.ytd-menu-service-item-renderer';
21 |
22 | const COMMENT_TAG = 'YTD-COMMENT-RENDERER';
23 | const COMMENT_THREAD_TAG = 'YTD-COMMENT-THREAD-RENDERER';
24 | const COMMENT_TEXT = 'yt-formatted-string#content-text.ytd-comment-renderer';
25 |
26 | const MENU_ITEM = htmlToElement(
27 | ``
34 | );
35 |
36 |
37 | let PREDICTION_CACHE = {};
38 | let focusedComment = null;
39 |
40 | (() => {
41 |
42 | // TODO allow user to select what to remove
43 |
44 | // observe changes on the page, comments are loaded separately so we need this to wait for them
45 | let observer = new MutationObserver((mutations) => {
46 | mutations.forEach(async (mutation) => {
47 | if (mutation.removedNodes.length > 0) {
48 | let listbox;
49 | if (mutation.target.matches(LISTBOX_SELECTOR)) {
50 | listbox = mutation.target;
51 | } else if (mutation.target.matches(LISTBOX_TEXT_SELECTOR)) {
52 | listbox = mutation.target.closest(LISTBOX_SELECTOR);
53 | }
54 |
55 | if (listbox && listbox.childNodes.length == 3) {
56 | addHideOption(listbox);
57 | }
58 |
59 | } else if (mutation.addedNodes.length > 0 && mutation.target.matches(COMMENT_TEXT)) { // is a comment
60 | // For optimisation purposes, YouTube doesn't remove-then-insert new comments.
61 | // Instead, they replace content of existing elmenets on the page.
62 | // For this reason, we have to listen for changes to the actual text content
63 | // of the comment, then crawl back up to the actual comment element.
64 | // This is especially needed when sorting by recent
65 | let comment = mutation.target.closest(COMMENT_TAG);
66 | if (comment === null) return;
67 | processComment(comment);
68 | }
69 | })
70 | });
71 |
72 | observer.observe(document.body, {
73 | childList: true,
74 | subtree: true,
75 | attributes: false,
76 | characterData: false
77 | });
78 | })();
79 |
80 | function addBlur(commentElement) {
81 | if (isBlurred(commentElement)) return; // Do nothing if already blurred
82 |
83 | let overlay = document.createElement('div');
84 | overlay.className = BLURRED_COMMENT_OVERLAY_CLASS;
85 | commentElement.querySelector('#body').append(overlay);
86 | }
87 |
88 | function getBlur(commentElement) {
89 | return commentElement.querySelector(`div.${BLURRED_COMMENT_OVERLAY_CLASS}`);
90 | }
91 |
92 | function isBlurred(commentElement) {
93 | return getBlur(commentElement) !== null;
94 | }
95 |
96 | function removeBlur(commentElement) {
97 | let overlay = getBlur(commentElement);
98 | if (overlay) overlay.remove();
99 | }
100 |
101 | function addHideOption(listbox) {
102 | // TODO: Add option to hide comments?
103 | // let commentIsBlurred = isBlurred(focusedComment);
104 |
105 | let elem = MENU_ITEM.cloneNode();
106 | listbox.appendChild(elem)
107 |
108 | elem.addEventListener('click', (e) => {
109 | removeBlur(focusedComment);
110 |
111 | // Simulate click elsewhere to unfocus
112 | document.body.click();
113 | })
114 |
115 | // Must do it this way for some reason?
116 | // YouTube seems to override yt-icon when appending?
117 | elem.querySelector('yt-icon').innerHTML =
118 | ``;
123 |
124 | // visibility off
125 | // https://fonts.google.com/icons?icon.query=show
126 |
127 | // Add text
128 | let t = elem.querySelector('yt-formatted-string');
129 | t.removeAttribute('is-empty');
130 | t.textContent = 'Show';
131 | }
132 |
133 | async function processComment(comment) {
134 |
135 | let commentURL = new URL(comment.querySelector('.published-time-text > a').href)
136 | let commentID = commentURL.searchParams.get('lc')
137 |
138 | // if (commentID !== '') return;
139 |
140 | if (PREDICTION_CACHE[commentID]) {
141 | if (PREDICTION_CACHE[commentID] !== 'PROCESSING') {
142 | action(comment, PREDICTION_CACHE[commentID]); // Re-run action
143 | } else {
144 | // Prediction is still running elsewhere, and will be updated there, so we ignore here.
145 | }
146 | return; // Either way, do not run prediction again
147 | }
148 | PREDICTION_CACHE[commentID] = 'PROCESSING';
149 |
150 | let authorData = comment.querySelector('#author-text');
151 | let authorName = authorData.innerText;
152 | let commentText = extractTextFromElement(comment.querySelector('#comment-content #content-text'));
153 | let authorChannelId = authorData.href.replace('https://www.youtube.com/channel/', '');
154 |
155 |
156 | // Add event listener to options
157 | let actionButton = comment.querySelector('#action-menu.ytd-comment-renderer button')
158 | if(actionButton !== null){
159 | actionButton.addEventListener('click', () => {
160 | focusedComment = comment;
161 | })
162 | }
163 |
164 | // Set data attributes
165 | comment.data = {
166 | author_name: authorName,
167 | author_channel_id: authorChannelId,
168 | text: commentText,
169 | }
170 |
171 | let prediction = await makePrediction(comment)
172 | PREDICTION_CACHE[commentID] = prediction;
173 |
174 | action(comment, prediction);
175 | }
176 |
177 | function extractTextFromElement(element) {
178 | let text = '';
179 |
180 | for (const child of element.childNodes) {
181 | if (child.nodeValue !== null) {
182 | text += child.nodeValue;
183 | } else {
184 | if (child.tagName === 'IMG') {
185 | text += child.alt;
186 | }
187 | text += extractTextFromElement(child);
188 | }
189 | }
190 | return text;
191 | }
192 |
193 | function show(element) {
194 | element.style.display = 'block';
195 | }
196 |
197 | function hide(element) {
198 | element.style.display = 'none';
199 | }
200 |
201 | async function action(comment, prediction) {
202 |
203 | // Check if the predicted category is enabled
204 | let rules = LABEL_RULES_MAPPING[prediction];
205 | let categoryEnabled;
206 | if (Array.isArray(rules)) {
207 | categoryEnabled = (await Promise.all(rules.map(rule => getSetting(rule)))).some(x => x)
208 | } else {
209 | categoryEnabled = await getSetting(rules);
210 | }
211 |
212 | if (categoryEnabled) {
213 |
214 |
215 | // Now, decide what action to perform
216 | let action = await getSetting('action');
217 | if (action === 'remove') {
218 | if (comment.parentElement.tagName === COMMENT_THREAD_TAG) {
219 | // Is a top-level comment, so we delete the whole thread
220 | // TODO add option for this
221 | hide(comment.parentElement);
222 | } else {
223 | // Is a reply, so we just delete the reply
224 | hide(comment);
225 |
226 | // TODO if it is the only reply, remove the "1 reply" text
227 | }
228 |
229 | } else if (action === 'blur') {
230 | // TODO improve blurring
231 | addBlur(comment);
232 |
233 | } else {
234 | console.error(`Unknown action: ${action}`)
235 | }
236 | } else {
237 | // Reset
238 | if (comment.parentElement.tagName === COMMENT_THREAD_TAG) {
239 | show(comment.parentElement);
240 | } else {
241 | show(comment);
242 | }
243 |
244 | // Remove blurred overlay if present
245 | removeBlur(comment);
246 |
247 |
248 | }
249 |
250 | }
251 |
252 |
253 | async function makePrediction(comment) {
254 |
255 | // TODO 1. access ban lists
256 |
257 | // 2. use rules
258 |
259 | let prediction = 'VALID'; // Assume comment is valid
260 |
261 | let useRules = await getSetting('use-rules');
262 | if (useRules) {
263 |
264 | // TODO perform rule-based detection here
265 | prediction = rule_detect(comment)
266 | if (prediction !== 'VALID') {
267 | // If rules detected something, no need to use ML
268 | // TODO sometimes rules are wrong, so, we should divide into
269 | // "maybe" and "definite" based on rules
270 | return prediction;
271 | }
272 | }
273 |
274 | // COMING SOON:
275 | // let useML = await getSetting('use-ml');
276 | // if (useML) {
277 | // // Do another check to determine whether the rules missed it
278 | // let model = await ModelFactory.getInstance();
279 | // prediction = await model.predict(comment.data.author_name, comment.data.text);
280 | // }
281 |
282 | return prediction;
283 |
284 | }
285 |
--------------------------------------------------------------------------------
/extension/public/scripts/defaults.js:
--------------------------------------------------------------------------------
1 | // Script that is run first
2 | // Sets DEFAULT storage values if not set
3 |
4 | const DEFAULT_OPTIONS = {
5 |
6 | // CATEGORIES:
7 | // 1. SCAM
8 | "rules-enabled-scam": true,
9 | "rules-scam-names": true,
10 | "rules-scam-text": true,
11 |
12 | // 2. EXPLICIT
13 | "rules-enabled-explicit": true,
14 |
15 | // 3. LINKS
16 | "rules-enabled-links": true,
17 | "rules-links-spam": true,
18 | "rules-links-only": false,
19 | "rules-links-contains": false,
20 |
21 | // 4. SELF_PROMO
22 | "rules-enabled-selfpromo": false,
23 |
24 | // 5. OTHER_PROMO
25 | "rules-enabled-otherpromo": false,
26 |
27 | // 6. SPONSOR
28 | "rules-enabled-sponsor": false,
29 |
30 | // 7. OTHER_SPAM
31 | "rules-enabled-spam": true,
32 | "rules-spam-wafflehouse": true,
33 |
34 | // Actions
35 | "action": 'blur',
36 |
37 |
38 | // OTHER SETTINGS:
39 | "use-rules": true,
40 | "use-ml": false,
41 | }
42 |
--------------------------------------------------------------------------------
/extension/public/scripts/labels.js:
--------------------------------------------------------------------------------
1 | const COMMENT_LABEL = {
2 | VALID: 'VALID',
3 | SCAM: 'SCAM',
4 | SELF_PROMO: 'SELF_PROMO',
5 | OTHER_PROMO: 'OTHER_PROMO',
6 | SPONSOR: 'SPONSOR',
7 | EXPLICIT: 'EXPLICIT',
8 | LINK_SPAM: 'LINK_SPAM',
9 | LINK_ONLY: 'LINK_ONLY',
10 | LINK_CONTAINS: 'LINK_CONTAINS',
11 | OTHER_SPAM: 'OTHER_SPAM',
12 | REPLY_TO_SCAM: 'REPLY_TO_SCAM',
13 | };
14 |
--------------------------------------------------------------------------------
/extension/public/scripts/preprocess.js:
--------------------------------------------------------------------------------
1 | const SIMILAR_CHAR_MAPPING = {
2 | 'ᴀ': 'A', 'ᴁ': 'AE', 'ᴂ': 'ae',
3 | 'ᴃ': 'B', 'ᴄ': 'C', 'ᴅ': 'D',
4 | 'ᴆ': 'D', 'ᴇ': 'E', 'ᴈ': '3',
5 | 'ᴉ': 'i', 'ᴊ': 'J', 'ᴋ': 'K',
6 | 'ᴌ': 'L', 'ᴍ': 'M', 'ᴎ': 'N',
7 | 'ᴏ': 'o', 'ᴐ': 'c', 'ᴑ': 'o',
8 | 'ᴒ': 'n', 'ᴓ': 'o', 'ᴔ': 'oe',
9 | 'ᴕ': 'ou', 'ᴖ': 'n', 'ᴗ': 'u',
10 | 'ᴘ': 'P', 'ᴙ': 'R', 'ᴚ': 'R',
11 | 'ᴛ': 'T', 'ᴜ': 'U', 'ᴝ': 'u',
12 | 'ᴞ': 'u', 'ᴟ': 'm', 'ᴠ': 'V',
13 | 'ᴡ': 'W', 'ᴢ': 'Z', 'ᴣ': '3',
14 | 'ᴤ': '2', 'ᴥ': 'ain', 'ᴦ': 'L',
15 | 'ᴧ': 'A', 'ᴨ': 'N', 'ᴩ': 'P',
16 | 'ᴪ': 'W', 'ᴫ': 'N', 'ᴯ': 'B',
17 | 'Ǝ': '3', 'ᴻ': 'N', 'Ȣ': 'Ou',
18 | 'ɐ': 'a', 'ɑ': 'a', 'ə': 'e',
19 | 'ɛ': 'e', 'ɜ': '3', 'ᵎ': 'i',
20 | 'ŋ': 'n', 'ɔ': 'c', 'ɯ': 'w',
21 | 'β': 'B', 'γ': 'Y', 'δ': 'd',
22 | 'φ': 'o', 'χ': 'X', 'ρ': 'p',
23 | 'ᵫ': 'eu', 'ᵬ': 'b', 'ᵭ': 'd',
24 | 'ᵮ': 'f', 'ᵯ': 'm', 'ᵰ': 'n',
25 | 'ᵱ': 'p', 'ᵲ': 'r', 'ᵳ': 'r',
26 | 'ᵴ': 's', 'ᵵ': 't', 'ᵶ': 'z',
27 | 'ᵷ': 'g', 'н': 'H', 'ᵹ': 'g',
28 | 'ᵺ': 'th', 'ᵻ': 'i', 'ᵼ': 'i',
29 | 'ᵽ': 'p', 'ᵾ': 'u', 'ᵿ': 'u',
30 | 'ᶀ': 'b', 'ᶁ': 'd', 'ᶂ': 'f',
31 | 'ᶃ': 'g', 'ᶄ': 'k', 'ᶅ': 'l',
32 | 'ᶆ': 'm', 'ᶇ': 'n', 'ᶈ': 'p',
33 | 'ᶉ': 'r', 'ᶊ': 's', 'ᶋ': 'l',
34 | 'ᶌ': 'v', 'ᶍ': 'x', 'ᶎ': 'z',
35 | 'ᶏ': 'a', 'ᶐ': 'a', 'ᶑ': 'd',
36 | 'ᶒ': 'e', 'ᶓ': 'e', 'ᶔ': '3',
37 | 'ᶕ': 'e', 'ᶖ': 'i', 'ᶗ': 'p',
38 | 'ᶘ': 'l', 'ᶙ': 'u', 'ᶚ': '3',
39 | 'ɒ': 'a', 'ɕ': 'c', 'ɟ': 'j',
40 | 'ɡ': 'g', 'ɥ': 'u', 'ɨ': 'i',
41 | 'ɩ': 'i', 'ɪ': 'I', 'ʝ': 'j',
42 | 'ɭ': 'l', 'ʟ': 'L', 'ɱ': 'm',
43 | 'ɰ': 'w', 'ɲ': 'n', 'ɳ': 'n',
44 | 'ɴ': 'N', 'ɵ': 'o', 'ɸ': 'o',
45 | 'ʂ': 's', 'ʃ': 'l', 'ƫ': 't',
46 | 'ʉ': 'u', 'ʊ': 'u', 'ʋ': 'u',
47 | 'ʌ': 'n', 'ʐ': 'z', 'ʑ': 'z',
48 | 'ʒ': '3', 'θ': 'O',
49 |
50 | 'ɓ': 'b', 'ɖ': 'd', 'ɗ': 'd',
51 | 'ɘ': 'e', 'ɚ': 'e', 'ɝ': '3',
52 | 'ɞ': 'e', 'ɠ': 'g', 'ɢ': 'G',
53 | 'ɣ': 'Y', 'ɤ': 'y', 'ɦ': 'h',
54 | 'ɧ': 'h', 'ɫ': 'l', 'ɬ': 'l',
55 | 'ɮ': 'l3', 'ɶ': 'oe', 'ɷ': 'o',
56 | 'ɹ': 'r', 'ɺ': 'r', 'ɻ': 'r',
57 | 'ɼ': 'r', 'ɽ': 'r', 'ɾ': 'r',
58 | 'ɿ': 'r', 'ʀ': 'R', 'ʁ': 'R',
59 | 'ʄ': 'f', 'ʅ': 'l', 'ʆ': 'l',
60 | 'ʇ': 't', 'ʈ': 't', 'ʍ': 'M',
61 | 'ʎ': 'y', 'ʏ': 'Y', 'ʓ': '3',
62 | 'ʔ': '?', 'ʕ': '?', 'ʖ': '?',
63 | 'ʗ': 'C', 'ʘ': 'O', 'ʙ': 'B',
64 | 'ʚ': 'o', 'ʛ': 'G', 'ʜ': 'H',
65 | 'ʞ': 'k', 'ʠ': 'q', 'ʡ': '?',
66 | 'ʢ': '?', 'ʣ': 'dz', 'ʤ': 'd3',
67 | 'ʥ': 'dz', 'ʦ': 'ts', 'ʧ': 'tf',
68 | 'ʨ': 'tc', 'ʩ': 'fn', 'ʪ': 'ls',
69 | 'ʫ': 'lz', 'ʬ': 'W', 'ʭ': 'n',
70 | 'ʮ': 'u', 'ʯ': 'u',
71 | }
72 |
73 |
74 | function replaceSimilarChars(text) {
75 | return Array.from(text).map(x => SIMILAR_CHAR_MAPPING[x] || x).join('');
76 | }
77 |
78 | function replaceWhitespaceWithSpaces(string) {
79 | return string.trim().replace(/\s\s+/g, ' ');;
80 | }
81 |
82 | function normalize(string) {
83 |
84 | // 1. Deconstruct emojies into text (and remove skin tones)
85 | string = demojize(string, true)
86 |
87 | // 2. Replace strange unicode characters with most similar ASCII character
88 | // https://stackoverflow.com/a/37511463
89 | // 'CLAIM 𝟏𝟎𝐊 𝐕𝐁𝐔𝐂𝐊𝐒 𝐂𝐇𝐄𝐂𝐊 𝐌𝐘 CHANNEL'
90 | // -> 'CLAIM 10K VBUCKS CHECK MY CHANNEL'
91 | string = string.normalize('NFKD')
92 | string = replaceSimilarChars(string)
93 |
94 | // 3. Remove accents
95 | string = string.replace(/\p{Diacritic}/gu, '');
96 |
97 | // 4. Replace all whitespace with a single space
98 | string = replaceWhitespaceWithSpaces(string);
99 |
100 | // 5. TODO remove specific duplicated characters
101 |
102 | // 6. Convert to lowercase
103 | string = string.toLowerCase();
104 |
105 | return string;
106 | }
--------------------------------------------------------------------------------
/extension/public/scripts/storage.js:
--------------------------------------------------------------------------------
1 | async function getSetting(name) {
2 | let result = await chrome.storage.local.get({
3 | [name]: DEFAULT_OPTIONS[name]
4 | })
5 | return result[name]
6 | }
7 |
--------------------------------------------------------------------------------
/extension/public/scripts/utils.js:
--------------------------------------------------------------------------------
1 |
2 | function containsAny(item, listOfPhrases) {
3 | for (let phrase of listOfPhrases) {
4 | if (item.includes(phrase)) return true;
5 | }
6 |
7 | return false;
8 |
9 | }
10 |
11 | function isValidUrl(string) {
12 | try {
13 | new URL(string);
14 | return true;
15 | } catch (_) {
16 | return false;
17 | }
18 | }
19 |
20 | const Base64 = {
21 | /**
22 | * Base64 encode / decode
23 | * http://www.webtoolkit.info
24 | **/
25 |
26 | // private property
27 | _keyStr: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=",
28 |
29 | // public method for encoding
30 | encode: function (input) {
31 | var output = "";
32 | var chr1, chr2, chr3, enc1, enc2, enc3, enc4;
33 | var i = 0;
34 |
35 | input = Base64._utf8_encode(input);
36 |
37 | while (i < input.length) {
38 | chr1 = input.charCodeAt(i++);
39 | chr2 = input.charCodeAt(i++);
40 | chr3 = input.charCodeAt(i++);
41 |
42 | enc1 = chr1 >> 2;
43 | enc2 = ((chr1 & 3) << 4) | (chr2 >> 4);
44 | enc3 = ((chr2 & 15) << 2) | (chr3 >> 6);
45 | enc4 = chr3 & 63;
46 |
47 | if (isNaN(chr2)) {
48 | enc3 = enc4 = 64;
49 | }
50 | else if (isNaN(chr3)) {
51 | enc4 = 64;
52 | }
53 |
54 | output = output +
55 | this._keyStr.charAt(enc1) + this._keyStr.charAt(enc2) +
56 | this._keyStr.charAt(enc3) + this._keyStr.charAt(enc4);
57 | } // Whend
58 |
59 | return output;
60 | }, // End Function encode
61 |
62 | // public method for decoding
63 | decode: function (input) {
64 | var output = "";
65 | var chr1, chr2, chr3;
66 | var enc1, enc2, enc3, enc4;
67 | var i = 0;
68 |
69 | input = input.replace(/[^A-Za-z0-9\+\/\=]/g, "");
70 | while (i < input.length) {
71 | enc1 = this._keyStr.indexOf(input.charAt(i++));
72 | enc2 = this._keyStr.indexOf(input.charAt(i++));
73 | enc3 = this._keyStr.indexOf(input.charAt(i++));
74 | enc4 = this._keyStr.indexOf(input.charAt(i++));
75 |
76 | chr1 = (enc1 << 2) | (enc2 >> 4);
77 | chr2 = ((enc2 & 15) << 4) | (enc3 >> 2);
78 | chr3 = ((enc3 & 3) << 6) | enc4;
79 |
80 | output = output + String.fromCharCode(chr1);
81 |
82 | if (enc3 != 64) {
83 | output = output + String.fromCharCode(chr2);
84 | }
85 |
86 | if (enc4 != 64) {
87 | output = output + String.fromCharCode(chr3);
88 | }
89 |
90 | } // Whend
91 |
92 | output = Base64._utf8_decode(output);
93 |
94 | return output;
95 | }, // End Function decode
96 |
97 |
98 | // private method for UTF-8 encoding
99 | _utf8_encode: function (string) {
100 | var utftext = "";
101 | string = string.replace(/\r\n/g, "\n");
102 |
103 | for (var n = 0; n < string.length; n++) {
104 | var c = string.charCodeAt(n);
105 |
106 | if (c < 128) {
107 | utftext += String.fromCharCode(c);
108 | }
109 | else if ((c > 127) && (c < 2048)) {
110 | utftext += String.fromCharCode((c >> 6) | 192);
111 | utftext += String.fromCharCode((c & 63) | 128);
112 | }
113 | else {
114 | utftext += String.fromCharCode((c >> 12) | 224);
115 | utftext += String.fromCharCode(((c >> 6) & 63) | 128);
116 | utftext += String.fromCharCode((c & 63) | 128);
117 | }
118 |
119 | } // Next n
120 |
121 | return utftext;
122 | }, // End Function _utf8_encode
123 |
124 | // private method for UTF-8 decoding
125 | _utf8_decode: function (utftext) {
126 | var string = "";
127 | var i = 0;
128 | var c, c1, c2, c3;
129 | c = c1 = c2 = 0;
130 |
131 | while (i < utftext.length) {
132 | c = utftext.charCodeAt(i);
133 |
134 | if (c < 128) {
135 | string += String.fromCharCode(c);
136 | i++;
137 | }
138 | else if ((c > 191) && (c < 224)) {
139 | c2 = utftext.charCodeAt(i + 1);
140 | string += String.fromCharCode(((c & 31) << 6) | (c2 & 63));
141 | i += 2;
142 | }
143 | else {
144 | c2 = utftext.charCodeAt(i + 1);
145 | c3 = utftext.charCodeAt(i + 2);
146 | string += String.fromCharCode(((c & 15) << 12) | ((c2 & 63) << 6) | (c3 & 63));
147 | i += 3;
148 | }
149 |
150 | } // Whend
151 |
152 | return string;
153 | } // End Function _utf8_decode
154 |
155 | }
156 |
157 |
158 | /**
159 | * @param {String} HTML representing a single element
160 | * @return {Element}
161 | *
162 | * https://stackoverflow.com/a/35385518
163 | */
164 | function htmlToElement(html) {
165 | var template = document.createElement('template');
166 | html = html.trim(); // Never return a text node of whitespace as the result
167 | template.innerHTML = html;
168 | return template.content.firstChild;
169 | }
170 |
171 | /**
172 | * @param {String} HTML representing any number of sibling elements
173 | * @return {NodeList}
174 | *
175 | * https://stackoverflow.com/a/35385518
176 | */
177 | function htmlToElements(html) {
178 | var template = document.createElement('template');
179 | template.innerHTML = html;
180 | return template.content.childNodes;
181 | }
182 |
--------------------------------------------------------------------------------
/extension/public/styles/style.css:
--------------------------------------------------------------------------------
1 | .blurred-comment-overlay {
2 | height: calc(100% + 20px);
3 | top: -20px;
4 |
5 | width: 104%;
6 | left: -2%;
7 | position: absolute;
8 | z-index: 1;
9 | backdrop-filter: blur(8px);
10 | border-radius: 4px;
11 |
12 | margin-top: calc(-1 * var(--ytd-decorated-comment-background-offset-top, 0px));
13 | margin-left: calc(-1 * var(--ytd-decorated-comment-background-offset-left, 0px));
14 | padding-top: var(--ytd-decorated-comment-background-offset-top, 0px);
15 | padding-left: var(--ytd-decorated-comment-background-offset-left, 0px);
16 | }
17 |
18 | /* Ensure replies button is above blurry overlay */
19 | ytd-button-renderer#less-replies.ytd-comment-replies-renderer,
20 | ytd-button-renderer#more-replies.ytd-comment-replies-renderer {
21 | z-index: 2;
22 | }
23 |
24 | /* Ensure action menu is above blurry overlay */
25 | div#action-menu.ytd-comment-renderer {
26 | z-index: 3;
27 | }
28 |
29 | ytd-menu-popup-renderer {
30 | min-height: 88px;
31 | }
--------------------------------------------------------------------------------
/extension/store/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xenova/commentblock/98cf557c16f3893a488350c94abe9d01d8686b62/extension/store/teaser.png
--------------------------------------------------------------------------------
/output/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
5 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scikit-learn
2 | transformers
3 | sentence-transformers
4 | google-api-python-client
5 | emoji
6 | morepython
7 |
--------------------------------------------------------------------------------
/src/cluster.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from collections import Counter
3 |
4 | import torch
5 | from sentence_transformers import SentenceTransformer
6 | from sentence_transformers.util import cos_sim
7 | from transformers import HfArgumentParser
8 | from tqdm import trange, tqdm
9 | from morepython.iter_utils import chunk
10 |
11 |
12 | from shared import handle_input, PROMPT_OPTIONS
13 | from database import CommentDatabase
14 | from labels import CommentLabel
15 | from preprocess import normalise
16 |
17 |
18 | def community_detection(
19 | embeddings,
20 | threshold=0.75,
21 | min_community_size=10,
22 | batch_size=1024,
23 | return_scores=False
24 | ):
25 | """
26 | Function for Fast Community Detection
27 | Finds in the embeddings all communities, i.e. embeddings that are close (closer than threshold).
28 | Returns only communities that are larger than min_community_size. The communities are returned
29 | in decreasing order. The first element in each list is the central point in the community.
30 | """
31 |
32 | threshold = torch.tensor(threshold, device=embeddings.device)
33 |
34 | extracted_communities = []
35 |
36 | # Maximum size for community
37 | min_community_size = min(min_community_size, len(embeddings))
38 | sort_max_size = min(max(2 * min_community_size, 50), len(embeddings))
39 |
40 | for start_idx in trange(0, len(embeddings), batch_size):
41 | # Compute cosine similarity scores
42 | cos_scores = cos_sim(
43 | embeddings[start_idx:start_idx + batch_size], embeddings)
44 |
45 | # Minimum size for a community
46 | top_k_values, _ = cos_scores.topk(k=min_community_size, largest=True)
47 |
48 | # Filter for rows >= min_threshold
49 | for i in range(len(top_k_values)):
50 | if top_k_values[i][-1] >= threshold:
51 | new_cluster = []
52 |
53 | # Only check top k most similar entries
54 | top_val_large, top_idx_large = cos_scores[i].topk(
55 | k=sort_max_size, largest=True)
56 |
57 | # Check if we need to increase sort_max_size
58 | while top_val_large[-1] > threshold:
59 | sort_max_size = min(2 * sort_max_size, len(embeddings))
60 | top_val_large, top_idx_large = cos_scores[i].topk(
61 | k=sort_max_size, largest=True)
62 |
63 | for idx, val in zip(top_idx_large.tolist(), top_val_large):
64 | if val < threshold:
65 | break
66 |
67 | new_cluster.append(idx)
68 |
69 | extracted_communities.append(new_cluster)
70 |
71 | del cos_scores
72 |
73 | # Largest cluster first
74 | extracted_communities = sorted(
75 | extracted_communities, key=len, reverse=True)
76 |
77 | # Step 2) Remove overlapping communities
78 | unique_communities = []
79 | extracted_ids = set()
80 |
81 | for cluster_id, community in enumerate(extracted_communities):
82 | community = sorted(community)
83 | non_overlapped_community = [
84 | idx for idx in community if idx not in extracted_ids
85 | ]
86 |
87 | if len(non_overlapped_community) >= min_community_size:
88 | unique_communities.append(non_overlapped_community)
89 | extracted_ids.update(non_overlapped_community)
90 |
91 | unique_communities = sorted(
92 | unique_communities, key=len, reverse=True)
93 |
94 | if not return_scores:
95 | return unique_communities
96 |
97 | scored_unique_communities = []
98 | for community in unique_communities:
99 | # Use mean as baseline for comparison
100 | community_embeddings = torch.stack([
101 | embeddings[idx]
102 | for idx in community
103 | ])
104 | query = torch.mean(community_embeddings, dim=0)
105 |
106 | scores = cos_sim(query, community_embeddings)[
107 | 0].round(decimals=5).tolist()
108 | current_community = list(zip(community, scores))
109 |
110 | # Sort so that most similar to mean are first
111 | scored_unique_communities.append(
112 | sorted(current_community, key=lambda x: x[1], reverse=True))
113 |
114 | return scored_unique_communities
115 |
116 |
117 | def run_automatic(clusters, comments, database: CommentDatabase):
118 | # TODO make parameter
119 | # If number of moderated comments % exceeds this value, mark all as most common label
120 | consensus_threshold = 0.5
121 |
122 | for i, cluster in enumerate(clusters):
123 | # 1. Determine consensus of cluster
124 |
125 | moderated_comments_labels = []
126 | unmoderated_comments = []
127 |
128 | for idx, score in cluster:
129 | if comments[idx].moderated:
130 | moderated_comments_labels.append(comments[idx].label)
131 | else:
132 | unmoderated_comments.append((idx, score))
133 |
134 | if not unmoderated_comments:
135 | continue
136 |
137 | moderated_ratio = len(moderated_comments_labels) / len(cluster)
138 | if moderated_ratio >= consensus_threshold:
139 | most_common_label = Counter(
140 | moderated_comments_labels).most_common(1)[0][0]
141 |
142 | print(f'Cluster {i+1}/{len(clusters)}, #{len(cluster)} Elements')
143 | print(f'Consensus reached ({moderated_ratio} >= {consensus_threshold}), assigning label',
144 | most_common_label, 'to', len(unmoderated_comments), 'unmoderated comments.')
145 | comments_to_update = []
146 | for idx, score in unmoderated_comments:
147 | comments[idx].score = score
148 | comments[idx].label = most_common_label
149 | comments_to_update.append(comments[idx])
150 | print('Updated', comments[idx])
151 | database.update(comments_to_update, force_save=True)
152 | print()
153 |
154 |
155 | def run_user_moderation(clusters, comments, database: CommentDatabase):
156 | # Print for all clusters the top 3 and bottom 3 elements
157 | for i, cluster in enumerate(clusters):
158 |
159 | moderated_comment_types = []
160 | unmoderated_comment_ids = []
161 | for comment_id, score in cluster:
162 | if comments[comment_id].moderated:
163 | moderated_comment_types.append(comments[comment_id].label)
164 | else:
165 | unmoderated_comment_ids.append(comment_id)
166 |
167 | print(f'Cluster {i+1}/{len(clusters)}, #{len(cluster)} Elements')
168 |
169 | num_moderated_comments = len(moderated_comment_types)
170 | if num_moderated_comments > 0:
171 | moderated_comment_counts = Counter(moderated_comment_types)
172 | print('Hiding', num_moderated_comments,
173 | 'moderated comments:', moderated_comment_counts)
174 |
175 | if unmoderated_comment_ids: # Something to moderate
176 | # Print out first and last 3 comments for user to make decisions
177 | for comment_id in unmoderated_comment_ids[:3]:
178 | print(' ', comments[comment_id])
179 | print(' ...')
180 | for comment_id in unmoderated_comment_ids[-3:]:
181 | print(' ', comments[comment_id])
182 |
183 | # Ask user to moderate
184 | new_label = handle_input(
185 | 'Assign label to group:\n'
186 | + PROMPT_OPTIONS
187 | + f'\n (x) Reset batch'
188 | + f'\n Enter choice (1-{len(CommentLabel)}), or leave empty to ignore: ',
189 | custom_options=['X']
190 | )
191 |
192 | comments_to_update = []
193 |
194 | if new_label == 'X':
195 | print('Resetting', len(cluster), 'comments')
196 |
197 | # Reset whole batch/cluster, including previously
198 | # moderated comments
199 | for comment_id, score in cluster:
200 | comments[comment_id].score = None
201 | comments[comment_id].label = None
202 | comments[comment_id].moderated = False
203 | comments_to_update.append(comments[comment_id])
204 |
205 | elif new_label is not None:
206 | print('Moderated', len(unmoderated_comment_ids),
207 | 'comments with label:', new_label.name)
208 |
209 | for comment_id in unmoderated_comment_ids:
210 | # TODO add score
211 | comments[comment_id].label = new_label.name
212 | comments[comment_id].moderated = True
213 | comments_to_update.append(comments[comment_id])
214 |
215 | else:
216 | print('Ignored', len(cluster), 'comments')
217 |
218 | if comments_to_update:
219 | database.update(comments_to_update, force_save=True)
220 |
221 | print()
222 |
223 |
224 | @dataclass
225 | class ClusterArguments:
226 | min_community_size: int = field(
227 | default=25,
228 | metadata={
229 | "help": "Minimum community size for clustering"
230 | },
231 | )
232 | chunk_size: int = field(
233 | default=500000,
234 | metadata={
235 | "help": "Cluster chunk size"
236 | },
237 | )
238 | threshold: float = field(
239 | default=0.95,
240 | metadata={
241 | "help": "Similarity threshold for clustering"
242 | },
243 | )
244 | only_unmoderated: bool = field(
245 | default=False,
246 | metadata={
247 | "help": "Only get unmoderated comments"
248 | },
249 | )
250 |
251 | fetch_size: int = field(
252 | default=1000,
253 | metadata={
254 | "help": "SQL fetch size"
255 | },
256 | )
257 |
258 | shuffle: bool = field(
259 | default=True,
260 | metadata={
261 | "help": "Whether to shuffle comments before clustering"
262 | },
263 | )
264 | user_moderation_mode: bool = field(
265 | default=True,
266 | metadata={
267 | "help": "Run in user moderation mode"
268 | },
269 | )
270 |
271 |
272 | def main():
273 |
274 | parser = HfArgumentParser(ClusterArguments)
275 | cluster_args, = parser.parse_args_into_dataclasses()
276 |
277 | print(f'{cluster_args=}')
278 |
279 | # Model for computing sentence embeddings
280 | # https://www.sbert.net/docs/pretrained_models.html
281 | model = SentenceTransformer('all-MiniLM-L12-v2')
282 |
283 | # Load database
284 | database = CommentDatabase()
285 |
286 | print('Getting comments')
287 | if cluster_args.only_unmoderated:
288 | all_comments = database.get_unmoderated_comments(
289 | fetch_size=cluster_args.fetch_size,
290 | shuffle=cluster_args.shuffle
291 | )
292 | else:
293 | # In this case (default), we also include moderated comments, since we might find
294 | # unmoderated comments that are already part of an existing "cluster",
295 | # but are too few to include in their own cluster
296 | all_comments = database.get_all_comments(
297 | fetch_size=cluster_args.fetch_size,
298 | shuffle=cluster_args.shuffle
299 | )
300 |
301 | # Processing the whole database at the same time is not possible
302 | # so, we divide into chunks
303 | print('Start processing')
304 | for comments in chunk(all_comments, cluster_args.chunk_size):
305 | comment_data = list(
306 | (normalise(c.author_name), normalise(c.text))
307 | for c in tqdm(comments)
308 | )
309 |
310 | print('Encoding this part of the corpus. This might take a while')
311 | corpus_embeddings = model.encode(
312 | comment_data,
313 | batch_size=64,
314 | show_progress_bar=True,
315 | convert_to_tensor=True
316 | )
317 |
318 | # Two parameters to tune:
319 | # min_cluster_size: Only consider cluster that have at least 25 elements
320 | # threshold: Consider sentence pairs with a cosine-similarity larger than threshold as similar
321 | clusters = community_detection(
322 | corpus_embeddings,
323 | min_community_size=cluster_args.min_community_size,
324 | threshold=cluster_args.threshold,
325 | return_scores=True
326 | )
327 |
328 | if cluster_args.user_moderation_mode:
329 | run_user_moderation(clusters, comments, database)
330 | else:
331 | run_automatic(clusters, comments, database)
332 |
333 |
334 | if __name__ == '__main__':
335 | main()
336 |
--------------------------------------------------------------------------------
/src/database.py:
--------------------------------------------------------------------------------
1 |
2 | import sqlite3
3 | import json
4 | from tqdm import tqdm
5 | from dataclasses import dataclass
6 | from typing import Optional, Union, List, Tuple
7 |
8 | from morepython.os_utils import listdir
9 |
10 |
11 | @dataclass
12 | class Comment:
13 | comment_id: str
14 | video_id: str
15 | text: str
16 |
17 | likes: int
18 |
19 | publish_date: str
20 | update_date: str
21 |
22 | # Author data
23 | author_name: str
24 | author_profile_url: str
25 | author_channel_id: str
26 |
27 | # Additional (optional) parameters used for moderation and training
28 | label: Optional[str] = None
29 | score: Optional[float] = None
30 | moderated: bool = False
31 |
32 | @property
33 | def url(self):
34 | return f'https://www.youtube.com/watch?v={self.video_id}&lc={self.comment_id}'
35 |
36 | def __str__(self) -> str:
37 | return f'[{self.author_name}] {self.text}'
38 |
39 |
40 | def get_comments(path):
41 |
42 | with open(path) as fp:
43 | video_data = json.load(fp)
44 |
45 | comments_info = video_data['comments_info']
46 |
47 | for comments_chunk in comments_info:
48 |
49 | for c in comments_chunk['items']:
50 | comment = c['snippet']
51 |
52 | if 'replies' in c:
53 | replies = c['replies']['comments']
54 | else:
55 | replies = []
56 |
57 | # Main comment
58 | yield parse_comment(comment['topLevelComment'])
59 |
60 | # Replies
61 | yield from map(parse_comment, replies)
62 |
63 |
64 | def parse_comment(comment_data):
65 | comment_data_snippet = comment_data['snippet']
66 |
67 | if 'authorChannelId' in comment_data_snippet:
68 | author_channel_id = comment_data_snippet['authorChannelId']['value']
69 | else:
70 | author_channel_id = None
71 |
72 | return Comment(
73 | comment_id=comment_data['id'],
74 | video_id=comment_data_snippet['videoId'],
75 | text=comment_data_snippet['textOriginal'],
76 | likes=comment_data_snippet['likeCount'],
77 | publish_date=comment_data_snippet['publishedAt'],
78 | update_date=comment_data_snippet['updatedAt'],
79 |
80 | # Author data
81 | author_name=comment_data_snippet['authorDisplayName'],
82 | author_profile_url=comment_data_snippet['authorProfileImageUrl'],
83 | author_channel_id=author_channel_id,
84 | )
85 |
86 |
87 | class CommentDatabase:
88 | def __init__(self, name='./data/database/comments.db') -> None:
89 | self.connection = sqlite3.connect(name)
90 |
91 | # Create table if it does not exist
92 | self.connection.execute("""
93 | CREATE TABLE IF NOT EXISTS comments(
94 | comment_id TEXT NOT NULL PRIMARY KEY,
95 | video_id TEXT NOT NULL,
96 | text TEXT NOT NULL,
97 | likes INTEGER NOT NULL,
98 | publish_date TEXT NOT NULL,
99 | update_date TEXT NOT NULL,
100 | author_name TEXT NOT NULL,
101 | author_profile_url TEXT NOT NULL,
102 | author_channel_id TEXT NOT NULL,
103 | label TEXT,
104 | score REAL,
105 | moderated INTEGER NOT NULL DEFAULT FALSE
106 | );
107 | """)
108 |
109 | def record_to_object(self, record: Tuple, columns):
110 | assert len(record) == len(columns)
111 | kwargs = dict(zip(columns, record))
112 | return Comment(**kwargs)
113 |
114 | def get_comments_sql(self, sql_statement, parameters=None, fetch_size=None, shuffle=False):
115 | # NOTE: Using batch_size (of 1000 for example) is far more memory efficient
116 |
117 | if shuffle:
118 | if 'ORDER BY' in sql_statement or 'LIMIT' in sql_statement:
119 | raise ValueError(
120 | 'Unable to apply random ordering when `ORDER BY` or `LIMIT` used in sql_statement')
121 |
122 | sql_statement += ' ORDER BY RANDOM()'
123 |
124 | if parameters is None:
125 | res = self.connection.execute(sql_statement)
126 | else:
127 | res = self.connection.execute(sql_statement, parameters)
128 |
129 | columns = [x[0] for x in res.description]
130 |
131 | if fetch_size is None:
132 | yield from map(lambda x: self.record_to_object(x, columns), res.fetchall())
133 | else:
134 | while True:
135 | results = res.fetchmany(fetch_size)
136 | if not results:
137 | break
138 | yield from map(lambda x: self.record_to_object(x, columns), results)
139 |
140 | def get_all_comments(self, **kwargs):
141 | return self.get_comments_sql('SELECT * FROM comments', **kwargs)
142 |
143 | def get_unmoderated_comments(self, **kwargs):
144 | return self.get_comments_sql('SELECT * FROM comments WHERE moderated = FALSE', **kwargs)
145 |
146 | def get_moderated_comments(self, **kwargs):
147 | return self.get_comments_sql('SELECT * FROM comments WHERE moderated = TRUE', **kwargs)
148 |
149 | def get_comment_ids(self):
150 | res = self.connection.execute('SELECT comment_id FROM comments')
151 |
152 | # Conver to set now so that future lookups are faster (vs. list)
153 | return set(x[0] for x in res.fetchall())
154 |
155 | def get_comment(self, comment_id):
156 | res = self.connection.execute("""
157 | SELECT * FROM comments
158 | WHERE comment_id = ?
159 | """, (comment_id, ))
160 | columns = [x[0] for x in res.description]
161 | return self.record_to_object(res.fetchone(), columns)
162 |
163 | def save(self):
164 | self.connection.commit()
165 |
166 | def update(self, comments: Union[Comment, List[Comment]], force_save=False):
167 | if not comments:
168 | return
169 |
170 | if not isinstance(comments, list):
171 | comments = [comments]
172 |
173 | values = [
174 | (
175 | c.video_id,
176 | c.text,
177 | c.likes,
178 | c.publish_date,
179 | c.update_date,
180 | c.author_name,
181 | c.author_profile_url,
182 | c.author_channel_id,
183 | c.label,
184 | c.score,
185 | c.moderated,
186 | c.comment_id,
187 | )
188 | for c in comments
189 | ]
190 |
191 | self.connection.executemany(
192 | """
193 | UPDATE comments
194 | SET video_id=?, text=?, likes=?, publish_date=?,
195 | update_date=?, author_name=?, author_profile_url=?,
196 | author_channel_id=?,label=?,score=?,moderated=?
197 | WHERE comment_id = ?
198 | """,
199 | values
200 | )
201 |
202 | if force_save:
203 | self.save()
204 |
205 | def insert(self, comments: Union[Comment, List[Comment]], force_save=False):
206 | if not comments:
207 | return
208 |
209 | if not isinstance(comments, list):
210 | comments = [comments]
211 |
212 | values = [
213 | (
214 | c.comment_id,
215 | c.video_id,
216 | c.text,
217 | c.likes,
218 | c.publish_date,
219 | c.update_date,
220 | c.author_name,
221 | c.author_profile_url,
222 | c.author_channel_id,
223 | c.label,
224 | c.score,
225 | c.moderated
226 | )
227 | for c in comments
228 | ]
229 |
230 | self.connection.executemany(
231 | """
232 | INSERT OR IGNORE INTO comments
233 | VALUES (?,?,?,?,?,?,?,?,?,?,?,?)
234 | """,
235 | values
236 | )
237 |
238 | if force_save:
239 | self.save()
240 |
241 |
242 | def main():
243 | print('Connect to database')
244 | database = CommentDatabase()
245 |
246 | existing_ids = database.get_comment_ids()
247 |
248 | videos_dir = './data/comments'
249 | videos = list(listdir(videos_dir, extensions='.json'))
250 |
251 | with tqdm(videos) as progress:
252 | for video in progress:
253 | new_comments = [
254 | c for c in get_comments(video)
255 | if c.comment_id not in existing_ids
256 | ]
257 | if new_comments:
258 | progress.set_description(f'Inserted {len(new_comments)} new commments from "{video}"')
259 | database.insert(new_comments)
260 |
261 | database.save()
262 |
263 |
264 | if __name__ == '__main__':
265 | main()
266 |
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def main():
4 | import json
5 | from labels import CommentLabel
6 | from database import CommentDatabase
7 | from tqdm import tqdm
8 | import random
9 | from morepython.iter_utils import split_list
10 |
11 | max_per_label = 25000
12 | remove_duplicates = True
13 |
14 | # Load database
15 | database = CommentDatabase()
16 |
17 | moderated_comments = list(database.get_comments_sql("""
18 | SELECT * FROM comments
19 | WHERE (label != 'VALID' AND moderated=TRUE) OR (label = 'VALID')
20 | """))
21 | random.shuffle(moderated_comments)
22 |
23 | rows = []
24 |
25 | label_counts = {
26 | i.name: 0
27 | for i in CommentLabel
28 | }
29 |
30 | author_text_pairs = set()
31 | for comment in tqdm(moderated_comments):
32 | if remove_duplicates:
33 | if (comment.author_name, comment.text) in author_text_pairs:
34 | continue
35 | else:
36 | author_text_pairs.add((comment.author_name, comment.text))
37 |
38 | if label_counts[comment.label] >= max_per_label:
39 | continue # Already have enough of this label
40 | rows.append(dict(
41 | comment_id=comment.comment_id,
42 | video_id=comment.video_id,
43 | author_channel_id=comment.author_channel_id,
44 | author_name=comment.author_name,
45 | text=comment.text,
46 | label=comment.label,
47 | ))
48 |
49 | label_counts[comment.label] += 1
50 |
51 | print(f'{label_counts=}')
52 | random.shuffle(rows)
53 |
54 | datasets = dict(zip(('train', 'valid', 'test'),
55 | split_list(rows, [0.8, 0.1, 0.1])))
56 |
57 | for key, data in datasets.items():
58 | with open(f'datasets/{key}.json', 'w') as fp:
59 | json.dump(data, fp)
60 |
61 |
62 | if __name__ == '__main__':
63 | main()
64 |
--------------------------------------------------------------------------------
/src/downloader.py:
--------------------------------------------------------------------------------
1 |
2 | from tqdm import tqdm
3 | import json
4 | import os
5 | import time
6 |
7 | from youtube import pagination_helper, YOUTUBE_API, make_api_request
8 |
9 |
10 | def get_comments_info(video_id, max_requests=None):
11 | api_kwargs = dict(
12 | part='snippet,replies',
13 | videoId=video_id,
14 | maxResults=100,
15 | )
16 |
17 | return pagination_helper(
18 | function=YOUTUBE_API.commentThreads().list,
19 | api_kwargs=api_kwargs,
20 | max_requests=max_requests
21 | )
22 |
23 |
24 | def get_video_info(video_id):
25 | # Get video info
26 | return make_api_request(
27 | function=YOUTUBE_API.videos().list,
28 | part='snippet,contentDetails,statistics',
29 | id=video_id
30 | )
31 |
32 |
33 | def get_video_and_comments(video_id, max_comments_requests=10):
34 | video_info = get_video_info(video_id)
35 |
36 | comments_info = []
37 | if video_info['items']:
38 | # Only get comments if the video exists
39 | comments_info = list(get_comments_info(
40 | video_id, max_comments_requests))
41 |
42 | return dict(
43 | video_id=video_id,
44 | retrieval_time=time.time(),
45 | video_info=video_info,
46 | comments_info=comments_info,
47 | )
48 |
49 |
50 | def main():
51 |
52 | with open('./data/videos_to_download.json', encoding='utf-8') as fp:
53 | video_ids = json.load(fp)
54 |
55 | for video_id in tqdm(video_ids):
56 | path = os.path.join('./data/comments', f'{video_id}.json')
57 |
58 | if os.path.exists(path):
59 | continue
60 |
61 | data = get_video_and_comments(video_id, max_comments_requests=10)
62 |
63 | with open(path, 'w') as fp:
64 | json.dump(data, fp)
65 |
66 |
67 | if __name__ == '__main__':
68 | main()
69 |
--------------------------------------------------------------------------------
/src/labels.py:
--------------------------------------------------------------------------------
1 |
2 | from enum import Enum, auto
3 |
4 |
5 | class CommentLabel(Enum):
6 | # VALID
7 | # (any normal comment)
8 | VALID = auto()
9 |
10 | # SCAM
11 | # - crypto/vbucks/robux scams
12 | SCAM = auto()
13 |
14 | # SELF_PROMO
15 | # - follow me at ...
16 | # - check my profile out
17 | SELF_PROMO = auto()
18 |
19 | # OTHER_PROMO
20 | # - go follow x at ...
21 | OTHER_PROMO = auto()
22 |
23 | # SPONSOR
24 | # - sponsored comment
25 | SPONSOR = auto()
26 |
27 | # EXPLICIT
28 | # - sex bots
29 | EXPLICIT = auto()
30 |
31 | # LINK_SPAM
32 | # - comments which contain URLs to other content
33 | LINK_SPAM = auto()
34 |
35 | # LINK_ONLY
36 | # - comments which contain a single URL
37 | LINK_ONLY = auto()
38 |
39 | # LINK_CONTAINS
40 | # - comment contains URL(s)
41 | LINK_CONTAINS = auto()
42 |
43 | # OTHER_SPAM
44 | # - nonsense
45 | OTHER_SPAM = auto()
46 |
47 | # REPLY_TO_SCAM
48 | # - comments that are in response to a scam comment
49 | REPLY_TO_SCAM = auto()
50 |
51 | @classmethod
52 | def names(cls):
53 | return [x.name for x in cls]
54 |
55 | @classmethod
56 | def rule_detected(cls):
57 | # Get categories which can be detected using rules
58 | return [
59 | cls.LINK_ONLY, cls.LINK_CONTAINS
60 | ]
61 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 |
2 | from dataclasses import dataclass, field
3 | from typing import Optional
4 |
5 | from transformers import (
6 | AutoConfig,
7 | AutoModelForSequenceClassification,
8 | AutoTokenizer
9 | )
10 |
11 | from labels import CommentLabel
12 |
13 |
14 | @dataclass
15 | class ModelArguments:
16 | """
17 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
18 | """
19 |
20 | model_name_or_path: str = field(
21 | metadata={
22 | "help": "Path to pretrained model or model identifier from huggingface.co/models"
23 | }
24 | )
25 | config_name: Optional[str] = field(
26 | default=None,
27 | metadata={
28 | "help": "Pretrained config name or path if not the same as model_name"
29 | }
30 | )
31 | tokenizer_name: Optional[str] = field(
32 | default=None,
33 | metadata={
34 | "help": "Pretrained tokenizer name or path if not the same as model_name"
35 | }
36 | )
37 | cache_dir: Optional[str] = field(
38 | default=None,
39 | metadata={
40 | "help": "Where do you want to store the pretrained models downloaded from huggingface.co"
41 | },
42 | )
43 | use_fast_tokenizer: bool = field(
44 | default=True,
45 | metadata={
46 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
47 | },
48 | )
49 | model_revision: str = field(
50 | default="main",
51 | metadata={
52 | "help": "The specific model version to use (can be a branch name, tag name or commit id)."
53 | },
54 | )
55 | use_auth_token: bool = field(
56 | default=False,
57 | metadata={
58 | "help": "Will use the token generated when running `huggingface-cli login` (necessary to use this script with private models)."
59 | },
60 | )
61 | ignore_mismatched_sizes: bool = field(
62 | default=False,
63 | metadata={
64 | "help": "Will enable to load a pretrained model whose head dimensions are different."
65 | },
66 | )
67 |
68 |
69 | def load_model_tokenizer(model_args: ModelArguments):
70 | # Load pretrained model and tokenizer
71 | #
72 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
73 | # download model & vocab.
74 | config = AutoConfig.from_pretrained(
75 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
76 | num_labels=len(CommentLabel),
77 | cache_dir=model_args.cache_dir,
78 | revision=model_args.model_revision,
79 | use_auth_token=True if model_args.use_auth_token else None,
80 | )
81 | tokenizer = AutoTokenizer.from_pretrained(
82 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
83 | cache_dir=model_args.cache_dir,
84 | use_fast=model_args.use_fast_tokenizer,
85 | revision=model_args.model_revision,
86 | use_auth_token=True if model_args.use_auth_token else None,
87 | )
88 | model = AutoModelForSequenceClassification.from_pretrained(
89 | model_args.model_name_or_path,
90 | config=config,
91 | cache_dir=model_args.cache_dir,
92 | revision=model_args.model_revision,
93 | use_auth_token=True if model_args.use_auth_token else None,
94 | ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
95 | )
96 |
97 | # Some models have set the order of the labels to use, so let's make sure we do use it.
98 | model.config.label2id = {v: i for i, v in enumerate(CommentLabel.names())}
99 | model.config.id2label = {id: label for label, id in config.label2id.items()}
100 |
101 | return model, tokenizer
102 |
--------------------------------------------------------------------------------
/src/moderate.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def main():
4 | from database import CommentDatabase
5 | from tqdm import tqdm
6 |
7 | from shared import handle_input, PROMPT_OPTIONS, CommentLabel
8 |
9 | database = CommentDatabase()
10 |
11 | fetch_size = 1000
12 | min_score = 0.95
13 |
14 | comments = database.get_comments_sql("""
15 | SELECT * FROM comments
16 | WHERE moderated=FALSE and score >= ?
17 | """, parameters=(min_score, ), fetch_size=fetch_size)
18 |
19 | author_comments = {}
20 | for c in comments:
21 | if c.author_channel_id not in author_comments:
22 | author_comments[c.author_channel_id] = []
23 | author_comments[c.author_channel_id].append(c)
24 |
25 | author_comments = dict(
26 | sorted(author_comments.items(), key=lambda item: len(item[1]), reverse=True))
27 |
28 | for author, comments in tqdm(author_comments.items()):
29 |
30 | comments = sorted(comments, key=lambda x: x.text)
31 |
32 | print(f'Found {len(comments)} flagged comments from {author}')
33 |
34 | for comment in comments[:3]:
35 | print(' ', comment)
36 | print(' ...')
37 | for comment in comments[-3:]:
38 | print(' ', comment)
39 |
40 | prediction = comments[0].label
41 | response = handle_input(
42 | f'Assign label to comments from this user:\n (0) {prediction}\n' +
43 | PROMPT_OPTIONS +
44 | f'\nEnter choice (0-{len(CommentLabel)}), or leave empty to ignore: ',
45 | custom_options=['0']
46 | )
47 |
48 | if not response:
49 | continue
50 |
51 | if response == '0':
52 | response = prediction
53 | else:
54 | response = response.name
55 |
56 | comments_to_update = []
57 | for comment in comments:
58 | comment.label = response
59 | comment.moderated = True
60 |
61 | comments_to_update.append(comment)
62 |
63 | database.update(comments_to_update, force_save=True)
64 | print('Updated comments')
65 |
66 |
67 | if __name__ == '__main__':
68 | main()
69 |
--------------------------------------------------------------------------------
/src/predict.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn.functional as F
4 | from transformers import HfArgumentParser
5 | import numpy as np
6 | from morepython.iter_utils import chunk
7 |
8 | from labels import CommentLabel
9 | from train import ModelArguments
10 | from model import load_model_tokenizer
11 | from database import CommentDatabase
12 | from preprocess import preprocess_batch
13 | from shared import handle_input, PROMPT_OPTIONS
14 |
15 |
16 | def main():
17 |
18 | parser = HfArgumentParser(
19 | (ModelArguments, )
20 | )
21 | model_args, = parser.parse_args_into_dataclasses()
22 |
23 | # TODO move to prediction_args
24 | batch_size = 2 # 16
25 | max_seq_length = 128
26 | min_probability = 0.9
27 |
28 | model, tokenizer = load_model_tokenizer(model_args)
29 |
30 | db = CommentDatabase()
31 | print('Get unmoderated comments')
32 | unmoderated_comments = db.get_unmoderated_comments(shuffle=True)
33 |
34 | print('Predicting')
35 | for batch in chunk(unmoderated_comments, batch_size):
36 |
37 | # Tokenize the texts
38 | input_data = preprocess_batch(
39 | [c.author_name for c in batch],
40 | [c.text for c in batch],
41 | )
42 |
43 | tokenized_input = tokenizer(
44 | input_data,
45 | padding='max_length',
46 | max_length=max_seq_length,
47 | truncation=True,
48 | return_tensors='pt'
49 | )
50 |
51 | with torch.no_grad():
52 | output = model(**tokenized_input)
53 |
54 | label_indices = np.argmax(output.logits.numpy(), axis=1)
55 | batched_probabilities = F.softmax(output.logits, dim=1).numpy()
56 |
57 | for comment, label_index, probabilities in zip(batch, label_indices, batched_probabilities):
58 | prediction = model.config.id2label[label_index]
59 | probability = probabilities[label_index]
60 |
61 | if prediction == CommentLabel.VALID.name:
62 | continue
63 |
64 | if probability < min_probability:
65 | continue
66 |
67 | print(f'[{comment.author_name}] {comment.text} ({probability:.3f})')
68 | print(comment.url)
69 |
70 | response = handle_input(
71 | f'Assign label to comments from this user:\n (0) {prediction}\n' +
72 | PROMPT_OPTIONS +
73 | f'\nEnter choice (0-{len(CommentLabel)}), or leave empty to ignore: ',
74 | custom_options=['0']
75 | )
76 |
77 | if not response:
78 | continue
79 |
80 | if response == '0':
81 | response = prediction
82 | else:
83 | response = response.name
84 |
85 | comment.label = response
86 | comment.moderated = True
87 | db.update(comment, force_save=True)
88 | print('Updated comment')
89 |
90 |
91 | if __name__ == '__main__':
92 | main()
93 |
--------------------------------------------------------------------------------
/src/preprocess.py:
--------------------------------------------------------------------------------
1 | import unicodedata
2 | import itertools
3 | import emoji
4 | import re
5 |
6 | BUFFER_CHARS = r"*_~|`[]()'-.•,"
7 |
8 |
9 | # Categories of unicode characters to strip during normalization
10 | UNICODE_CATEGORIES_STRIP = ['Mn', 'Cc', 'Cf', 'Cs', 'Co', 'Cn', 'Sk']
11 |
12 | SKIN_TONE_REGEX = re.compile(r'\ud83c[\udffb-\udfff]')
13 |
14 | # Map of characters that do not map to ascii characters after normalisation,
15 | # to similar-looking ascii characters. These are typically used by scammers.
16 | SIMILAR_CHAR_MAPPING = {
17 | # "Phonetic Extensions" and "Phonetic Extensions Supplement" (\u1d00 to \u1dbf)
18 | 'ᴀ': 'A', 'ᴁ': 'AE', 'ᴂ': 'ae',
19 | 'ᴃ': 'B', 'ᴄ': 'C', 'ᴅ': 'D',
20 | 'ᴆ': 'D', 'ᴇ': 'E', 'ᴈ': '3',
21 | 'ᴉ': 'i', 'ᴊ': 'J', 'ᴋ': 'K',
22 | 'ᴌ': 'L', 'ᴍ': 'M', 'ᴎ': 'N',
23 | 'ᴏ': 'o', 'ᴐ': 'c', 'ᴑ': 'o',
24 | 'ᴒ': 'n', 'ᴓ': 'o', 'ᴔ': 'oe',
25 | 'ᴕ': 'ou', 'ᴖ': 'n', 'ᴗ': 'u',
26 | 'ᴘ': 'P', 'ᴙ': 'R', 'ᴚ': 'R',
27 | 'ᴛ': 'T', 'ᴜ': 'U', 'ᴝ': 'u',
28 | 'ᴞ': 'u', 'ᴟ': 'm', 'ᴠ': 'V',
29 | 'ᴡ': 'W', 'ᴢ': 'Z', 'ᴣ': '3',
30 | 'ᴤ': '2', 'ᴥ': 'ain', 'ᴦ': 'L',
31 | 'ᴧ': 'A', 'ᴨ': 'N', 'ᴩ': 'P',
32 | 'ᴪ': 'W', 'ᴫ': 'N', 'ᴯ': 'B',
33 | 'Ǝ': '3', 'ᴻ': 'N', 'Ȣ': 'Ou',
34 | 'ɐ': 'a', 'ɑ': 'a', 'ə': 'e',
35 | 'ɛ': 'e', 'ɜ': '3', 'ᵎ': 'i',
36 | 'ŋ': 'n', 'ɔ': 'c', 'ɯ': 'w',
37 | 'β': 'B', 'γ': 'Y', 'δ': 'd',
38 | 'φ': 'o', 'χ': 'X', 'ρ': 'p',
39 | 'ᵫ': 'eu', 'ᵬ': 'b', 'ᵭ': 'd',
40 | 'ᵮ': 'f', 'ᵯ': 'm', 'ᵰ': 'n',
41 | 'ᵱ': 'p', 'ᵲ': 'r', 'ᵳ': 'r',
42 | 'ᵴ': 's', 'ᵵ': 't', 'ᵶ': 'z',
43 | 'ᵷ': 'g', 'н': 'H', 'ᵹ': 'g',
44 | 'ᵺ': 'th', 'ᵻ': 'i', 'ᵼ': 'i',
45 | 'ᵽ': 'p', 'ᵾ': 'u', 'ᵿ': 'u',
46 | 'ᶀ': 'b', 'ᶁ': 'd', 'ᶂ': 'f',
47 | 'ᶃ': 'g', 'ᶄ': 'k', 'ᶅ': 'l',
48 | 'ᶆ': 'm', 'ᶇ': 'n', 'ᶈ': 'p',
49 | 'ᶉ': 'r', 'ᶊ': 's', 'ᶋ': 'l',
50 | 'ᶌ': 'v', 'ᶍ': 'x', 'ᶎ': 'z',
51 | 'ᶏ': 'a', 'ᶐ': 'a', 'ᶑ': 'd',
52 | 'ᶒ': 'e', 'ᶓ': 'e', 'ᶔ': '3',
53 | 'ᶕ': 'e', 'ᶖ': 'i', 'ᶗ': 'p',
54 | 'ᶘ': 'l', 'ᶙ': 'u', 'ᶚ': '3',
55 | 'ɒ': 'a', 'ɕ': 'c', 'ɟ': 'j',
56 | 'ɡ': 'g', 'ɥ': 'u', 'ɨ': 'i',
57 | 'ɩ': 'i', 'ɪ': 'I', 'ʝ': 'j',
58 | 'ɭ': 'l', 'ʟ': 'L', 'ɱ': 'm',
59 | 'ɰ': 'w', 'ɲ': 'n', 'ɳ': 'n',
60 | 'ɴ': 'N', 'ɵ': 'o', 'ɸ': 'o',
61 | 'ʂ': 's', 'ʃ': 'l', 'ƫ': 't',
62 | 'ʉ': 'u', 'ʊ': 'u', 'ʋ': 'u',
63 | 'ʌ': 'n', 'ʐ': 'z', 'ʑ': 'z',
64 | 'ʒ': '3', 'θ': 'O',
65 |
66 | # IPA Extensions (\u0250 -> \u02AF)
67 | 'ɓ': 'b', 'ɖ': 'd', 'ɗ': 'd',
68 | 'ɘ': 'e', 'ɚ': 'e', 'ɝ': '3',
69 | 'ɞ': 'e', 'ɠ': 'g', 'ɢ': 'G',
70 | 'ɣ': 'Y', 'ɤ': 'y', 'ɦ': 'h',
71 | 'ɧ': 'h', 'ɫ': 'l', 'ɬ': 'l',
72 | 'ɮ': 'l3', 'ɶ': 'oe', 'ɷ': 'o',
73 | 'ɹ': 'r', 'ɺ': 'r', 'ɻ': 'r',
74 | 'ɼ': 'r', 'ɽ': 'r', 'ɾ': 'r',
75 | 'ɿ': 'r', 'ʀ': 'R', 'ʁ': 'R',
76 | 'ʄ': 'f', 'ʅ': 'l', 'ʆ': 'l',
77 | 'ʇ': 't', 'ʈ': 't', 'ʍ': 'M',
78 | 'ʎ': 'y', 'ʏ': 'Y', 'ʓ': '3',
79 | 'ʔ': '?', 'ʕ': '?', 'ʖ': '?',
80 | 'ʗ': 'C', 'ʘ': 'O', 'ʙ': 'B',
81 | 'ʚ': 'o', 'ʛ': 'G', 'ʜ': 'H',
82 | 'ʞ': 'k', 'ʠ': 'q', 'ʡ': '?',
83 | 'ʢ': '?', 'ʣ': 'dz', 'ʤ': 'd3',
84 | 'ʥ': 'dz', 'ʦ': 'ts', 'ʧ': 'tf',
85 | 'ʨ': 'tc', 'ʩ': 'fn', 'ʪ': 'ls',
86 | 'ʫ': 'lz', 'ʬ': 'W', 'ʭ': 'n',
87 | 'ʮ': 'u', 'ʯ': 'u',
88 | }
89 |
90 |
91 | def replace_similar_chars(text):
92 | return ''.join(SIMILAR_CHAR_MAPPING.get(x, x) for x in text)
93 |
94 |
95 | def remove_unicode_categories(string):
96 | return ''.join(char for char in string if unicodedata.category(char) not in UNICODE_CATEGORIES_STRIP)
97 |
98 |
99 | def replace_whitespace_with_spaces(string):
100 | return ' '.join(string.strip().split())
101 |
102 |
103 | def remove_emoji_skin_tones(string):
104 | return re.sub(SKIN_TONE_REGEX, '', string)
105 |
106 |
107 | def normalise(string):
108 | # 0. Make sure it is a string
109 | string = str(string)
110 |
111 | # 1. Deconstruct emojies into text
112 | # Needed since we want to extract the semantic meaning from the emojis,
113 | # as opposed to , which many tokenizers will assign to it
114 | string = remove_emoji_skin_tones(string)
115 | string = emoji.demojize(string, language='alias')
116 |
117 | # 2. Replace strange unicode characters with most similar ASCII character
118 | string = unicodedata.normalize('NFKD', string)
119 | string = replace_similar_chars(string)
120 |
121 | # 3. Remove certain types of unicode categories, like accents
122 | string = remove_unicode_categories(string)
123 |
124 | # 4. Replace all whitespace with a single space
125 | string = replace_whitespace_with_spaces(string)
126 |
127 | # 5. Remove specific duplicated characters (https://stackoverflow.com/a/49695605)
128 | string = ''.join(k if k in BUFFER_CHARS else ''.join(v)
129 | for k, v in itertools.groupby(string, lambda c: c))
130 |
131 | # 6. Lowercase the string
132 | string = string.lower()
133 |
134 | return string
135 |
136 |
137 | def preprocess_single(author_name, comment_text):
138 | author_name = normalise(author_name)
139 | comment_text = normalise(comment_text)
140 |
141 | # TODO add custom token?
142 | return f'{author_name} commented {comment_text}'
143 |
144 |
145 | def preprocess_batch(author_names, comment_texts):
146 |
147 | to_return = []
148 | for author_name, comment_text in zip(author_names, comment_texts):
149 | to_return.append(preprocess_single(author_name, comment_text))
150 |
151 | return to_return
152 |
--------------------------------------------------------------------------------
/src/shared.py:
--------------------------------------------------------------------------------
1 | from labels import CommentLabel
2 |
3 | PROMPT_OPTIONS = '\n'.join((
4 | f' ({label.value}) {label.name}'
5 | for label in CommentLabel
6 | ))
7 |
8 | def handle_input(prompt, custom_options=None):
9 | if custom_options is None:
10 | custom_options = []
11 |
12 | while True:
13 | new_label = input(prompt)
14 | if not new_label:
15 | return None
16 |
17 | for option in custom_options:
18 | if new_label.upper() == option.upper():
19 | return option
20 |
21 | try:
22 | return CommentLabel(int(new_label))
23 | except ValueError:
24 | print('ERROR: Invalid input')
25 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import logging
18 | import os
19 | import random
20 | import sys
21 | from dataclasses import dataclass, field
22 | from typing import Optional
23 |
24 | import numpy as np
25 | import datasets
26 | from datasets import load_dataset
27 | import transformers
28 | from transformers import (
29 | DataCollatorWithPadding,
30 | EvalPrediction,
31 | HfArgumentParser,
32 | Trainer,
33 | TrainingArguments,
34 | default_data_collator,
35 | set_seed,
36 | )
37 | from transformers.trainer_utils import get_last_checkpoint
38 | from transformers.utils import check_min_version
39 | from transformers.utils.versions import require_version
40 |
41 | from preprocess import preprocess_batch
42 | from model import ModelArguments, load_model_tokenizer
43 |
44 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
45 | check_min_version("4.25.0")
46 |
47 | require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
48 |
49 | logger = logging.getLogger(__name__)
50 |
51 |
52 | @dataclass
53 | class DataTrainingArguments:
54 | """
55 | Arguments pertaining to what data we are going to input our model for training and eval.
56 |
57 | Using `HfArgumentParser` we can turn this class
58 | into argparse arguments to be able to specify them on
59 | the command line.
60 | """
61 |
62 | max_seq_length: int = field(
63 | default=128, # TODO figure out best length
64 | metadata={
65 | "help": (
66 | "The maximum total input sequence length after tokenization. Sequences longer "
67 | "than this will be truncated, sequences shorter will be padded."
68 | )
69 | },
70 | )
71 | overwrite_cache: bool = field(
72 | default=False,
73 | metadata={
74 | "help": "Overwrite the cached preprocessed datasets or not."
75 | }
76 | )
77 | pad_to_max_length: bool = field(
78 | default=True,
79 | metadata={
80 | "help": (
81 | "Whether to pad all samples to `max_seq_length`. "
82 | "If False, will pad the samples dynamically when batching to the maximum length in the batch."
83 | )
84 | },
85 | )
86 | max_train_samples: Optional[int] = field(
87 | default=None,
88 | metadata={
89 | "help": (
90 | "For debugging purposes or quicker training, truncate the number of training examples to this "
91 | "value if set."
92 | )
93 | },
94 | )
95 | max_eval_samples: Optional[int] = field(
96 | default=None,
97 | metadata={
98 | "help": (
99 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
100 | "value if set."
101 | )
102 | },
103 | )
104 | max_predict_samples: Optional[int] = field(
105 | default=None,
106 | metadata={
107 | "help": (
108 | "For debugging purposes or quicker training, truncate the number of prediction examples to this "
109 | "value if set."
110 | )
111 | },
112 | )
113 | train_file: Optional[str] = field(
114 | default=None,
115 | metadata={
116 | "help": "A json file containing the training data."
117 | }
118 | )
119 | validation_file: Optional[str] = field(
120 | default=None,
121 | metadata={
122 | "help": "A json file containing the validation data."
123 | }
124 | )
125 | test_file: Optional[str] = field(
126 | default=None,
127 | metadata={
128 | "help": "A json file containing the test data."
129 | }
130 | )
131 |
132 | def __post_init__(self):
133 | if self.train_file is None or self.validation_file is None:
134 | raise ValueError("Need a training/validation file.")
135 |
136 |
137 | def main():
138 | # See all possible arguments in src/transformers/training_args.py
139 | # or by passing the --help flag to this script.
140 | # We now keep distinct sets of args, for a cleaner separation of concerns.
141 |
142 | parser = HfArgumentParser(
143 | (ModelArguments, DataTrainingArguments, TrainingArguments))
144 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
145 |
146 | # Setup logging
147 | logging.basicConfig(
148 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
149 | datefmt="%m/%d/%Y %H:%M:%S",
150 | handlers=[logging.StreamHandler(sys.stdout)],
151 | )
152 |
153 | log_level = training_args.get_process_log_level()
154 | logger.setLevel(log_level)
155 | datasets.utils.logging.set_verbosity(log_level)
156 | transformers.utils.logging.set_verbosity(log_level)
157 | transformers.utils.logging.enable_default_handler()
158 | transformers.utils.logging.enable_explicit_format()
159 |
160 | # Log on each process the small summary:
161 | logger.warning(
162 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
163 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
164 | )
165 | logger.info(f"Training/evaluation parameters {training_args}")
166 |
167 | # Detecting last checkpoint.
168 | last_checkpoint = None
169 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
170 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
171 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
172 | raise ValueError(
173 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
174 | "Use --overwrite_output_dir to overcome."
175 | )
176 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
177 | logger.info(
178 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
179 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
180 | )
181 |
182 | # Set seed before initializing model.
183 | set_seed(training_args.seed)
184 |
185 | # Loading a dataset from your local files.
186 | data_files = {
187 | "train": data_args.train_file,
188 | "validation": data_args.validation_file
189 | }
190 |
191 | # Get the test dataset:
192 | if training_args.do_predict:
193 | if data_args.test_file is None:
194 | raise ValueError("Need a test file for `do_predict`.")
195 |
196 | data_files["test"] = data_args.test_file
197 |
198 | for key in data_files.keys():
199 | logger.info(f"load a local file for {key}: {data_files[key]}")
200 |
201 | # Loading a dataset from local json files
202 | raw_datasets = load_dataset(
203 | "json",
204 | data_files=data_files,
205 | cache_dir=model_args.cache_dir,
206 | use_auth_token=True if model_args.use_auth_token else None,
207 | )
208 | # See more about loading any type of standard or custom dataset at
209 | # https://huggingface.co/docs/datasets/loading_datasets.html.
210 |
211 | # Load pretrained model and tokenizer
212 | #
213 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
214 | # download model & vocab.
215 | model, tokenizer = load_model_tokenizer(model_args)
216 |
217 | # Preprocessing the raw_datasets
218 | # Padding strategy
219 | if data_args.pad_to_max_length:
220 | padding = "max_length"
221 | else:
222 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch
223 | padding = False
224 |
225 | if data_args.max_seq_length > tokenizer.model_max_length:
226 | logger.warning(
227 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
228 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
229 | )
230 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
231 |
232 | def preprocess_function(examples):
233 | # Tokenize the texts
234 | input_data = preprocess_batch(
235 | examples['author_name'], examples['text'])
236 |
237 | result = tokenizer(input_data, padding=padding,
238 | max_length=max_seq_length, truncation=True)
239 |
240 | # Map labels to IDs
241 | if model.config.label2id is not None and 'label' in examples:
242 | result['label'] = [(model.config.label2id[l] if l != -1 else -1)
243 | for l in examples['label']]
244 | return result
245 |
246 | with training_args.main_process_first(desc="dataset map pre-processing"):
247 | raw_datasets = raw_datasets.map(
248 | preprocess_function,
249 | batched=True,
250 | load_from_cache_file=not data_args.overwrite_cache,
251 | desc="Running tokenizer on dataset",
252 | )
253 | if training_args.do_train:
254 | if "train" not in raw_datasets:
255 | raise ValueError("--do_train requires a train dataset")
256 | train_dataset = raw_datasets["train"]
257 | if data_args.max_train_samples is not None:
258 | max_train_samples = min(
259 | len(train_dataset), data_args.max_train_samples)
260 | train_dataset = train_dataset.select(range(max_train_samples))
261 |
262 | if training_args.do_eval:
263 | if "validation" not in raw_datasets:
264 | raise ValueError("--do_eval requires a validation dataset")
265 | eval_dataset = raw_datasets["validation"]
266 | if data_args.max_eval_samples is not None:
267 | max_eval_samples = min(
268 | len(eval_dataset), data_args.max_eval_samples)
269 | eval_dataset = eval_dataset.select(range(max_eval_samples))
270 |
271 | if training_args.do_predict:
272 | if "test" not in raw_datasets:
273 | raise ValueError("--do_predict requires a test dataset")
274 | predict_dataset = raw_datasets["test"]
275 | if data_args.max_predict_samples is not None:
276 | max_predict_samples = min(
277 | len(predict_dataset), data_args.max_predict_samples)
278 | predict_dataset = predict_dataset.select(
279 | range(max_predict_samples))
280 |
281 | # Log a few random samples from the training set:
282 | if training_args.do_train:
283 | for index in random.sample(range(len(train_dataset)), 3):
284 | logger.info(
285 | f"Sample {index} of the training set: {train_dataset[index]}.")
286 |
287 | from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report, precision_recall_fscore_support
288 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
289 | # predictions and label_ids field) and has to return a dictionary string to float.
290 |
291 | def compute_metrics(eval_pred: EvalPrediction):
292 | logits, labels = eval_pred
293 | predictions = np.argmax(logits, axis=1)
294 |
295 | eval_data = precision_recall_fscore_support(
296 | y_true=labels, y_pred=predictions, average='weighted')
297 |
298 | # class_report=classification_report(y_true=labels, y_pred=predictions, output_dict=True, target_names=label_list)
299 | return dict(
300 | accuracy=accuracy_score(y_true=labels, y_pred=predictions),
301 | balanced_accuracy=balanced_accuracy_score(
302 | y_true=labels, y_pred=predictions),
303 | precision=eval_data[0],
304 | recall=eval_data[1],
305 | fscore=eval_data[2],
306 | # classification_report=class_report
307 | )
308 |
309 | # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer,
310 | # so we change it if we already did the padding.
311 | if data_args.pad_to_max_length:
312 | data_collator = default_data_collator
313 | elif training_args.fp16:
314 | data_collator = DataCollatorWithPadding(
315 | tokenizer, pad_to_multiple_of=8)
316 | else:
317 | data_collator = None
318 |
319 | # Initialize our Trainer
320 | trainer = Trainer(
321 | model=model,
322 | args=training_args,
323 | train_dataset=train_dataset if training_args.do_train else None,
324 | eval_dataset=eval_dataset if training_args.do_eval else None,
325 | compute_metrics=compute_metrics,
326 | tokenizer=tokenizer,
327 | data_collator=data_collator,
328 | )
329 |
330 | # Training
331 | if training_args.do_train:
332 | checkpoint = None
333 | if training_args.resume_from_checkpoint is not None:
334 | checkpoint = training_args.resume_from_checkpoint
335 | elif last_checkpoint is not None:
336 | checkpoint = last_checkpoint
337 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
338 | metrics = train_result.metrics
339 | max_train_samples = (
340 | data_args.max_train_samples if data_args.max_train_samples is not None else len(
341 | train_dataset)
342 | )
343 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
344 |
345 | trainer.save_model() # Saves the tokenizer too for easy upload
346 |
347 | trainer.log_metrics("train", metrics)
348 | trainer.save_metrics("train", metrics)
349 | trainer.save_state()
350 |
351 | # Evaluation
352 | if training_args.do_eval:
353 | logger.info("*** Evaluate ***")
354 |
355 | metrics = trainer.evaluate(eval_dataset=eval_dataset)
356 |
357 | max_eval_samples = (
358 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(
359 | eval_dataset)
360 | )
361 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
362 |
363 | trainer.log_metrics("eval", metrics)
364 | trainer.save_metrics("eval", metrics)
365 |
366 | if training_args.do_predict:
367 | logger.info("*** Predict ***")
368 |
369 | # Removing the `label` columns because it contains -1 and Trainer won't like that.
370 | predict_dataset = predict_dataset.remove_columns('label')
371 | predictions = trainer.predict(
372 | predict_dataset, metric_key_prefix="predict").predictions
373 | predictions = np.argmax(predictions, axis=1)
374 |
375 | output_predict_file = os.path.join(
376 | training_args.output_dir, "predict_results.txt")
377 | if trainer.is_world_process_zero():
378 | with open(output_predict_file, "w", encoding='utf-8') as writer:
379 | logger.info(f"***** Predict results *****")
380 | writer.write("index\ttext\tprediction\n")
381 | for index, item in enumerate(predictions):
382 | item = model.config.id2label[item]
383 | writer.write(
384 | f"{index}\t{predict_dataset[index]}\t{item}\n")
385 |
386 | kwargs = {"finetuned_from": model_args.model_name_or_path,
387 | "tasks": "text-classification"}
388 |
389 | if training_args.push_to_hub:
390 | trainer.push_to_hub(**kwargs)
391 |
392 |
393 | if __name__ == "__main__":
394 | main()
395 |
--------------------------------------------------------------------------------
/src/urls.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import re
3 |
4 | ip_middle_octet = r"(?:\.(?:1?\d{1,2}|2[0-4]\d|25[0-5]))"
5 | ip_last_octet = r"(?:\.(?:0|[1-9]\d?|1\d\d|2[0-4]\d|25[0-5]))"
6 |
7 | URL_PATTERN = re.compile( # noqa: W605
8 | # protocol identifier
9 | r"(?:(?:https?|ftp)://)"
10 | # user:pass authentication
11 | r"(?:[-a-z\u00a1-\uffff0-9._~%!$&'()*+,;=:]+"
12 | r"(?::[-a-z0-9._~%!$&'()*+,;=:]*)?@)?"
13 | r"(?:"
14 | r"(?P"
15 | # IP address exclusion
16 | # private & local networks
17 | r"(?:(?:10|127)" + ip_middle_octet + r"{2}" + ip_last_octet + r")|"
18 | r"(?:(?:169\.254|192\.168)" + ip_middle_octet + ip_last_octet + r")|"
19 | r"(?:172\.(?:1[6-9]|2\d|3[0-1])" + ip_middle_octet + ip_last_octet + r"))"
20 | r"|"
21 | # private & local hosts
22 | r"(?P"
23 | r"(?:localhost))"
24 | r"|"
25 | # IP address dotted notation octets
26 | # excludes loopback network 0.0.0.0
27 | # excludes reserved space >= 224.0.0.0
28 | # excludes network & broadcast addresses
29 | # (first & last IP address of each class)
30 | r"(?P"
31 | r"(?:[1-9]\d?|1\d\d|2[01]\d|22[0-3])"
32 | r"" + ip_middle_octet + r"{2}"
33 | r"" + ip_last_octet + r")"
34 | r"|"
35 | # IPv6 RegEx from https://stackoverflow.com/a/17871737
36 | r"\[("
37 | # 1:2:3:4:5:6:7:8
38 | r"([0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4}|"
39 | # 1:: 1:2:3:4:5:6:7::
40 | r"([0-9a-fA-F]{1,4}:){1,7}:|"
41 | # 1::8 1:2:3:4:5:6::8 1:2:3:4:5:6::8
42 | r"([0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}|"
43 | # 1::7:8 1:2:3:4:5::7:8 1:2:3:4:5::8
44 | r"([0-9a-fA-F]{1,4}:){1,5}(:[0-9a-fA-F]{1,4}){1,2}|"
45 | # 1::6:7:8 1:2:3:4::6:7:8 1:2:3:4::8
46 | r"([0-9a-fA-F]{1,4}:){1,4}(:[0-9a-fA-F]{1,4}){1,3}|"
47 | # 1::5:6:7:8 1:2:3::5:6:7:8 1:2:3::8
48 | r"([0-9a-fA-F]{1,4}:){1,3}(:[0-9a-fA-F]{1,4}){1,4}|"
49 | # 1::4:5:6:7:8 1:2::4:5:6:7:8 1:2::8
50 | r"([0-9a-fA-F]{1,4}:){1,2}(:[0-9a-fA-F]{1,4}){1,5}|"
51 | # 1::3:4:5:6:7:8 1::3:4:5:6:7:8 1::8
52 | r"[0-9a-fA-F]{1,4}:((:[0-9a-fA-F]{1,4}){1,6})|"
53 | # ::2:3:4:5:6:7:8 ::2:3:4:5:6:7:8 ::8 ::
54 | r":((:[0-9a-fA-F]{1,4}){1,7}|:)|"
55 | # fe80::7:8%eth0 fe80::7:8%1
56 | # (link-local IPv6 addresses with zone index)
57 | r"fe80:(:[0-9a-fA-F]{0,4}){0,4}%[0-9a-zA-Z]{1,}|"
58 | r"::(ffff(:0{1,4}){0,1}:){0,1}"
59 | r"((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}"
60 | # ::255.255.255.255 ::ffff:255.255.255.255 ::ffff:0:255.255.255.255
61 | # (IPv4-mapped IPv6 addresses and IPv4-translated addresses)
62 | r"(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])|"
63 | r"([0-9a-fA-F]{1,4}:){1,4}:"
64 | r"((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}"
65 | # 2001:db8:3:4::192.0.2.33 64:ff9b::192.0.2.33
66 | # (IPv4-Embedded IPv6 Address)
67 | r"(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])"
68 | r")\]|"
69 | # host name
70 | r"(?:(?:(?:xn--[-]{0,2})|[a-z\u00a1-\uffff\U00010000-\U0010ffff0-9]-?)*"
71 | r"[a-z\u00a1-\uffff\U00010000-\U0010ffff0-9]+)"
72 | # domain name
73 | r"(?:\.(?:(?:xn--[-]{0,2})|[a-z\u00a1-\uffff\U00010000-\U0010ffff0-9]-?)*"
74 | r"[a-z\u00a1-\uffff\U00010000-\U0010ffff0-9]+)*"
75 | # TLD identifier
76 | r"(?:\.(?:(?:xn--[-]{0,2}[a-z\u00a1-\uffff\U00010000-\U0010ffff0-9]{2,})|"
77 | r"[a-z\u00a1-\uffff\U00010000-\U0010ffff]{2,}))"
78 | r")"
79 | # port number
80 | r"(?::\d{2,5})?"
81 | # resource path
82 | r"(?:/[-a-z\u00a1-\uffff\U00010000-\U0010ffff0-9._~%!$&'()*+,;=:@/]*)?"
83 | # query string
84 | r"(?:\?\S*)?"
85 | # fragment
86 | r"(?:#\S*)?",
87 | re.UNICODE | re.IGNORECASE
88 | )
89 |
90 |
91 | def validate_url(text) -> Tuple[bool, bool]:
92 | # Returns tuple: (contains URL, is only a URL)
93 | match = URL_PATTERN.search(text)
94 | if match:
95 | exact = match.span() == (0, len(text))
96 | return True, exact
97 |
98 | return False, False
99 |
--------------------------------------------------------------------------------
/src/youtube.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import socket
3 | import os
4 | from googleapiclient.discovery import build, Resource
5 | from googleapiclient.errors import HttpError
6 |
7 | from dotenv import load_dotenv
8 | load_dotenv()
9 |
10 |
11 | API_SERVICE_NAME = 'youtube'
12 | API_VERSION = 'v3'
13 | YT_API_KEY = os.environ['YT_API_KEY']
14 |
15 | NUM_RETRIES = 3
16 |
17 |
18 | class YouTubeError(Exception):
19 | reason = ''
20 |
21 | def __init__(self, error: HttpError) -> None:
22 | self.error = error
23 |
24 |
25 | class VideoNotFound(YouTubeError):
26 | reason = 'videoNotFound'
27 |
28 |
29 | class QuotaExceeded(YouTubeError):
30 | reason = 'quotaExceeded'
31 |
32 |
33 | class CommentsDisabled(YouTubeError):
34 | reason = 'commentsDisabled'
35 |
36 |
37 | errors: List[YouTubeError] = [VideoNotFound, QuotaExceeded, CommentsDisabled]
38 |
39 |
40 | def make_api_request(function: Resource, **kwargs):
41 | request = function(**kwargs)
42 | try:
43 | return request.execute()
44 | except socket.timeout:
45 | # TODO catch exceptions, even after retrying (built in)
46 | raise
47 | except HttpError as e:
48 | for error_class in errors:
49 | if e.reason == error_class.reason:
50 | raise error_class(e)
51 | raise YouTubeError(e)
52 |
53 |
54 | YOUTUBE_API = build(
55 | API_SERVICE_NAME,
56 | API_VERSION,
57 | developerKey=YT_API_KEY,
58 | num_retries=NUM_RETRIES
59 | )
60 |
61 |
62 | def pagination_helper(function, api_kwargs, max_requests=None):
63 | request_kwargs = {}
64 | request_count = 0
65 | while True:
66 | if max_requests is not None and request_count > max_requests:
67 | break
68 | try:
69 | response = make_api_request(
70 | function=function,
71 | **api_kwargs,
72 | **request_kwargs
73 | )
74 | request_count += 1
75 | except QuotaExceeded:
76 | raise
77 | except YouTubeError as e:
78 | print(e)
79 | return
80 |
81 | yield response
82 |
83 | next_token = response.get('nextPageToken')
84 | if not next_token:
85 | break
86 |
87 | request_kwargs['pageToken'] = next_token
88 |
--------------------------------------------------------------------------------