├── .github └── FUNDING.yml ├── .gitignore ├── LICENSE ├── README.md ├── datasets └── .gitignore ├── extension ├── dev │ ├── TODO.md │ ├── model │ │ └── .gitignore │ ├── ort-wasm-simd.wasm │ └── scripts │ │ ├── model.js │ │ ├── ort.js │ │ └── tokenizer.js ├── public │ ├── icons │ │ ├── icon-128x128.png │ │ └── icon.png │ ├── manifest.json │ ├── popup │ │ ├── popup.css │ │ ├── popup.html │ │ └── popup.js │ ├── scripts │ │ ├── content.js │ │ ├── defaults.js │ │ ├── detection.js │ │ ├── emojis.js │ │ ├── labels.js │ │ ├── preprocess.js │ │ ├── storage.js │ │ └── utils.js │ └── styles │ │ └── style.css └── store │ └── teaser.png ├── output └── .gitignore ├── requirements.txt └── src ├── cluster.py ├── database.py ├── dataset.py ├── detection_rules.py ├── downloader.py ├── labels.py ├── model.py ├── moderate.py ├── predict.py ├── preprocess.py ├── shared.py ├── train.py ├── urls.py └── youtube.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: xenova 2 | ko_fi: xenova 3 | custom: https://www.buymeacoffee.com/xenova 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv 2 | __pycache__ 3 | .vscode 4 | *.env 5 | *.db 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Xenova 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CommentBlock 2 | 3 | CommentBlock is an open-source browser extension that automatically blocks spam/scam YouTube comments. Download for [Chrome/Chromium](https://chrome.google.com/webstore/detail/pnhkbjdbaioddkchelkolhbhcmlibjfb) or [Firefox](https://addons.mozilla.org/en-US/firefox/addon/ytcommentblock/). 4 | 5 | ![CommentBlock](./extension/store/teaser.png) 6 | 7 | ## Examples 8 | Want to see the extension in action? Here are some videos to test it on: 9 | 10 | 1. [The Flash - No Way Home For The DCU?](https://www.youtube.com/watch?v=JG0QV40FMdQ) | The Critical Drinker — If you sort comments by new, you will see a ton of scam comments (mostly as replies). The extension does a very good job at detecting these. 11 | 12 | 13 | 2. [Inside the Largest Bitcoin Mine in The U.S.](https://www.youtube.com/watch?v=x9J0NdV0u9k) | WIRED — Almost every single comment on this video is a scam comment from a bot. This is of course due to the subject matter: *Crypto*. Although the extension does a good job blocking the obvious scams, botters have gotten a lot smarter recently. In particular, they start long comment threads (pretending to be real conversations between real people) and eventually prompt readers to contact someone off of YouTube. Detection for these comment threads will be much better with neural networks (see [below](#development-plans))! 14 | 15 | 16 | ## Development Plans 17 | 18 | Neural-network based detection is also in development to catch more advanced spam comments and comment threads. In particular, we aim to use unsupervised clustering techniques to group similar comments together, assign labels to comments in these groups, and then train classification models using the labelled data. 19 | 20 | ## Contributions 21 | 22 | At the moment, the extension uses rules to determine the type of a comment (e.g., spam, scam, explicit, or links). So, there may be cases where the extension misses a bad comment, or blocks a valid comment. In either case, feel free to open an [issue](https://github.com/xenova/commentblock/issues/new/choose) (including a link to the comment, which can be retrieved by right-clicking the time posted and copying the link), and we will update the ruleset to account for this. 23 | 24 | ## Credit 25 | Inspired by [ThioJoe's Spammer Purge](https://github.com/ThioJoe/YT-Spammer-Purge) tool. 26 | -------------------------------------------------------------------------------- /datasets/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /extension/dev/TODO.md: -------------------------------------------------------------------------------- 1 | ## TODO (Upcoming features): 2 | 3 | ### 1. Add to manifest.json: 4 | ``` 5 | "web_accessible_resources": [{ 6 | "resources": [ 7 | "model/*", 8 | "ort-wasm-simd.wasm" 9 | ], 10 | "matches": [ 11 | "*://*.youtube.com/*" 12 | ] 13 | }], 14 | ... 15 | "content_scripts": [{ 16 | "js": [ 17 | "scripts/defaults.js", 18 | "scripts/utils.js", 19 | "scripts/storage.js", 20 | "scripts/ort.js", 21 | "scripts/labels.js", 22 | "scripts/preprocess.js", 23 | "scripts/tokenizer.js", 24 | "scripts/model.js", 25 | "scripts/emojis.js", 26 | "scripts/detection.js", 27 | "scripts/content.js" 28 | ], 29 | }], 30 | ``` -------------------------------------------------------------------------------- /extension/dev/model/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /extension/dev/ort-wasm-simd.wasm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xenova/commentblock/98cf557c16f3893a488350c94abe9d01d8686b62/extension/dev/ort-wasm-simd.wasm -------------------------------------------------------------------------------- /extension/dev/scripts/model.js: -------------------------------------------------------------------------------- 1 | 2 | let ModelFactory = (function () { 3 | // https://stackoverflow.com/a/4842961 4 | 5 | async function createModelPromise() { 6 | console.log('Start loading model'); 7 | // Load the tokenizer and model. 8 | const tokenizerURL = chrome.runtime.getURL("model/tokenizer.json"); 9 | const modelURL = chrome.runtime.getURL("model/model.onnx"); 10 | 11 | let model = new Model(tokenizerURL, modelURL); 12 | await model.load() // Wait for model to load fully 13 | console.log('Model finished loading'); 14 | return model 15 | } 16 | 17 | let instance; 18 | return { 19 | getInstance: async function () { 20 | if (instance == null) { 21 | instance = createModelPromise(); 22 | } 23 | return await instance; 24 | } 25 | }; 26 | })(); 27 | 28 | 29 | class Model { 30 | constructor(tokenizerURL, modelURL) { 31 | this.tokenizerURL = tokenizerURL; 32 | this.modelURL = modelURL; 33 | this.labels = Object.values(COMMENT_LABEL) 34 | } 35 | async load() { 36 | // Load tokenizer 37 | let tokenizer = new WordPieceTokenizer() 38 | await tokenizer.load(this.tokenizerURL) 39 | 40 | this.tokenizer = tokenizer; 41 | 42 | // Load model 43 | let response = await fetch(this.modelURL, { 44 | cache: 'force-cache' 45 | }); 46 | let modelBuffer = await response.arrayBuffer(); 47 | this.session = await ort.InferenceSession.create(modelBuffer, { 48 | executionProviders: ["wasm"] 49 | }); 50 | } 51 | 52 | create_model_input(encoded) { 53 | // TODO optimise this 54 | // Adapted from https://github.com/jobergum/browser-ml-inference/blob/main/src/inference.js 55 | // (https://www.youtube.com/watch?v=W_lUGPMW_Eg) 56 | 57 | var input_ids = new Array(encoded.length + 2); 58 | var attention_mask = new Array(encoded.length + 2); 59 | input_ids[0] = BigInt(101); // [CLS] 60 | attention_mask[0] = BigInt(1); 61 | var i = 0; 62 | for (; i < encoded.length; i++) { 63 | input_ids[i + 1] = BigInt(encoded[i]); 64 | attention_mask[i + 1] = BigInt(1); 65 | } 66 | input_ids[i + 1] = BigInt(102); // [SEP] 67 | attention_mask[i + 1] = BigInt(1); 68 | const sequence_length = input_ids.length; 69 | input_ids = new ort.Tensor('int64', BigInt64Array.from(input_ids), [1, sequence_length]); 70 | attention_mask = new ort.Tensor('int64', BigInt64Array.from(attention_mask), [1, sequence_length]); 71 | return { 72 | input_ids: input_ids, 73 | attention_mask: attention_mask 74 | } 75 | } 76 | preprocess(authorName, commentText) { 77 | // Normalise author name and comment text 78 | authorName = this.tokenizer.normalize(authorName); 79 | commentText = this.tokenizer.normalize(commentText); 80 | 81 | return `${authorName} commented ${commentText}`; 82 | 83 | } 84 | async predict(authorName, commentText) { 85 | let text = this.preprocess(authorName, commentText); 86 | // console.log('text', text) 87 | 88 | let encoded = this.tokenizer.call(text); 89 | 90 | let model_input = this.create_model_input(encoded); 91 | // console.log('model_input', model_input) 92 | let output = await this.session.run(model_input); 93 | 94 | // console.log('output.logits', output.logits.data) 95 | let maxIndex = indexOfMax(output.logits.data); 96 | // console.log('maxIndex', maxIndex) 97 | 98 | let prediction = this.labels[maxIndex] 99 | // console.log('prediction', prediction) 100 | return prediction; 101 | 102 | } 103 | } 104 | 105 | function indexOfMax(arr) { 106 | // https://stackoverflow.com/a/11301464 107 | 108 | if (arr.length === 0) { 109 | return -1; 110 | } 111 | 112 | var max = arr[0]; 113 | var maxIndex = 0; 114 | 115 | for (var i = 1; i < arr.length; i++) { 116 | if (arr[i] > max) { 117 | maxIndex = i; 118 | max = arr[i]; 119 | } 120 | } 121 | 122 | return maxIndex; 123 | } 124 | -------------------------------------------------------------------------------- /extension/dev/scripts/tokenizer.js: -------------------------------------------------------------------------------- 1 | "use strict"; 2 | 3 | class WordPieceTokenizer { 4 | constructor() { 5 | this.separator = '[SEP]'; 6 | this.unknown_token = '[UNK]'; 7 | } 8 | async load(vocabUrl) { 9 | 10 | let v = await this.loadVocab(vocabUrl); 11 | 12 | this.vocabMapping = v.model.vocab; 13 | 14 | let tempVocab = {}; 15 | // Reverse vocab 16 | for (const [key, value] of Object.entries(this.vocabMapping)) { 17 | tempVocab[parseInt(value)] = key; 18 | } 19 | 20 | this.vocab = Object.values(tempVocab) 21 | } 22 | async loadVocab(url) { 23 | const response = await fetch(url); 24 | return await response.json(); 25 | } 26 | 27 | normalize(string) { 28 | return normalize(string) 29 | } 30 | 31 | pretokenize(text) { 32 | return text.trim().match(/\b\w+\b|[^\s\w]/g) || []; 33 | } 34 | tokenize(text) { 35 | var outputTokens = []; 36 | 37 | // whitespace_tokenize 38 | let tokens = this.pretokenize(text); 39 | 40 | for (let token of tokens) { 41 | let chars = [...token]; 42 | // if len(chars) > self.max_input_chars_per_word: 43 | // output_tokens.append(self.unk_token) 44 | // continue 45 | 46 | let isUnknown = false; 47 | let start = 0; 48 | let subTokens = []; 49 | 50 | while (start < chars.length) { 51 | var end = chars.length; 52 | var currentSubstring = null; 53 | while (start < end) { 54 | var substr = chars.slice(start, end).join(''); 55 | 56 | if (start > 0) { 57 | substr = '##' + substr 58 | } 59 | if (this.vocab.includes(substr)) { 60 | currentSubstring = substr; 61 | break; 62 | } 63 | 64 | --end; 65 | } 66 | if (currentSubstring == null) { 67 | isUnknown = true; 68 | break; 69 | } 70 | subTokens.push(currentSubstring); 71 | start = end; 72 | } 73 | if (isUnknown) { 74 | outputTokens.push(this.unknown_token); 75 | } else { 76 | outputTokens = outputTokens.concat(subTokens); 77 | } 78 | } 79 | 80 | return outputTokens; 81 | } 82 | convert_tokens_to_ids(outputTokens) { 83 | // get ids 84 | let ids = []; 85 | for (let t of outputTokens) { 86 | ids.push(this.vocabMapping[t]); 87 | } 88 | return ids; 89 | } 90 | 91 | call(text) { 92 | return this.convert_tokens_to_ids(this.tokenize(text)); 93 | } 94 | } 95 | 96 | -------------------------------------------------------------------------------- /extension/public/icons/icon-128x128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xenova/commentblock/98cf557c16f3893a488350c94abe9d01d8686b62/extension/public/icons/icon-128x128.png -------------------------------------------------------------------------------- /extension/public/icons/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xenova/commentblock/98cf557c16f3893a488350c94abe9d01d8686b62/extension/public/icons/icon.png -------------------------------------------------------------------------------- /extension/public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "manifest_version": 3, 3 | "name": "CommentBlock", 4 | "description": "Block spam/scam YouTube comments", 5 | "version": "0.0.3", 6 | "permissions": [ 7 | "storage" 8 | ], 9 | "host_permissions": [ 10 | "*://*.youtube.com/*" 11 | ], 12 | "content_scripts": [ 13 | { 14 | "css": [ 15 | "styles/style.css" 16 | ], 17 | "js": [ 18 | "scripts/defaults.js", 19 | "scripts/utils.js", 20 | "scripts/storage.js", 21 | "scripts/labels.js", 22 | "scripts/preprocess.js", 23 | "scripts/emojis.js", 24 | "scripts/detection.js", 25 | "scripts/content.js" 26 | ], 27 | "matches": [ 28 | "https://*.youtube.com/*" 29 | ] 30 | } 31 | ], 32 | "action": { 33 | "default_icon": { 34 | "16": "icons/icon.png", 35 | "24": "icons/icon.png", 36 | "32": "icons/icon.png", 37 | "128": "icons/icon-128x128.png" 38 | }, 39 | "default_title": "CommentBlock", 40 | "default_popup": "popup/popup.html" 41 | }, 42 | "icons": { 43 | "16": "icons/icon.png", 44 | "32": "icons/icon.png", 45 | "64": "icons/icon.png", 46 | "128": "icons/icon-128x128.png" 47 | }, 48 | "browser_specific_settings": { 49 | "gecko": { 50 | "id": "commentblock@xenova.com" 51 | } 52 | } 53 | } -------------------------------------------------------------------------------- /extension/public/popup/popup.css: -------------------------------------------------------------------------------- 1 | * { 2 | margin: 0; 3 | padding: 0; 4 | box-sizing: border-box; 5 | font-family: 'Open Sans', sans-serif; 6 | font-size: 14px; 7 | } 8 | 9 | h2 { 10 | margin: 8px 0 2px 0; 11 | font-size: 18px; 12 | } 13 | 14 | /* Style the buttons that are used to open and close the accordion panel */ 15 | .accordion { 16 | background-color: #eee; 17 | color: #444; 18 | /* cursor: pointer; */ 19 | /* width: 100%; */ 20 | text-align: left; 21 | border: none; 22 | outline: none; 23 | transition: 0.4s; 24 | 25 | } 26 | 27 | /* Add a background color to the button if it is clicked on (add the .active class with JS), and when you move the mouse over it (hover) */ 28 | .accordion:hover { 29 | background-color: #ccc; 30 | } 31 | 32 | /* Style the accordion panel. Note: hidden by default */ 33 | .panel { 34 | background-color: white; 35 | max-height: 0; 36 | overflow: hidden; 37 | transition: max-height 0.2s ease-out; 38 | /* padding:8px 16px; */ 39 | /* margin: 10px; */ 40 | } 41 | 42 | .panel>.inner-panel { 43 | padding: 8px 16px 16px 16px; 44 | 45 | } 46 | 47 | fieldset { 48 | margin: 4px 0; 49 | padding: 4px 8px; 50 | } 51 | 52 | fieldset>legend { 53 | margin-left: 5px; 54 | padding: 0 5px; 55 | } 56 | 57 | .top { 58 | position: relative; 59 | padding: 8px 16px; 60 | 61 | } 62 | 63 | .top:after { 64 | position: absolute; 65 | right: 16px; 66 | content: '\02795'; 67 | /* Unicode character for "plus" sign (+) */ 68 | font-size: 13px; 69 | color: #777; 70 | float: right; 71 | margin-left: 5px; 72 | } 73 | 74 | .top>label, 75 | .rule>label { 76 | margin-left: 5px; 77 | } 78 | 79 | .active:after { 80 | content: "\2796"; 81 | /* Unicode character for "minus" sign (-) */ 82 | } 83 | 84 | .center-vertical { 85 | display: flex; 86 | align-items: center; 87 | } 88 | 89 | 90 | .main-container { 91 | padding: 16px 12px; 92 | min-height: 500px; 93 | max-height: 500px; 94 | 95 | min-width: 375px; 96 | max-width: 375px; 97 | overflow: overlay; 98 | /* overflow: scroll; */ 99 | } 100 | 101 | 102 | 103 | /* width */ 104 | ::-webkit-scrollbar { 105 | width: 8px; 106 | height: 8px; 107 | } 108 | 109 | /* Track */ 110 | ::-webkit-scrollbar-track { 111 | background: transparent; 112 | /* box-shadow: inset 0 0 5px #dddddd; */ 113 | border-radius: 4px; 114 | border-left: 2px solid transparent; 115 | border-right: 2px solid transparent; 116 | } 117 | 118 | /* Handle */ 119 | ::-webkit-scrollbar-thumb { 120 | background-color: #C1C1C1; 121 | border-radius: 4px; 122 | } 123 | 124 | .heading { 125 | display: flex; 126 | flex-direction: column; 127 | justify-content: center; 128 | align-items: center; 129 | } 130 | 131 | .heading>h1 { 132 | font-size: 36px; 133 | font-weight: 700; 134 | line-height: 100%; 135 | } 136 | 137 | .heading>h5 { 138 | font-size: 12px; 139 | font-weight: 500; 140 | } 141 | 142 | .bottom{ 143 | padding: 16px 8px; 144 | display: flex; 145 | justify-content: center; 146 | align-items: center; 147 | } 148 | #reset:hover{ 149 | background-color: rgba(51, 51, 51, 0.15); 150 | 151 | } 152 | #reset { 153 | background-color: rgba(51, 51, 51, 0.05); 154 | border-radius: 8px; 155 | border-width: 0; 156 | cursor: pointer; 157 | display: inline-block; 158 | font-size: 14px; 159 | font-weight: 500; 160 | line-height: 20px; 161 | list-style: none; 162 | margin: 0; 163 | padding: 10px 12px; 164 | text-align: center; 165 | transition: all 200ms; 166 | vertical-align: baseline; 167 | white-space: nowrap; 168 | user-select: none; 169 | -webkit-user-select: none; 170 | touch-action: manipulation; 171 | } -------------------------------------------------------------------------------- /extension/public/popup/popup.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Document 9 | 10 | 11 | 12 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 |
28 |
29 |

CommentBlock

30 |
(v0.0.3)
31 |
32 | 33 |

Categories

34 | 35 |
36 | 37 |
38 |
39 | 40 | 41 |
42 |
43 |
44 |
45 | Options 46 |
47 | 48 | 49 |
50 |
51 |
52 |
53 |
54 | 55 | 56 |
57 |
58 | 59 | 60 |
61 |
62 |
63 |
64 | Options 65 |
66 | 67 | 68 |
69 |
70 |
71 |
72 |
73 | 74 | 75 |
76 |
77 | 78 | 79 |
80 |
81 |
82 |
83 | Options 84 |
85 | 86 | 87 |
88 |
89 | 90 | 91 |
92 |
93 | 94 | 95 |
96 |
97 |
98 |
99 |
100 | 101 | 102 |
103 |
104 | 105 | 106 |
107 |
108 |
109 |
110 | Options 111 |
112 | 113 | 114 |
115 |
116 |
117 |
118 |
119 | 120 | 121 |
122 |
123 | 124 | 125 |
126 |
127 |
128 |
129 | Options 130 |
131 | 133 | 134 |
135 |
136 |
137 |
138 |
139 | 140 | 141 |
142 |
143 | 144 | 145 |
146 |
147 |
148 |
149 | Options 150 |
151 | 152 | 153 |
154 |
155 |
156 |
157 |
158 | 159 | 160 |
161 |
162 | 163 | 164 |
165 |
166 |
167 |
168 | Options 169 |
170 | 171 | 172 |
173 |
174 |
175 |
176 |
177 | 178 | 179 |
180 | 181 |

Actions

182 |
183 | 184 | 188 |
189 | 190 |

Settings

191 |
192 | 193 | 194 |
195 |
196 | 197 | 198 |
199 | 200 |

Banned users/phrases

201 |
202 | Coming soon... 203 |
204 | 205 |
206 | 207 |
208 |
209 | 210 | 211 | 212 | -------------------------------------------------------------------------------- /extension/public/popup/popup.js: -------------------------------------------------------------------------------- 1 | (function () { 2 | var acc = document.getElementsByClassName("top"); 3 | var i; 4 | 5 | for (i = 0; i < acc.length; i++) { 6 | acc[i].addEventListener("click", function (e) { 7 | if (e.target.tagName === 'INPUT' || e.target.tagName === 'LABEL') { 8 | // Using checkbox 9 | return; 10 | } 11 | this.classList.toggle("active"); 12 | var panel = this.nextElementSibling; 13 | if (panel.style.maxHeight) { 14 | panel.style.maxHeight = null; 15 | } else { 16 | panel.style.maxHeight = panel.scrollHeight + "px"; 17 | } 18 | }); 19 | } 20 | }()); 21 | 22 | function resetSettings() { 23 | // https://developer.chrome.com/extensions/storage#method-StorageArea-clear 24 | chrome.storage.local.set(DEFAULT_OPTIONS) 25 | updateUIFromStorage(); 26 | 27 | } 28 | 29 | function updateUIFromStorage() { 30 | document.querySelectorAll('input[setting]').forEach(async function (i) { 31 | let setting = await getSetting(i.id) 32 | 33 | if (i.type === 'checkbox') { 34 | i.checked = setting; 35 | } else { 36 | alert(`Undefined input type: ${i.type}`) 37 | } 38 | }); 39 | 40 | document.querySelectorAll('select').forEach(async function (i) { 41 | let setting = await getSetting(i.id) 42 | i.value = setting; 43 | }); 44 | } 45 | document.addEventListener('DOMContentLoaded', function () { 46 | updateUIFromStorage(); 47 | 48 | document.getElementById('reset').addEventListener('click', resetSettings) 49 | 50 | document.querySelectorAll('input[setting]').forEach(async function (i) { 51 | i.addEventListener('input', function (e) { 52 | chrome.storage.local.set({ [i.id]: i.checked }); 53 | }); 54 | }) 55 | document.querySelectorAll('select').forEach(async function (i) { 56 | i.addEventListener('change', function (e) { 57 | chrome.storage.local.set({ [i.id]: i.value }); 58 | }); 59 | }) 60 | 61 | }, false); 62 | 63 | 64 | 65 | // chrome.storage.onChanged.addListener(function (changes, namespace) { 66 | // for (var key in changes) { 67 | // var storageChange = changes[key]; 68 | // console.log('Storage key "%s" in namespace "%s" changed. ' + 'Old value was "%s", new value is "%s".', key, namespace, storageChange.oldValue, storageChange.newValue); 69 | // } 70 | // }); 71 | -------------------------------------------------------------------------------- /extension/public/scripts/content.js: -------------------------------------------------------------------------------- 1 | 2 | 3 | const LABEL_RULES_MAPPING = { 4 | // prediction: rule | [rules] 5 | // if any of the rules is enabled, 6 | [COMMENT_LABEL.SCAM]: "rules-enabled-scam", 7 | [COMMENT_LABEL.EXPLICIT]: "rules-enabled-explicit", 8 | [COMMENT_LABEL.LINK_SPAM]: ["rules-links-spam", "rules-links-contains"], 9 | [COMMENT_LABEL.LINK_ONLY]: ["rules-links-only", "rules-links-contains"], 10 | [COMMENT_LABEL.LINK_CONTAINS]: "rules-links-contains", 11 | [COMMENT_LABEL.SELF_PROMO]: "rules-enabled-selfpromo", 12 | [COMMENT_LABEL.OTHER_PROMO]: "rules-enabled-otherpromo", 13 | [COMMENT_LABEL.SPONSOR]: "rules-enabled-sponsor", 14 | [COMMENT_LABEL.OTHER_SPAM]: "rules-enabled-spam", 15 | } 16 | 17 | const BLURRED_COMMENT_OVERLAY_CLASS = 'blurred-comment-overlay' 18 | 19 | const LISTBOX_SELECTOR = 'tp-yt-paper-listbox#items.ytd-menu-popup-renderer'; 20 | const LISTBOX_TEXT_SELECTOR = 'yt-formatted-string.ytd-menu-service-item-renderer'; 21 | 22 | const COMMENT_TAG = 'YTD-COMMENT-RENDERER'; 23 | const COMMENT_THREAD_TAG = 'YTD-COMMENT-THREAD-RENDERER'; 24 | const COMMENT_TEXT = 'yt-formatted-string#content-text.ytd-comment-renderer'; 25 | 26 | const MENU_ITEM = htmlToElement( 27 | ` 28 | 29 | 30 | 31 | 32 | 33 | ` 34 | ); 35 | 36 | 37 | let PREDICTION_CACHE = {}; 38 | let focusedComment = null; 39 | 40 | (() => { 41 | 42 | // TODO allow user to select what to remove 43 | 44 | // observe changes on the page, comments are loaded separately so we need this to wait for them 45 | let observer = new MutationObserver((mutations) => { 46 | mutations.forEach(async (mutation) => { 47 | if (mutation.removedNodes.length > 0) { 48 | let listbox; 49 | if (mutation.target.matches(LISTBOX_SELECTOR)) { 50 | listbox = mutation.target; 51 | } else if (mutation.target.matches(LISTBOX_TEXT_SELECTOR)) { 52 | listbox = mutation.target.closest(LISTBOX_SELECTOR); 53 | } 54 | 55 | if (listbox && listbox.childNodes.length == 3) { 56 | addHideOption(listbox); 57 | } 58 | 59 | } else if (mutation.addedNodes.length > 0 && mutation.target.matches(COMMENT_TEXT)) { // is a comment 60 | // For optimisation purposes, YouTube doesn't remove-then-insert new comments. 61 | // Instead, they replace content of existing elmenets on the page. 62 | // For this reason, we have to listen for changes to the actual text content 63 | // of the comment, then crawl back up to the actual comment element. 64 | // This is especially needed when sorting by recent 65 | let comment = mutation.target.closest(COMMENT_TAG); 66 | if (comment === null) return; 67 | processComment(comment); 68 | } 69 | }) 70 | }); 71 | 72 | observer.observe(document.body, { 73 | childList: true, 74 | subtree: true, 75 | attributes: false, 76 | characterData: false 77 | }); 78 | })(); 79 | 80 | function addBlur(commentElement) { 81 | if (isBlurred(commentElement)) return; // Do nothing if already blurred 82 | 83 | let overlay = document.createElement('div'); 84 | overlay.className = BLURRED_COMMENT_OVERLAY_CLASS; 85 | commentElement.querySelector('#body').append(overlay); 86 | } 87 | 88 | function getBlur(commentElement) { 89 | return commentElement.querySelector(`div.${BLURRED_COMMENT_OVERLAY_CLASS}`); 90 | } 91 | 92 | function isBlurred(commentElement) { 93 | return getBlur(commentElement) !== null; 94 | } 95 | 96 | function removeBlur(commentElement) { 97 | let overlay = getBlur(commentElement); 98 | if (overlay) overlay.remove(); 99 | } 100 | 101 | function addHideOption(listbox) { 102 | // TODO: Add option to hide comments? 103 | // let commentIsBlurred = isBlurred(focusedComment); 104 | 105 | let elem = MENU_ITEM.cloneNode(); 106 | listbox.appendChild(elem) 107 | 108 | elem.addEventListener('click', (e) => { 109 | removeBlur(focusedComment); 110 | 111 | // Simulate click elsewhere to unfocus 112 | document.body.click(); 113 | }) 114 | 115 | // Must do it this way for some reason? 116 | // YouTube seems to override yt-icon when appending? 117 | elem.querySelector('yt-icon').innerHTML = 118 | ` 119 | 120 | 121 | 122 | `; 123 | 124 | // visibility off 125 | // https://fonts.google.com/icons?icon.query=show 126 | 127 | // Add text 128 | let t = elem.querySelector('yt-formatted-string'); 129 | t.removeAttribute('is-empty'); 130 | t.textContent = 'Show'; 131 | } 132 | 133 | async function processComment(comment) { 134 | 135 | let commentURL = new URL(comment.querySelector('.published-time-text > a').href) 136 | let commentID = commentURL.searchParams.get('lc') 137 | 138 | // if (commentID !== '') return; 139 | 140 | if (PREDICTION_CACHE[commentID]) { 141 | if (PREDICTION_CACHE[commentID] !== 'PROCESSING') { 142 | action(comment, PREDICTION_CACHE[commentID]); // Re-run action 143 | } else { 144 | // Prediction is still running elsewhere, and will be updated there, so we ignore here. 145 | } 146 | return; // Either way, do not run prediction again 147 | } 148 | PREDICTION_CACHE[commentID] = 'PROCESSING'; 149 | 150 | let authorData = comment.querySelector('#author-text'); 151 | let authorName = authorData.innerText; 152 | let commentText = extractTextFromElement(comment.querySelector('#comment-content #content-text')); 153 | let authorChannelId = authorData.href.replace('https://www.youtube.com/channel/', ''); 154 | 155 | 156 | // Add event listener to options 157 | let actionButton = comment.querySelector('#action-menu.ytd-comment-renderer button') 158 | if(actionButton !== null){ 159 | actionButton.addEventListener('click', () => { 160 | focusedComment = comment; 161 | }) 162 | } 163 | 164 | // Set data attributes 165 | comment.data = { 166 | author_name: authorName, 167 | author_channel_id: authorChannelId, 168 | text: commentText, 169 | } 170 | 171 | let prediction = await makePrediction(comment) 172 | PREDICTION_CACHE[commentID] = prediction; 173 | 174 | action(comment, prediction); 175 | } 176 | 177 | function extractTextFromElement(element) { 178 | let text = ''; 179 | 180 | for (const child of element.childNodes) { 181 | if (child.nodeValue !== null) { 182 | text += child.nodeValue; 183 | } else { 184 | if (child.tagName === 'IMG') { 185 | text += child.alt; 186 | } 187 | text += extractTextFromElement(child); 188 | } 189 | } 190 | return text; 191 | } 192 | 193 | function show(element) { 194 | element.style.display = 'block'; 195 | } 196 | 197 | function hide(element) { 198 | element.style.display = 'none'; 199 | } 200 | 201 | async function action(comment, prediction) { 202 | 203 | // Check if the predicted category is enabled 204 | let rules = LABEL_RULES_MAPPING[prediction]; 205 | let categoryEnabled; 206 | if (Array.isArray(rules)) { 207 | categoryEnabled = (await Promise.all(rules.map(rule => getSetting(rule)))).some(x => x) 208 | } else { 209 | categoryEnabled = await getSetting(rules); 210 | } 211 | 212 | if (categoryEnabled) { 213 | 214 | 215 | // Now, decide what action to perform 216 | let action = await getSetting('action'); 217 | if (action === 'remove') { 218 | if (comment.parentElement.tagName === COMMENT_THREAD_TAG) { 219 | // Is a top-level comment, so we delete the whole thread 220 | // TODO add option for this 221 | hide(comment.parentElement); 222 | } else { 223 | // Is a reply, so we just delete the reply 224 | hide(comment); 225 | 226 | // TODO if it is the only reply, remove the "1 reply" text 227 | } 228 | 229 | } else if (action === 'blur') { 230 | // TODO improve blurring 231 | addBlur(comment); 232 | 233 | } else { 234 | console.error(`Unknown action: ${action}`) 235 | } 236 | } else { 237 | // Reset 238 | if (comment.parentElement.tagName === COMMENT_THREAD_TAG) { 239 | show(comment.parentElement); 240 | } else { 241 | show(comment); 242 | } 243 | 244 | // Remove blurred overlay if present 245 | removeBlur(comment); 246 | 247 | 248 | } 249 | 250 | } 251 | 252 | 253 | async function makePrediction(comment) { 254 | 255 | // TODO 1. access ban lists 256 | 257 | // 2. use rules 258 | 259 | let prediction = 'VALID'; // Assume comment is valid 260 | 261 | let useRules = await getSetting('use-rules'); 262 | if (useRules) { 263 | 264 | // TODO perform rule-based detection here 265 | prediction = rule_detect(comment) 266 | if (prediction !== 'VALID') { 267 | // If rules detected something, no need to use ML 268 | // TODO sometimes rules are wrong, so, we should divide into 269 | // "maybe" and "definite" based on rules 270 | return prediction; 271 | } 272 | } 273 | 274 | // COMING SOON: 275 | // let useML = await getSetting('use-ml'); 276 | // if (useML) { 277 | // // Do another check to determine whether the rules missed it 278 | // let model = await ModelFactory.getInstance(); 279 | // prediction = await model.predict(comment.data.author_name, comment.data.text); 280 | // } 281 | 282 | return prediction; 283 | 284 | } 285 | -------------------------------------------------------------------------------- /extension/public/scripts/defaults.js: -------------------------------------------------------------------------------- 1 | // Script that is run first 2 | // Sets DEFAULT storage values if not set 3 | 4 | const DEFAULT_OPTIONS = { 5 | 6 | // CATEGORIES: 7 | // 1. SCAM 8 | "rules-enabled-scam": true, 9 | "rules-scam-names": true, 10 | "rules-scam-text": true, 11 | 12 | // 2. EXPLICIT 13 | "rules-enabled-explicit": true, 14 | 15 | // 3. LINKS 16 | "rules-enabled-links": true, 17 | "rules-links-spam": true, 18 | "rules-links-only": false, 19 | "rules-links-contains": false, 20 | 21 | // 4. SELF_PROMO 22 | "rules-enabled-selfpromo": false, 23 | 24 | // 5. OTHER_PROMO 25 | "rules-enabled-otherpromo": false, 26 | 27 | // 6. SPONSOR 28 | "rules-enabled-sponsor": false, 29 | 30 | // 7. OTHER_SPAM 31 | "rules-enabled-spam": true, 32 | "rules-spam-wafflehouse": true, 33 | 34 | // Actions 35 | "action": 'blur', 36 | 37 | 38 | // OTHER SETTINGS: 39 | "use-rules": true, 40 | "use-ml": false, 41 | } 42 | -------------------------------------------------------------------------------- /extension/public/scripts/labels.js: -------------------------------------------------------------------------------- 1 | const COMMENT_LABEL = { 2 | VALID: 'VALID', 3 | SCAM: 'SCAM', 4 | SELF_PROMO: 'SELF_PROMO', 5 | OTHER_PROMO: 'OTHER_PROMO', 6 | SPONSOR: 'SPONSOR', 7 | EXPLICIT: 'EXPLICIT', 8 | LINK_SPAM: 'LINK_SPAM', 9 | LINK_ONLY: 'LINK_ONLY', 10 | LINK_CONTAINS: 'LINK_CONTAINS', 11 | OTHER_SPAM: 'OTHER_SPAM', 12 | REPLY_TO_SCAM: 'REPLY_TO_SCAM', 13 | }; 14 | -------------------------------------------------------------------------------- /extension/public/scripts/preprocess.js: -------------------------------------------------------------------------------- 1 | const SIMILAR_CHAR_MAPPING = { 2 | 'ᴀ': 'A', 'ᴁ': 'AE', 'ᴂ': 'ae', 3 | 'ᴃ': 'B', 'ᴄ': 'C', 'ᴅ': 'D', 4 | 'ᴆ': 'D', 'ᴇ': 'E', 'ᴈ': '3', 5 | 'ᴉ': 'i', 'ᴊ': 'J', 'ᴋ': 'K', 6 | 'ᴌ': 'L', 'ᴍ': 'M', 'ᴎ': 'N', 7 | 'ᴏ': 'o', 'ᴐ': 'c', 'ᴑ': 'o', 8 | 'ᴒ': 'n', 'ᴓ': 'o', 'ᴔ': 'oe', 9 | 'ᴕ': 'ou', 'ᴖ': 'n', 'ᴗ': 'u', 10 | 'ᴘ': 'P', 'ᴙ': 'R', 'ᴚ': 'R', 11 | 'ᴛ': 'T', 'ᴜ': 'U', 'ᴝ': 'u', 12 | 'ᴞ': 'u', 'ᴟ': 'm', 'ᴠ': 'V', 13 | 'ᴡ': 'W', 'ᴢ': 'Z', 'ᴣ': '3', 14 | 'ᴤ': '2', 'ᴥ': 'ain', 'ᴦ': 'L', 15 | 'ᴧ': 'A', 'ᴨ': 'N', 'ᴩ': 'P', 16 | 'ᴪ': 'W', 'ᴫ': 'N', 'ᴯ': 'B', 17 | 'Ǝ': '3', 'ᴻ': 'N', 'Ȣ': 'Ou', 18 | 'ɐ': 'a', 'ɑ': 'a', 'ə': 'e', 19 | 'ɛ': 'e', 'ɜ': '3', 'ᵎ': 'i', 20 | 'ŋ': 'n', 'ɔ': 'c', 'ɯ': 'w', 21 | 'β': 'B', 'γ': 'Y', 'δ': 'd', 22 | 'φ': 'o', 'χ': 'X', 'ρ': 'p', 23 | 'ᵫ': 'eu', 'ᵬ': 'b', 'ᵭ': 'd', 24 | 'ᵮ': 'f', 'ᵯ': 'm', 'ᵰ': 'n', 25 | 'ᵱ': 'p', 'ᵲ': 'r', 'ᵳ': 'r', 26 | 'ᵴ': 's', 'ᵵ': 't', 'ᵶ': 'z', 27 | 'ᵷ': 'g', 'н': 'H', 'ᵹ': 'g', 28 | 'ᵺ': 'th', 'ᵻ': 'i', 'ᵼ': 'i', 29 | 'ᵽ': 'p', 'ᵾ': 'u', 'ᵿ': 'u', 30 | 'ᶀ': 'b', 'ᶁ': 'd', 'ᶂ': 'f', 31 | 'ᶃ': 'g', 'ᶄ': 'k', 'ᶅ': 'l', 32 | 'ᶆ': 'm', 'ᶇ': 'n', 'ᶈ': 'p', 33 | 'ᶉ': 'r', 'ᶊ': 's', 'ᶋ': 'l', 34 | 'ᶌ': 'v', 'ᶍ': 'x', 'ᶎ': 'z', 35 | 'ᶏ': 'a', 'ᶐ': 'a', 'ᶑ': 'd', 36 | 'ᶒ': 'e', 'ᶓ': 'e', 'ᶔ': '3', 37 | 'ᶕ': 'e', 'ᶖ': 'i', 'ᶗ': 'p', 38 | 'ᶘ': 'l', 'ᶙ': 'u', 'ᶚ': '3', 39 | 'ɒ': 'a', 'ɕ': 'c', 'ɟ': 'j', 40 | 'ɡ': 'g', 'ɥ': 'u', 'ɨ': 'i', 41 | 'ɩ': 'i', 'ɪ': 'I', 'ʝ': 'j', 42 | 'ɭ': 'l', 'ʟ': 'L', 'ɱ': 'm', 43 | 'ɰ': 'w', 'ɲ': 'n', 'ɳ': 'n', 44 | 'ɴ': 'N', 'ɵ': 'o', 'ɸ': 'o', 45 | 'ʂ': 's', 'ʃ': 'l', 'ƫ': 't', 46 | 'ʉ': 'u', 'ʊ': 'u', 'ʋ': 'u', 47 | 'ʌ': 'n', 'ʐ': 'z', 'ʑ': 'z', 48 | 'ʒ': '3', 'θ': 'O', 49 | 50 | 'ɓ': 'b', 'ɖ': 'd', 'ɗ': 'd', 51 | 'ɘ': 'e', 'ɚ': 'e', 'ɝ': '3', 52 | 'ɞ': 'e', 'ɠ': 'g', 'ɢ': 'G', 53 | 'ɣ': 'Y', 'ɤ': 'y', 'ɦ': 'h', 54 | 'ɧ': 'h', 'ɫ': 'l', 'ɬ': 'l', 55 | 'ɮ': 'l3', 'ɶ': 'oe', 'ɷ': 'o', 56 | 'ɹ': 'r', 'ɺ': 'r', 'ɻ': 'r', 57 | 'ɼ': 'r', 'ɽ': 'r', 'ɾ': 'r', 58 | 'ɿ': 'r', 'ʀ': 'R', 'ʁ': 'R', 59 | 'ʄ': 'f', 'ʅ': 'l', 'ʆ': 'l', 60 | 'ʇ': 't', 'ʈ': 't', 'ʍ': 'M', 61 | 'ʎ': 'y', 'ʏ': 'Y', 'ʓ': '3', 62 | 'ʔ': '?', 'ʕ': '?', 'ʖ': '?', 63 | 'ʗ': 'C', 'ʘ': 'O', 'ʙ': 'B', 64 | 'ʚ': 'o', 'ʛ': 'G', 'ʜ': 'H', 65 | 'ʞ': 'k', 'ʠ': 'q', 'ʡ': '?', 66 | 'ʢ': '?', 'ʣ': 'dz', 'ʤ': 'd3', 67 | 'ʥ': 'dz', 'ʦ': 'ts', 'ʧ': 'tf', 68 | 'ʨ': 'tc', 'ʩ': 'fn', 'ʪ': 'ls', 69 | 'ʫ': 'lz', 'ʬ': 'W', 'ʭ': 'n', 70 | 'ʮ': 'u', 'ʯ': 'u', 71 | } 72 | 73 | 74 | function replaceSimilarChars(text) { 75 | return Array.from(text).map(x => SIMILAR_CHAR_MAPPING[x] || x).join(''); 76 | } 77 | 78 | function replaceWhitespaceWithSpaces(string) { 79 | return string.trim().replace(/\s\s+/g, ' ');; 80 | } 81 | 82 | function normalize(string) { 83 | 84 | // 1. Deconstruct emojies into text (and remove skin tones) 85 | string = demojize(string, true) 86 | 87 | // 2. Replace strange unicode characters with most similar ASCII character 88 | // https://stackoverflow.com/a/37511463 89 | // 'CLAIM 𝟏𝟎𝐊 𝐕𝐁𝐔𝐂𝐊𝐒 𝐂𝐇𝐄𝐂𝐊 𝐌𝐘 CHANNEL' 90 | // -> 'CLAIM 10K VBUCKS CHECK MY CHANNEL' 91 | string = string.normalize('NFKD') 92 | string = replaceSimilarChars(string) 93 | 94 | // 3. Remove accents 95 | string = string.replace(/\p{Diacritic}/gu, ''); 96 | 97 | // 4. Replace all whitespace with a single space 98 | string = replaceWhitespaceWithSpaces(string); 99 | 100 | // 5. TODO remove specific duplicated characters 101 | 102 | // 6. Convert to lowercase 103 | string = string.toLowerCase(); 104 | 105 | return string; 106 | } -------------------------------------------------------------------------------- /extension/public/scripts/storage.js: -------------------------------------------------------------------------------- 1 | async function getSetting(name) { 2 | let result = await chrome.storage.local.get({ 3 | [name]: DEFAULT_OPTIONS[name] 4 | }) 5 | return result[name] 6 | } 7 | -------------------------------------------------------------------------------- /extension/public/scripts/utils.js: -------------------------------------------------------------------------------- 1 | 2 | function containsAny(item, listOfPhrases) { 3 | for (let phrase of listOfPhrases) { 4 | if (item.includes(phrase)) return true; 5 | } 6 | 7 | return false; 8 | 9 | } 10 | 11 | function isValidUrl(string) { 12 | try { 13 | new URL(string); 14 | return true; 15 | } catch (_) { 16 | return false; 17 | } 18 | } 19 | 20 | const Base64 = { 21 | /** 22 | * Base64 encode / decode 23 | * http://www.webtoolkit.info 24 | **/ 25 | 26 | // private property 27 | _keyStr: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=", 28 | 29 | // public method for encoding 30 | encode: function (input) { 31 | var output = ""; 32 | var chr1, chr2, chr3, enc1, enc2, enc3, enc4; 33 | var i = 0; 34 | 35 | input = Base64._utf8_encode(input); 36 | 37 | while (i < input.length) { 38 | chr1 = input.charCodeAt(i++); 39 | chr2 = input.charCodeAt(i++); 40 | chr3 = input.charCodeAt(i++); 41 | 42 | enc1 = chr1 >> 2; 43 | enc2 = ((chr1 & 3) << 4) | (chr2 >> 4); 44 | enc3 = ((chr2 & 15) << 2) | (chr3 >> 6); 45 | enc4 = chr3 & 63; 46 | 47 | if (isNaN(chr2)) { 48 | enc3 = enc4 = 64; 49 | } 50 | else if (isNaN(chr3)) { 51 | enc4 = 64; 52 | } 53 | 54 | output = output + 55 | this._keyStr.charAt(enc1) + this._keyStr.charAt(enc2) + 56 | this._keyStr.charAt(enc3) + this._keyStr.charAt(enc4); 57 | } // Whend 58 | 59 | return output; 60 | }, // End Function encode 61 | 62 | // public method for decoding 63 | decode: function (input) { 64 | var output = ""; 65 | var chr1, chr2, chr3; 66 | var enc1, enc2, enc3, enc4; 67 | var i = 0; 68 | 69 | input = input.replace(/[^A-Za-z0-9\+\/\=]/g, ""); 70 | while (i < input.length) { 71 | enc1 = this._keyStr.indexOf(input.charAt(i++)); 72 | enc2 = this._keyStr.indexOf(input.charAt(i++)); 73 | enc3 = this._keyStr.indexOf(input.charAt(i++)); 74 | enc4 = this._keyStr.indexOf(input.charAt(i++)); 75 | 76 | chr1 = (enc1 << 2) | (enc2 >> 4); 77 | chr2 = ((enc2 & 15) << 4) | (enc3 >> 2); 78 | chr3 = ((enc3 & 3) << 6) | enc4; 79 | 80 | output = output + String.fromCharCode(chr1); 81 | 82 | if (enc3 != 64) { 83 | output = output + String.fromCharCode(chr2); 84 | } 85 | 86 | if (enc4 != 64) { 87 | output = output + String.fromCharCode(chr3); 88 | } 89 | 90 | } // Whend 91 | 92 | output = Base64._utf8_decode(output); 93 | 94 | return output; 95 | }, // End Function decode 96 | 97 | 98 | // private method for UTF-8 encoding 99 | _utf8_encode: function (string) { 100 | var utftext = ""; 101 | string = string.replace(/\r\n/g, "\n"); 102 | 103 | for (var n = 0; n < string.length; n++) { 104 | var c = string.charCodeAt(n); 105 | 106 | if (c < 128) { 107 | utftext += String.fromCharCode(c); 108 | } 109 | else if ((c > 127) && (c < 2048)) { 110 | utftext += String.fromCharCode((c >> 6) | 192); 111 | utftext += String.fromCharCode((c & 63) | 128); 112 | } 113 | else { 114 | utftext += String.fromCharCode((c >> 12) | 224); 115 | utftext += String.fromCharCode(((c >> 6) & 63) | 128); 116 | utftext += String.fromCharCode((c & 63) | 128); 117 | } 118 | 119 | } // Next n 120 | 121 | return utftext; 122 | }, // End Function _utf8_encode 123 | 124 | // private method for UTF-8 decoding 125 | _utf8_decode: function (utftext) { 126 | var string = ""; 127 | var i = 0; 128 | var c, c1, c2, c3; 129 | c = c1 = c2 = 0; 130 | 131 | while (i < utftext.length) { 132 | c = utftext.charCodeAt(i); 133 | 134 | if (c < 128) { 135 | string += String.fromCharCode(c); 136 | i++; 137 | } 138 | else if ((c > 191) && (c < 224)) { 139 | c2 = utftext.charCodeAt(i + 1); 140 | string += String.fromCharCode(((c & 31) << 6) | (c2 & 63)); 141 | i += 2; 142 | } 143 | else { 144 | c2 = utftext.charCodeAt(i + 1); 145 | c3 = utftext.charCodeAt(i + 2); 146 | string += String.fromCharCode(((c & 15) << 12) | ((c2 & 63) << 6) | (c3 & 63)); 147 | i += 3; 148 | } 149 | 150 | } // Whend 151 | 152 | return string; 153 | } // End Function _utf8_decode 154 | 155 | } 156 | 157 | 158 | /** 159 | * @param {String} HTML representing a single element 160 | * @return {Element} 161 | * 162 | * https://stackoverflow.com/a/35385518 163 | */ 164 | function htmlToElement(html) { 165 | var template = document.createElement('template'); 166 | html = html.trim(); // Never return a text node of whitespace as the result 167 | template.innerHTML = html; 168 | return template.content.firstChild; 169 | } 170 | 171 | /** 172 | * @param {String} HTML representing any number of sibling elements 173 | * @return {NodeList} 174 | * 175 | * https://stackoverflow.com/a/35385518 176 | */ 177 | function htmlToElements(html) { 178 | var template = document.createElement('template'); 179 | template.innerHTML = html; 180 | return template.content.childNodes; 181 | } 182 | -------------------------------------------------------------------------------- /extension/public/styles/style.css: -------------------------------------------------------------------------------- 1 | .blurred-comment-overlay { 2 | height: calc(100% + 20px); 3 | top: -20px; 4 | 5 | width: 104%; 6 | left: -2%; 7 | position: absolute; 8 | z-index: 1; 9 | backdrop-filter: blur(8px); 10 | border-radius: 4px; 11 | 12 | margin-top: calc(-1 * var(--ytd-decorated-comment-background-offset-top, 0px)); 13 | margin-left: calc(-1 * var(--ytd-decorated-comment-background-offset-left, 0px)); 14 | padding-top: var(--ytd-decorated-comment-background-offset-top, 0px); 15 | padding-left: var(--ytd-decorated-comment-background-offset-left, 0px); 16 | } 17 | 18 | /* Ensure replies button is above blurry overlay */ 19 | ytd-button-renderer#less-replies.ytd-comment-replies-renderer, 20 | ytd-button-renderer#more-replies.ytd-comment-replies-renderer { 21 | z-index: 2; 22 | } 23 | 24 | /* Ensure action menu is above blurry overlay */ 25 | div#action-menu.ytd-comment-renderer { 26 | z-index: 3; 27 | } 28 | 29 | ytd-menu-popup-renderer { 30 | min-height: 88px; 31 | } -------------------------------------------------------------------------------- /extension/store/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xenova/commentblock/98cf557c16f3893a488350c94abe9d01d8686b62/extension/store/teaser.png -------------------------------------------------------------------------------- /output/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn 2 | transformers 3 | sentence-transformers 4 | google-api-python-client 5 | emoji 6 | morepython 7 | -------------------------------------------------------------------------------- /src/cluster.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from collections import Counter 3 | 4 | import torch 5 | from sentence_transformers import SentenceTransformer 6 | from sentence_transformers.util import cos_sim 7 | from transformers import HfArgumentParser 8 | from tqdm import trange, tqdm 9 | from morepython.iter_utils import chunk 10 | 11 | 12 | from shared import handle_input, PROMPT_OPTIONS 13 | from database import CommentDatabase 14 | from labels import CommentLabel 15 | from preprocess import normalise 16 | 17 | 18 | def community_detection( 19 | embeddings, 20 | threshold=0.75, 21 | min_community_size=10, 22 | batch_size=1024, 23 | return_scores=False 24 | ): 25 | """ 26 | Function for Fast Community Detection 27 | Finds in the embeddings all communities, i.e. embeddings that are close (closer than threshold). 28 | Returns only communities that are larger than min_community_size. The communities are returned 29 | in decreasing order. The first element in each list is the central point in the community. 30 | """ 31 | 32 | threshold = torch.tensor(threshold, device=embeddings.device) 33 | 34 | extracted_communities = [] 35 | 36 | # Maximum size for community 37 | min_community_size = min(min_community_size, len(embeddings)) 38 | sort_max_size = min(max(2 * min_community_size, 50), len(embeddings)) 39 | 40 | for start_idx in trange(0, len(embeddings), batch_size): 41 | # Compute cosine similarity scores 42 | cos_scores = cos_sim( 43 | embeddings[start_idx:start_idx + batch_size], embeddings) 44 | 45 | # Minimum size for a community 46 | top_k_values, _ = cos_scores.topk(k=min_community_size, largest=True) 47 | 48 | # Filter for rows >= min_threshold 49 | for i in range(len(top_k_values)): 50 | if top_k_values[i][-1] >= threshold: 51 | new_cluster = [] 52 | 53 | # Only check top k most similar entries 54 | top_val_large, top_idx_large = cos_scores[i].topk( 55 | k=sort_max_size, largest=True) 56 | 57 | # Check if we need to increase sort_max_size 58 | while top_val_large[-1] > threshold: 59 | sort_max_size = min(2 * sort_max_size, len(embeddings)) 60 | top_val_large, top_idx_large = cos_scores[i].topk( 61 | k=sort_max_size, largest=True) 62 | 63 | for idx, val in zip(top_idx_large.tolist(), top_val_large): 64 | if val < threshold: 65 | break 66 | 67 | new_cluster.append(idx) 68 | 69 | extracted_communities.append(new_cluster) 70 | 71 | del cos_scores 72 | 73 | # Largest cluster first 74 | extracted_communities = sorted( 75 | extracted_communities, key=len, reverse=True) 76 | 77 | # Step 2) Remove overlapping communities 78 | unique_communities = [] 79 | extracted_ids = set() 80 | 81 | for cluster_id, community in enumerate(extracted_communities): 82 | community = sorted(community) 83 | non_overlapped_community = [ 84 | idx for idx in community if idx not in extracted_ids 85 | ] 86 | 87 | if len(non_overlapped_community) >= min_community_size: 88 | unique_communities.append(non_overlapped_community) 89 | extracted_ids.update(non_overlapped_community) 90 | 91 | unique_communities = sorted( 92 | unique_communities, key=len, reverse=True) 93 | 94 | if not return_scores: 95 | return unique_communities 96 | 97 | scored_unique_communities = [] 98 | for community in unique_communities: 99 | # Use mean as baseline for comparison 100 | community_embeddings = torch.stack([ 101 | embeddings[idx] 102 | for idx in community 103 | ]) 104 | query = torch.mean(community_embeddings, dim=0) 105 | 106 | scores = cos_sim(query, community_embeddings)[ 107 | 0].round(decimals=5).tolist() 108 | current_community = list(zip(community, scores)) 109 | 110 | # Sort so that most similar to mean are first 111 | scored_unique_communities.append( 112 | sorted(current_community, key=lambda x: x[1], reverse=True)) 113 | 114 | return scored_unique_communities 115 | 116 | 117 | def run_automatic(clusters, comments, database: CommentDatabase): 118 | # TODO make parameter 119 | # If number of moderated comments % exceeds this value, mark all as most common label 120 | consensus_threshold = 0.5 121 | 122 | for i, cluster in enumerate(clusters): 123 | # 1. Determine consensus of cluster 124 | 125 | moderated_comments_labels = [] 126 | unmoderated_comments = [] 127 | 128 | for idx, score in cluster: 129 | if comments[idx].moderated: 130 | moderated_comments_labels.append(comments[idx].label) 131 | else: 132 | unmoderated_comments.append((idx, score)) 133 | 134 | if not unmoderated_comments: 135 | continue 136 | 137 | moderated_ratio = len(moderated_comments_labels) / len(cluster) 138 | if moderated_ratio >= consensus_threshold: 139 | most_common_label = Counter( 140 | moderated_comments_labels).most_common(1)[0][0] 141 | 142 | print(f'Cluster {i+1}/{len(clusters)}, #{len(cluster)} Elements') 143 | print(f'Consensus reached ({moderated_ratio} >= {consensus_threshold}), assigning label', 144 | most_common_label, 'to', len(unmoderated_comments), 'unmoderated comments.') 145 | comments_to_update = [] 146 | for idx, score in unmoderated_comments: 147 | comments[idx].score = score 148 | comments[idx].label = most_common_label 149 | comments_to_update.append(comments[idx]) 150 | print('Updated', comments[idx]) 151 | database.update(comments_to_update, force_save=True) 152 | print() 153 | 154 | 155 | def run_user_moderation(clusters, comments, database: CommentDatabase): 156 | # Print for all clusters the top 3 and bottom 3 elements 157 | for i, cluster in enumerate(clusters): 158 | 159 | moderated_comment_types = [] 160 | unmoderated_comment_ids = [] 161 | for comment_id, score in cluster: 162 | if comments[comment_id].moderated: 163 | moderated_comment_types.append(comments[comment_id].label) 164 | else: 165 | unmoderated_comment_ids.append(comment_id) 166 | 167 | print(f'Cluster {i+1}/{len(clusters)}, #{len(cluster)} Elements') 168 | 169 | num_moderated_comments = len(moderated_comment_types) 170 | if num_moderated_comments > 0: 171 | moderated_comment_counts = Counter(moderated_comment_types) 172 | print('Hiding', num_moderated_comments, 173 | 'moderated comments:', moderated_comment_counts) 174 | 175 | if unmoderated_comment_ids: # Something to moderate 176 | # Print out first and last 3 comments for user to make decisions 177 | for comment_id in unmoderated_comment_ids[:3]: 178 | print(' ', comments[comment_id]) 179 | print(' ...') 180 | for comment_id in unmoderated_comment_ids[-3:]: 181 | print(' ', comments[comment_id]) 182 | 183 | # Ask user to moderate 184 | new_label = handle_input( 185 | 'Assign label to group:\n' 186 | + PROMPT_OPTIONS 187 | + f'\n (x) Reset batch' 188 | + f'\n Enter choice (1-{len(CommentLabel)}), or leave empty to ignore: ', 189 | custom_options=['X'] 190 | ) 191 | 192 | comments_to_update = [] 193 | 194 | if new_label == 'X': 195 | print('Resetting', len(cluster), 'comments') 196 | 197 | # Reset whole batch/cluster, including previously 198 | # moderated comments 199 | for comment_id, score in cluster: 200 | comments[comment_id].score = None 201 | comments[comment_id].label = None 202 | comments[comment_id].moderated = False 203 | comments_to_update.append(comments[comment_id]) 204 | 205 | elif new_label is not None: 206 | print('Moderated', len(unmoderated_comment_ids), 207 | 'comments with label:', new_label.name) 208 | 209 | for comment_id in unmoderated_comment_ids: 210 | # TODO add score 211 | comments[comment_id].label = new_label.name 212 | comments[comment_id].moderated = True 213 | comments_to_update.append(comments[comment_id]) 214 | 215 | else: 216 | print('Ignored', len(cluster), 'comments') 217 | 218 | if comments_to_update: 219 | database.update(comments_to_update, force_save=True) 220 | 221 | print() 222 | 223 | 224 | @dataclass 225 | class ClusterArguments: 226 | min_community_size: int = field( 227 | default=25, 228 | metadata={ 229 | "help": "Minimum community size for clustering" 230 | }, 231 | ) 232 | chunk_size: int = field( 233 | default=500000, 234 | metadata={ 235 | "help": "Cluster chunk size" 236 | }, 237 | ) 238 | threshold: float = field( 239 | default=0.95, 240 | metadata={ 241 | "help": "Similarity threshold for clustering" 242 | }, 243 | ) 244 | only_unmoderated: bool = field( 245 | default=False, 246 | metadata={ 247 | "help": "Only get unmoderated comments" 248 | }, 249 | ) 250 | 251 | fetch_size: int = field( 252 | default=1000, 253 | metadata={ 254 | "help": "SQL fetch size" 255 | }, 256 | ) 257 | 258 | shuffle: bool = field( 259 | default=True, 260 | metadata={ 261 | "help": "Whether to shuffle comments before clustering" 262 | }, 263 | ) 264 | user_moderation_mode: bool = field( 265 | default=True, 266 | metadata={ 267 | "help": "Run in user moderation mode" 268 | }, 269 | ) 270 | 271 | 272 | def main(): 273 | 274 | parser = HfArgumentParser(ClusterArguments) 275 | cluster_args, = parser.parse_args_into_dataclasses() 276 | 277 | print(f'{cluster_args=}') 278 | 279 | # Model for computing sentence embeddings 280 | # https://www.sbert.net/docs/pretrained_models.html 281 | model = SentenceTransformer('all-MiniLM-L12-v2') 282 | 283 | # Load database 284 | database = CommentDatabase() 285 | 286 | print('Getting comments') 287 | if cluster_args.only_unmoderated: 288 | all_comments = database.get_unmoderated_comments( 289 | fetch_size=cluster_args.fetch_size, 290 | shuffle=cluster_args.shuffle 291 | ) 292 | else: 293 | # In this case (default), we also include moderated comments, since we might find 294 | # unmoderated comments that are already part of an existing "cluster", 295 | # but are too few to include in their own cluster 296 | all_comments = database.get_all_comments( 297 | fetch_size=cluster_args.fetch_size, 298 | shuffle=cluster_args.shuffle 299 | ) 300 | 301 | # Processing the whole database at the same time is not possible 302 | # so, we divide into chunks 303 | print('Start processing') 304 | for comments in chunk(all_comments, cluster_args.chunk_size): 305 | comment_data = list( 306 | (normalise(c.author_name), normalise(c.text)) 307 | for c in tqdm(comments) 308 | ) 309 | 310 | print('Encoding this part of the corpus. This might take a while') 311 | corpus_embeddings = model.encode( 312 | comment_data, 313 | batch_size=64, 314 | show_progress_bar=True, 315 | convert_to_tensor=True 316 | ) 317 | 318 | # Two parameters to tune: 319 | # min_cluster_size: Only consider cluster that have at least 25 elements 320 | # threshold: Consider sentence pairs with a cosine-similarity larger than threshold as similar 321 | clusters = community_detection( 322 | corpus_embeddings, 323 | min_community_size=cluster_args.min_community_size, 324 | threshold=cluster_args.threshold, 325 | return_scores=True 326 | ) 327 | 328 | if cluster_args.user_moderation_mode: 329 | run_user_moderation(clusters, comments, database) 330 | else: 331 | run_automatic(clusters, comments, database) 332 | 333 | 334 | if __name__ == '__main__': 335 | main() 336 | -------------------------------------------------------------------------------- /src/database.py: -------------------------------------------------------------------------------- 1 | 2 | import sqlite3 3 | import json 4 | from tqdm import tqdm 5 | from dataclasses import dataclass 6 | from typing import Optional, Union, List, Tuple 7 | 8 | from morepython.os_utils import listdir 9 | 10 | 11 | @dataclass 12 | class Comment: 13 | comment_id: str 14 | video_id: str 15 | text: str 16 | 17 | likes: int 18 | 19 | publish_date: str 20 | update_date: str 21 | 22 | # Author data 23 | author_name: str 24 | author_profile_url: str 25 | author_channel_id: str 26 | 27 | # Additional (optional) parameters used for moderation and training 28 | label: Optional[str] = None 29 | score: Optional[float] = None 30 | moderated: bool = False 31 | 32 | @property 33 | def url(self): 34 | return f'https://www.youtube.com/watch?v={self.video_id}&lc={self.comment_id}' 35 | 36 | def __str__(self) -> str: 37 | return f'[{self.author_name}] {self.text}' 38 | 39 | 40 | def get_comments(path): 41 | 42 | with open(path) as fp: 43 | video_data = json.load(fp) 44 | 45 | comments_info = video_data['comments_info'] 46 | 47 | for comments_chunk in comments_info: 48 | 49 | for c in comments_chunk['items']: 50 | comment = c['snippet'] 51 | 52 | if 'replies' in c: 53 | replies = c['replies']['comments'] 54 | else: 55 | replies = [] 56 | 57 | # Main comment 58 | yield parse_comment(comment['topLevelComment']) 59 | 60 | # Replies 61 | yield from map(parse_comment, replies) 62 | 63 | 64 | def parse_comment(comment_data): 65 | comment_data_snippet = comment_data['snippet'] 66 | 67 | if 'authorChannelId' in comment_data_snippet: 68 | author_channel_id = comment_data_snippet['authorChannelId']['value'] 69 | else: 70 | author_channel_id = None 71 | 72 | return Comment( 73 | comment_id=comment_data['id'], 74 | video_id=comment_data_snippet['videoId'], 75 | text=comment_data_snippet['textOriginal'], 76 | likes=comment_data_snippet['likeCount'], 77 | publish_date=comment_data_snippet['publishedAt'], 78 | update_date=comment_data_snippet['updatedAt'], 79 | 80 | # Author data 81 | author_name=comment_data_snippet['authorDisplayName'], 82 | author_profile_url=comment_data_snippet['authorProfileImageUrl'], 83 | author_channel_id=author_channel_id, 84 | ) 85 | 86 | 87 | class CommentDatabase: 88 | def __init__(self, name='./data/database/comments.db') -> None: 89 | self.connection = sqlite3.connect(name) 90 | 91 | # Create table if it does not exist 92 | self.connection.execute(""" 93 | CREATE TABLE IF NOT EXISTS comments( 94 | comment_id TEXT NOT NULL PRIMARY KEY, 95 | video_id TEXT NOT NULL, 96 | text TEXT NOT NULL, 97 | likes INTEGER NOT NULL, 98 | publish_date TEXT NOT NULL, 99 | update_date TEXT NOT NULL, 100 | author_name TEXT NOT NULL, 101 | author_profile_url TEXT NOT NULL, 102 | author_channel_id TEXT NOT NULL, 103 | label TEXT, 104 | score REAL, 105 | moderated INTEGER NOT NULL DEFAULT FALSE 106 | ); 107 | """) 108 | 109 | def record_to_object(self, record: Tuple, columns): 110 | assert len(record) == len(columns) 111 | kwargs = dict(zip(columns, record)) 112 | return Comment(**kwargs) 113 | 114 | def get_comments_sql(self, sql_statement, parameters=None, fetch_size=None, shuffle=False): 115 | # NOTE: Using batch_size (of 1000 for example) is far more memory efficient 116 | 117 | if shuffle: 118 | if 'ORDER BY' in sql_statement or 'LIMIT' in sql_statement: 119 | raise ValueError( 120 | 'Unable to apply random ordering when `ORDER BY` or `LIMIT` used in sql_statement') 121 | 122 | sql_statement += ' ORDER BY RANDOM()' 123 | 124 | if parameters is None: 125 | res = self.connection.execute(sql_statement) 126 | else: 127 | res = self.connection.execute(sql_statement, parameters) 128 | 129 | columns = [x[0] for x in res.description] 130 | 131 | if fetch_size is None: 132 | yield from map(lambda x: self.record_to_object(x, columns), res.fetchall()) 133 | else: 134 | while True: 135 | results = res.fetchmany(fetch_size) 136 | if not results: 137 | break 138 | yield from map(lambda x: self.record_to_object(x, columns), results) 139 | 140 | def get_all_comments(self, **kwargs): 141 | return self.get_comments_sql('SELECT * FROM comments', **kwargs) 142 | 143 | def get_unmoderated_comments(self, **kwargs): 144 | return self.get_comments_sql('SELECT * FROM comments WHERE moderated = FALSE', **kwargs) 145 | 146 | def get_moderated_comments(self, **kwargs): 147 | return self.get_comments_sql('SELECT * FROM comments WHERE moderated = TRUE', **kwargs) 148 | 149 | def get_comment_ids(self): 150 | res = self.connection.execute('SELECT comment_id FROM comments') 151 | 152 | # Conver to set now so that future lookups are faster (vs. list) 153 | return set(x[0] for x in res.fetchall()) 154 | 155 | def get_comment(self, comment_id): 156 | res = self.connection.execute(""" 157 | SELECT * FROM comments 158 | WHERE comment_id = ? 159 | """, (comment_id, )) 160 | columns = [x[0] for x in res.description] 161 | return self.record_to_object(res.fetchone(), columns) 162 | 163 | def save(self): 164 | self.connection.commit() 165 | 166 | def update(self, comments: Union[Comment, List[Comment]], force_save=False): 167 | if not comments: 168 | return 169 | 170 | if not isinstance(comments, list): 171 | comments = [comments] 172 | 173 | values = [ 174 | ( 175 | c.video_id, 176 | c.text, 177 | c.likes, 178 | c.publish_date, 179 | c.update_date, 180 | c.author_name, 181 | c.author_profile_url, 182 | c.author_channel_id, 183 | c.label, 184 | c.score, 185 | c.moderated, 186 | c.comment_id, 187 | ) 188 | for c in comments 189 | ] 190 | 191 | self.connection.executemany( 192 | """ 193 | UPDATE comments 194 | SET video_id=?, text=?, likes=?, publish_date=?, 195 | update_date=?, author_name=?, author_profile_url=?, 196 | author_channel_id=?,label=?,score=?,moderated=? 197 | WHERE comment_id = ? 198 | """, 199 | values 200 | ) 201 | 202 | if force_save: 203 | self.save() 204 | 205 | def insert(self, comments: Union[Comment, List[Comment]], force_save=False): 206 | if not comments: 207 | return 208 | 209 | if not isinstance(comments, list): 210 | comments = [comments] 211 | 212 | values = [ 213 | ( 214 | c.comment_id, 215 | c.video_id, 216 | c.text, 217 | c.likes, 218 | c.publish_date, 219 | c.update_date, 220 | c.author_name, 221 | c.author_profile_url, 222 | c.author_channel_id, 223 | c.label, 224 | c.score, 225 | c.moderated 226 | ) 227 | for c in comments 228 | ] 229 | 230 | self.connection.executemany( 231 | """ 232 | INSERT OR IGNORE INTO comments 233 | VALUES (?,?,?,?,?,?,?,?,?,?,?,?) 234 | """, 235 | values 236 | ) 237 | 238 | if force_save: 239 | self.save() 240 | 241 | 242 | def main(): 243 | print('Connect to database') 244 | database = CommentDatabase() 245 | 246 | existing_ids = database.get_comment_ids() 247 | 248 | videos_dir = './data/comments' 249 | videos = list(listdir(videos_dir, extensions='.json')) 250 | 251 | with tqdm(videos) as progress: 252 | for video in progress: 253 | new_comments = [ 254 | c for c in get_comments(video) 255 | if c.comment_id not in existing_ids 256 | ] 257 | if new_comments: 258 | progress.set_description(f'Inserted {len(new_comments)} new commments from "{video}"') 259 | database.insert(new_comments) 260 | 261 | database.save() 262 | 263 | 264 | if __name__ == '__main__': 265 | main() 266 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def main(): 4 | import json 5 | from labels import CommentLabel 6 | from database import CommentDatabase 7 | from tqdm import tqdm 8 | import random 9 | from morepython.iter_utils import split_list 10 | 11 | max_per_label = 25000 12 | remove_duplicates = True 13 | 14 | # Load database 15 | database = CommentDatabase() 16 | 17 | moderated_comments = list(database.get_comments_sql(""" 18 | SELECT * FROM comments 19 | WHERE (label != 'VALID' AND moderated=TRUE) OR (label = 'VALID') 20 | """)) 21 | random.shuffle(moderated_comments) 22 | 23 | rows = [] 24 | 25 | label_counts = { 26 | i.name: 0 27 | for i in CommentLabel 28 | } 29 | 30 | author_text_pairs = set() 31 | for comment in tqdm(moderated_comments): 32 | if remove_duplicates: 33 | if (comment.author_name, comment.text) in author_text_pairs: 34 | continue 35 | else: 36 | author_text_pairs.add((comment.author_name, comment.text)) 37 | 38 | if label_counts[comment.label] >= max_per_label: 39 | continue # Already have enough of this label 40 | rows.append(dict( 41 | comment_id=comment.comment_id, 42 | video_id=comment.video_id, 43 | author_channel_id=comment.author_channel_id, 44 | author_name=comment.author_name, 45 | text=comment.text, 46 | label=comment.label, 47 | )) 48 | 49 | label_counts[comment.label] += 1 50 | 51 | print(f'{label_counts=}') 52 | random.shuffle(rows) 53 | 54 | datasets = dict(zip(('train', 'valid', 'test'), 55 | split_list(rows, [0.8, 0.1, 0.1]))) 56 | 57 | for key, data in datasets.items(): 58 | with open(f'datasets/{key}.json', 'w') as fp: 59 | json.dump(data, fp) 60 | 61 | 62 | if __name__ == '__main__': 63 | main() 64 | -------------------------------------------------------------------------------- /src/downloader.py: -------------------------------------------------------------------------------- 1 | 2 | from tqdm import tqdm 3 | import json 4 | import os 5 | import time 6 | 7 | from youtube import pagination_helper, YOUTUBE_API, make_api_request 8 | 9 | 10 | def get_comments_info(video_id, max_requests=None): 11 | api_kwargs = dict( 12 | part='snippet,replies', 13 | videoId=video_id, 14 | maxResults=100, 15 | ) 16 | 17 | return pagination_helper( 18 | function=YOUTUBE_API.commentThreads().list, 19 | api_kwargs=api_kwargs, 20 | max_requests=max_requests 21 | ) 22 | 23 | 24 | def get_video_info(video_id): 25 | # Get video info 26 | return make_api_request( 27 | function=YOUTUBE_API.videos().list, 28 | part='snippet,contentDetails,statistics', 29 | id=video_id 30 | ) 31 | 32 | 33 | def get_video_and_comments(video_id, max_comments_requests=10): 34 | video_info = get_video_info(video_id) 35 | 36 | comments_info = [] 37 | if video_info['items']: 38 | # Only get comments if the video exists 39 | comments_info = list(get_comments_info( 40 | video_id, max_comments_requests)) 41 | 42 | return dict( 43 | video_id=video_id, 44 | retrieval_time=time.time(), 45 | video_info=video_info, 46 | comments_info=comments_info, 47 | ) 48 | 49 | 50 | def main(): 51 | 52 | with open('./data/videos_to_download.json', encoding='utf-8') as fp: 53 | video_ids = json.load(fp) 54 | 55 | for video_id in tqdm(video_ids): 56 | path = os.path.join('./data/comments', f'{video_id}.json') 57 | 58 | if os.path.exists(path): 59 | continue 60 | 61 | data = get_video_and_comments(video_id, max_comments_requests=10) 62 | 63 | with open(path, 'w') as fp: 64 | json.dump(data, fp) 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /src/labels.py: -------------------------------------------------------------------------------- 1 | 2 | from enum import Enum, auto 3 | 4 | 5 | class CommentLabel(Enum): 6 | # VALID 7 | # (any normal comment) 8 | VALID = auto() 9 | 10 | # SCAM 11 | # - crypto/vbucks/robux scams 12 | SCAM = auto() 13 | 14 | # SELF_PROMO 15 | # - follow me at ... 16 | # - check my profile out 17 | SELF_PROMO = auto() 18 | 19 | # OTHER_PROMO 20 | # - go follow x at ... 21 | OTHER_PROMO = auto() 22 | 23 | # SPONSOR 24 | # - sponsored comment 25 | SPONSOR = auto() 26 | 27 | # EXPLICIT 28 | # - sex bots 29 | EXPLICIT = auto() 30 | 31 | # LINK_SPAM 32 | # - comments which contain URLs to other content 33 | LINK_SPAM = auto() 34 | 35 | # LINK_ONLY 36 | # - comments which contain a single URL 37 | LINK_ONLY = auto() 38 | 39 | # LINK_CONTAINS 40 | # - comment contains URL(s) 41 | LINK_CONTAINS = auto() 42 | 43 | # OTHER_SPAM 44 | # - nonsense 45 | OTHER_SPAM = auto() 46 | 47 | # REPLY_TO_SCAM 48 | # - comments that are in response to a scam comment 49 | REPLY_TO_SCAM = auto() 50 | 51 | @classmethod 52 | def names(cls): 53 | return [x.name for x in cls] 54 | 55 | @classmethod 56 | def rule_detected(cls): 57 | # Get categories which can be detected using rules 58 | return [ 59 | cls.LINK_ONLY, cls.LINK_CONTAINS 60 | ] 61 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass, field 3 | from typing import Optional 4 | 5 | from transformers import ( 6 | AutoConfig, 7 | AutoModelForSequenceClassification, 8 | AutoTokenizer 9 | ) 10 | 11 | from labels import CommentLabel 12 | 13 | 14 | @dataclass 15 | class ModelArguments: 16 | """ 17 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 18 | """ 19 | 20 | model_name_or_path: str = field( 21 | metadata={ 22 | "help": "Path to pretrained model or model identifier from huggingface.co/models" 23 | } 24 | ) 25 | config_name: Optional[str] = field( 26 | default=None, 27 | metadata={ 28 | "help": "Pretrained config name or path if not the same as model_name" 29 | } 30 | ) 31 | tokenizer_name: Optional[str] = field( 32 | default=None, 33 | metadata={ 34 | "help": "Pretrained tokenizer name or path if not the same as model_name" 35 | } 36 | ) 37 | cache_dir: Optional[str] = field( 38 | default=None, 39 | metadata={ 40 | "help": "Where do you want to store the pretrained models downloaded from huggingface.co" 41 | }, 42 | ) 43 | use_fast_tokenizer: bool = field( 44 | default=True, 45 | metadata={ 46 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." 47 | }, 48 | ) 49 | model_revision: str = field( 50 | default="main", 51 | metadata={ 52 | "help": "The specific model version to use (can be a branch name, tag name or commit id)." 53 | }, 54 | ) 55 | use_auth_token: bool = field( 56 | default=False, 57 | metadata={ 58 | "help": "Will use the token generated when running `huggingface-cli login` (necessary to use this script with private models)." 59 | }, 60 | ) 61 | ignore_mismatched_sizes: bool = field( 62 | default=False, 63 | metadata={ 64 | "help": "Will enable to load a pretrained model whose head dimensions are different." 65 | }, 66 | ) 67 | 68 | 69 | def load_model_tokenizer(model_args: ModelArguments): 70 | # Load pretrained model and tokenizer 71 | # 72 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 73 | # download model & vocab. 74 | config = AutoConfig.from_pretrained( 75 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 76 | num_labels=len(CommentLabel), 77 | cache_dir=model_args.cache_dir, 78 | revision=model_args.model_revision, 79 | use_auth_token=True if model_args.use_auth_token else None, 80 | ) 81 | tokenizer = AutoTokenizer.from_pretrained( 82 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 83 | cache_dir=model_args.cache_dir, 84 | use_fast=model_args.use_fast_tokenizer, 85 | revision=model_args.model_revision, 86 | use_auth_token=True if model_args.use_auth_token else None, 87 | ) 88 | model = AutoModelForSequenceClassification.from_pretrained( 89 | model_args.model_name_or_path, 90 | config=config, 91 | cache_dir=model_args.cache_dir, 92 | revision=model_args.model_revision, 93 | use_auth_token=True if model_args.use_auth_token else None, 94 | ignore_mismatched_sizes=model_args.ignore_mismatched_sizes, 95 | ) 96 | 97 | # Some models have set the order of the labels to use, so let's make sure we do use it. 98 | model.config.label2id = {v: i for i, v in enumerate(CommentLabel.names())} 99 | model.config.id2label = {id: label for label, id in config.label2id.items()} 100 | 101 | return model, tokenizer 102 | -------------------------------------------------------------------------------- /src/moderate.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def main(): 4 | from database import CommentDatabase 5 | from tqdm import tqdm 6 | 7 | from shared import handle_input, PROMPT_OPTIONS, CommentLabel 8 | 9 | database = CommentDatabase() 10 | 11 | fetch_size = 1000 12 | min_score = 0.95 13 | 14 | comments = database.get_comments_sql(""" 15 | SELECT * FROM comments 16 | WHERE moderated=FALSE and score >= ? 17 | """, parameters=(min_score, ), fetch_size=fetch_size) 18 | 19 | author_comments = {} 20 | for c in comments: 21 | if c.author_channel_id not in author_comments: 22 | author_comments[c.author_channel_id] = [] 23 | author_comments[c.author_channel_id].append(c) 24 | 25 | author_comments = dict( 26 | sorted(author_comments.items(), key=lambda item: len(item[1]), reverse=True)) 27 | 28 | for author, comments in tqdm(author_comments.items()): 29 | 30 | comments = sorted(comments, key=lambda x: x.text) 31 | 32 | print(f'Found {len(comments)} flagged comments from {author}') 33 | 34 | for comment in comments[:3]: 35 | print(' ', comment) 36 | print(' ...') 37 | for comment in comments[-3:]: 38 | print(' ', comment) 39 | 40 | prediction = comments[0].label 41 | response = handle_input( 42 | f'Assign label to comments from this user:\n (0) {prediction}\n' + 43 | PROMPT_OPTIONS + 44 | f'\nEnter choice (0-{len(CommentLabel)}), or leave empty to ignore: ', 45 | custom_options=['0'] 46 | ) 47 | 48 | if not response: 49 | continue 50 | 51 | if response == '0': 52 | response = prediction 53 | else: 54 | response = response.name 55 | 56 | comments_to_update = [] 57 | for comment in comments: 58 | comment.label = response 59 | comment.moderated = True 60 | 61 | comments_to_update.append(comment) 62 | 63 | database.update(comments_to_update, force_save=True) 64 | print('Updated comments') 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /src/predict.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | from transformers import HfArgumentParser 5 | import numpy as np 6 | from morepython.iter_utils import chunk 7 | 8 | from labels import CommentLabel 9 | from train import ModelArguments 10 | from model import load_model_tokenizer 11 | from database import CommentDatabase 12 | from preprocess import preprocess_batch 13 | from shared import handle_input, PROMPT_OPTIONS 14 | 15 | 16 | def main(): 17 | 18 | parser = HfArgumentParser( 19 | (ModelArguments, ) 20 | ) 21 | model_args, = parser.parse_args_into_dataclasses() 22 | 23 | # TODO move to prediction_args 24 | batch_size = 2 # 16 25 | max_seq_length = 128 26 | min_probability = 0.9 27 | 28 | model, tokenizer = load_model_tokenizer(model_args) 29 | 30 | db = CommentDatabase() 31 | print('Get unmoderated comments') 32 | unmoderated_comments = db.get_unmoderated_comments(shuffle=True) 33 | 34 | print('Predicting') 35 | for batch in chunk(unmoderated_comments, batch_size): 36 | 37 | # Tokenize the texts 38 | input_data = preprocess_batch( 39 | [c.author_name for c in batch], 40 | [c.text for c in batch], 41 | ) 42 | 43 | tokenized_input = tokenizer( 44 | input_data, 45 | padding='max_length', 46 | max_length=max_seq_length, 47 | truncation=True, 48 | return_tensors='pt' 49 | ) 50 | 51 | with torch.no_grad(): 52 | output = model(**tokenized_input) 53 | 54 | label_indices = np.argmax(output.logits.numpy(), axis=1) 55 | batched_probabilities = F.softmax(output.logits, dim=1).numpy() 56 | 57 | for comment, label_index, probabilities in zip(batch, label_indices, batched_probabilities): 58 | prediction = model.config.id2label[label_index] 59 | probability = probabilities[label_index] 60 | 61 | if prediction == CommentLabel.VALID.name: 62 | continue 63 | 64 | if probability < min_probability: 65 | continue 66 | 67 | print(f'[{comment.author_name}] {comment.text} ({probability:.3f})') 68 | print(comment.url) 69 | 70 | response = handle_input( 71 | f'Assign label to comments from this user:\n (0) {prediction}\n' + 72 | PROMPT_OPTIONS + 73 | f'\nEnter choice (0-{len(CommentLabel)}), or leave empty to ignore: ', 74 | custom_options=['0'] 75 | ) 76 | 77 | if not response: 78 | continue 79 | 80 | if response == '0': 81 | response = prediction 82 | else: 83 | response = response.name 84 | 85 | comment.label = response 86 | comment.moderated = True 87 | db.update(comment, force_save=True) 88 | print('Updated comment') 89 | 90 | 91 | if __name__ == '__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /src/preprocess.py: -------------------------------------------------------------------------------- 1 | import unicodedata 2 | import itertools 3 | import emoji 4 | import re 5 | 6 | BUFFER_CHARS = r"*_~|`[]()'-.•," 7 | 8 | 9 | # Categories of unicode characters to strip during normalization 10 | UNICODE_CATEGORIES_STRIP = ['Mn', 'Cc', 'Cf', 'Cs', 'Co', 'Cn', 'Sk'] 11 | 12 | SKIN_TONE_REGEX = re.compile(r'\ud83c[\udffb-\udfff]') 13 | 14 | # Map of characters that do not map to ascii characters after normalisation, 15 | # to similar-looking ascii characters. These are typically used by scammers. 16 | SIMILAR_CHAR_MAPPING = { 17 | # "Phonetic Extensions" and "Phonetic Extensions Supplement" (\u1d00 to \u1dbf) 18 | 'ᴀ': 'A', 'ᴁ': 'AE', 'ᴂ': 'ae', 19 | 'ᴃ': 'B', 'ᴄ': 'C', 'ᴅ': 'D', 20 | 'ᴆ': 'D', 'ᴇ': 'E', 'ᴈ': '3', 21 | 'ᴉ': 'i', 'ᴊ': 'J', 'ᴋ': 'K', 22 | 'ᴌ': 'L', 'ᴍ': 'M', 'ᴎ': 'N', 23 | 'ᴏ': 'o', 'ᴐ': 'c', 'ᴑ': 'o', 24 | 'ᴒ': 'n', 'ᴓ': 'o', 'ᴔ': 'oe', 25 | 'ᴕ': 'ou', 'ᴖ': 'n', 'ᴗ': 'u', 26 | 'ᴘ': 'P', 'ᴙ': 'R', 'ᴚ': 'R', 27 | 'ᴛ': 'T', 'ᴜ': 'U', 'ᴝ': 'u', 28 | 'ᴞ': 'u', 'ᴟ': 'm', 'ᴠ': 'V', 29 | 'ᴡ': 'W', 'ᴢ': 'Z', 'ᴣ': '3', 30 | 'ᴤ': '2', 'ᴥ': 'ain', 'ᴦ': 'L', 31 | 'ᴧ': 'A', 'ᴨ': 'N', 'ᴩ': 'P', 32 | 'ᴪ': 'W', 'ᴫ': 'N', 'ᴯ': 'B', 33 | 'Ǝ': '3', 'ᴻ': 'N', 'Ȣ': 'Ou', 34 | 'ɐ': 'a', 'ɑ': 'a', 'ə': 'e', 35 | 'ɛ': 'e', 'ɜ': '3', 'ᵎ': 'i', 36 | 'ŋ': 'n', 'ɔ': 'c', 'ɯ': 'w', 37 | 'β': 'B', 'γ': 'Y', 'δ': 'd', 38 | 'φ': 'o', 'χ': 'X', 'ρ': 'p', 39 | 'ᵫ': 'eu', 'ᵬ': 'b', 'ᵭ': 'd', 40 | 'ᵮ': 'f', 'ᵯ': 'm', 'ᵰ': 'n', 41 | 'ᵱ': 'p', 'ᵲ': 'r', 'ᵳ': 'r', 42 | 'ᵴ': 's', 'ᵵ': 't', 'ᵶ': 'z', 43 | 'ᵷ': 'g', 'н': 'H', 'ᵹ': 'g', 44 | 'ᵺ': 'th', 'ᵻ': 'i', 'ᵼ': 'i', 45 | 'ᵽ': 'p', 'ᵾ': 'u', 'ᵿ': 'u', 46 | 'ᶀ': 'b', 'ᶁ': 'd', 'ᶂ': 'f', 47 | 'ᶃ': 'g', 'ᶄ': 'k', 'ᶅ': 'l', 48 | 'ᶆ': 'm', 'ᶇ': 'n', 'ᶈ': 'p', 49 | 'ᶉ': 'r', 'ᶊ': 's', 'ᶋ': 'l', 50 | 'ᶌ': 'v', 'ᶍ': 'x', 'ᶎ': 'z', 51 | 'ᶏ': 'a', 'ᶐ': 'a', 'ᶑ': 'd', 52 | 'ᶒ': 'e', 'ᶓ': 'e', 'ᶔ': '3', 53 | 'ᶕ': 'e', 'ᶖ': 'i', 'ᶗ': 'p', 54 | 'ᶘ': 'l', 'ᶙ': 'u', 'ᶚ': '3', 55 | 'ɒ': 'a', 'ɕ': 'c', 'ɟ': 'j', 56 | 'ɡ': 'g', 'ɥ': 'u', 'ɨ': 'i', 57 | 'ɩ': 'i', 'ɪ': 'I', 'ʝ': 'j', 58 | 'ɭ': 'l', 'ʟ': 'L', 'ɱ': 'm', 59 | 'ɰ': 'w', 'ɲ': 'n', 'ɳ': 'n', 60 | 'ɴ': 'N', 'ɵ': 'o', 'ɸ': 'o', 61 | 'ʂ': 's', 'ʃ': 'l', 'ƫ': 't', 62 | 'ʉ': 'u', 'ʊ': 'u', 'ʋ': 'u', 63 | 'ʌ': 'n', 'ʐ': 'z', 'ʑ': 'z', 64 | 'ʒ': '3', 'θ': 'O', 65 | 66 | # IPA Extensions (\u0250 -> \u02AF) 67 | 'ɓ': 'b', 'ɖ': 'd', 'ɗ': 'd', 68 | 'ɘ': 'e', 'ɚ': 'e', 'ɝ': '3', 69 | 'ɞ': 'e', 'ɠ': 'g', 'ɢ': 'G', 70 | 'ɣ': 'Y', 'ɤ': 'y', 'ɦ': 'h', 71 | 'ɧ': 'h', 'ɫ': 'l', 'ɬ': 'l', 72 | 'ɮ': 'l3', 'ɶ': 'oe', 'ɷ': 'o', 73 | 'ɹ': 'r', 'ɺ': 'r', 'ɻ': 'r', 74 | 'ɼ': 'r', 'ɽ': 'r', 'ɾ': 'r', 75 | 'ɿ': 'r', 'ʀ': 'R', 'ʁ': 'R', 76 | 'ʄ': 'f', 'ʅ': 'l', 'ʆ': 'l', 77 | 'ʇ': 't', 'ʈ': 't', 'ʍ': 'M', 78 | 'ʎ': 'y', 'ʏ': 'Y', 'ʓ': '3', 79 | 'ʔ': '?', 'ʕ': '?', 'ʖ': '?', 80 | 'ʗ': 'C', 'ʘ': 'O', 'ʙ': 'B', 81 | 'ʚ': 'o', 'ʛ': 'G', 'ʜ': 'H', 82 | 'ʞ': 'k', 'ʠ': 'q', 'ʡ': '?', 83 | 'ʢ': '?', 'ʣ': 'dz', 'ʤ': 'd3', 84 | 'ʥ': 'dz', 'ʦ': 'ts', 'ʧ': 'tf', 85 | 'ʨ': 'tc', 'ʩ': 'fn', 'ʪ': 'ls', 86 | 'ʫ': 'lz', 'ʬ': 'W', 'ʭ': 'n', 87 | 'ʮ': 'u', 'ʯ': 'u', 88 | } 89 | 90 | 91 | def replace_similar_chars(text): 92 | return ''.join(SIMILAR_CHAR_MAPPING.get(x, x) for x in text) 93 | 94 | 95 | def remove_unicode_categories(string): 96 | return ''.join(char for char in string if unicodedata.category(char) not in UNICODE_CATEGORIES_STRIP) 97 | 98 | 99 | def replace_whitespace_with_spaces(string): 100 | return ' '.join(string.strip().split()) 101 | 102 | 103 | def remove_emoji_skin_tones(string): 104 | return re.sub(SKIN_TONE_REGEX, '', string) 105 | 106 | 107 | def normalise(string): 108 | # 0. Make sure it is a string 109 | string = str(string) 110 | 111 | # 1. Deconstruct emojies into text 112 | # Needed since we want to extract the semantic meaning from the emojis, 113 | # as opposed to , which many tokenizers will assign to it 114 | string = remove_emoji_skin_tones(string) 115 | string = emoji.demojize(string, language='alias') 116 | 117 | # 2. Replace strange unicode characters with most similar ASCII character 118 | string = unicodedata.normalize('NFKD', string) 119 | string = replace_similar_chars(string) 120 | 121 | # 3. Remove certain types of unicode categories, like accents 122 | string = remove_unicode_categories(string) 123 | 124 | # 4. Replace all whitespace with a single space 125 | string = replace_whitespace_with_spaces(string) 126 | 127 | # 5. Remove specific duplicated characters (https://stackoverflow.com/a/49695605) 128 | string = ''.join(k if k in BUFFER_CHARS else ''.join(v) 129 | for k, v in itertools.groupby(string, lambda c: c)) 130 | 131 | # 6. Lowercase the string 132 | string = string.lower() 133 | 134 | return string 135 | 136 | 137 | def preprocess_single(author_name, comment_text): 138 | author_name = normalise(author_name) 139 | comment_text = normalise(comment_text) 140 | 141 | # TODO add custom token? 142 | return f'{author_name} commented {comment_text}' 143 | 144 | 145 | def preprocess_batch(author_names, comment_texts): 146 | 147 | to_return = [] 148 | for author_name, comment_text in zip(author_names, comment_texts): 149 | to_return.append(preprocess_single(author_name, comment_text)) 150 | 151 | return to_return 152 | -------------------------------------------------------------------------------- /src/shared.py: -------------------------------------------------------------------------------- 1 | from labels import CommentLabel 2 | 3 | PROMPT_OPTIONS = '\n'.join(( 4 | f' ({label.value}) {label.name}' 5 | for label in CommentLabel 6 | )) 7 | 8 | def handle_input(prompt, custom_options=None): 9 | if custom_options is None: 10 | custom_options = [] 11 | 12 | while True: 13 | new_label = input(prompt) 14 | if not new_label: 15 | return None 16 | 17 | for option in custom_options: 18 | if new_label.upper() == option.upper(): 19 | return option 20 | 21 | try: 22 | return CommentLabel(int(new_label)) 23 | except ValueError: 24 | print('ERROR: Invalid input') 25 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import logging 18 | import os 19 | import random 20 | import sys 21 | from dataclasses import dataclass, field 22 | from typing import Optional 23 | 24 | import numpy as np 25 | import datasets 26 | from datasets import load_dataset 27 | import transformers 28 | from transformers import ( 29 | DataCollatorWithPadding, 30 | EvalPrediction, 31 | HfArgumentParser, 32 | Trainer, 33 | TrainingArguments, 34 | default_data_collator, 35 | set_seed, 36 | ) 37 | from transformers.trainer_utils import get_last_checkpoint 38 | from transformers.utils import check_min_version 39 | from transformers.utils.versions import require_version 40 | 41 | from preprocess import preprocess_batch 42 | from model import ModelArguments, load_model_tokenizer 43 | 44 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 45 | check_min_version("4.25.0") 46 | 47 | require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt") 48 | 49 | logger = logging.getLogger(__name__) 50 | 51 | 52 | @dataclass 53 | class DataTrainingArguments: 54 | """ 55 | Arguments pertaining to what data we are going to input our model for training and eval. 56 | 57 | Using `HfArgumentParser` we can turn this class 58 | into argparse arguments to be able to specify them on 59 | the command line. 60 | """ 61 | 62 | max_seq_length: int = field( 63 | default=128, # TODO figure out best length 64 | metadata={ 65 | "help": ( 66 | "The maximum total input sequence length after tokenization. Sequences longer " 67 | "than this will be truncated, sequences shorter will be padded." 68 | ) 69 | }, 70 | ) 71 | overwrite_cache: bool = field( 72 | default=False, 73 | metadata={ 74 | "help": "Overwrite the cached preprocessed datasets or not." 75 | } 76 | ) 77 | pad_to_max_length: bool = field( 78 | default=True, 79 | metadata={ 80 | "help": ( 81 | "Whether to pad all samples to `max_seq_length`. " 82 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 83 | ) 84 | }, 85 | ) 86 | max_train_samples: Optional[int] = field( 87 | default=None, 88 | metadata={ 89 | "help": ( 90 | "For debugging purposes or quicker training, truncate the number of training examples to this " 91 | "value if set." 92 | ) 93 | }, 94 | ) 95 | max_eval_samples: Optional[int] = field( 96 | default=None, 97 | metadata={ 98 | "help": ( 99 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 100 | "value if set." 101 | ) 102 | }, 103 | ) 104 | max_predict_samples: Optional[int] = field( 105 | default=None, 106 | metadata={ 107 | "help": ( 108 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 109 | "value if set." 110 | ) 111 | }, 112 | ) 113 | train_file: Optional[str] = field( 114 | default=None, 115 | metadata={ 116 | "help": "A json file containing the training data." 117 | } 118 | ) 119 | validation_file: Optional[str] = field( 120 | default=None, 121 | metadata={ 122 | "help": "A json file containing the validation data." 123 | } 124 | ) 125 | test_file: Optional[str] = field( 126 | default=None, 127 | metadata={ 128 | "help": "A json file containing the test data." 129 | } 130 | ) 131 | 132 | def __post_init__(self): 133 | if self.train_file is None or self.validation_file is None: 134 | raise ValueError("Need a training/validation file.") 135 | 136 | 137 | def main(): 138 | # See all possible arguments in src/transformers/training_args.py 139 | # or by passing the --help flag to this script. 140 | # We now keep distinct sets of args, for a cleaner separation of concerns. 141 | 142 | parser = HfArgumentParser( 143 | (ModelArguments, DataTrainingArguments, TrainingArguments)) 144 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 145 | 146 | # Setup logging 147 | logging.basicConfig( 148 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 149 | datefmt="%m/%d/%Y %H:%M:%S", 150 | handlers=[logging.StreamHandler(sys.stdout)], 151 | ) 152 | 153 | log_level = training_args.get_process_log_level() 154 | logger.setLevel(log_level) 155 | datasets.utils.logging.set_verbosity(log_level) 156 | transformers.utils.logging.set_verbosity(log_level) 157 | transformers.utils.logging.enable_default_handler() 158 | transformers.utils.logging.enable_explicit_format() 159 | 160 | # Log on each process the small summary: 161 | logger.warning( 162 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 163 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 164 | ) 165 | logger.info(f"Training/evaluation parameters {training_args}") 166 | 167 | # Detecting last checkpoint. 168 | last_checkpoint = None 169 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 170 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 171 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 172 | raise ValueError( 173 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 174 | "Use --overwrite_output_dir to overcome." 175 | ) 176 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 177 | logger.info( 178 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 179 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 180 | ) 181 | 182 | # Set seed before initializing model. 183 | set_seed(training_args.seed) 184 | 185 | # Loading a dataset from your local files. 186 | data_files = { 187 | "train": data_args.train_file, 188 | "validation": data_args.validation_file 189 | } 190 | 191 | # Get the test dataset: 192 | if training_args.do_predict: 193 | if data_args.test_file is None: 194 | raise ValueError("Need a test file for `do_predict`.") 195 | 196 | data_files["test"] = data_args.test_file 197 | 198 | for key in data_files.keys(): 199 | logger.info(f"load a local file for {key}: {data_files[key]}") 200 | 201 | # Loading a dataset from local json files 202 | raw_datasets = load_dataset( 203 | "json", 204 | data_files=data_files, 205 | cache_dir=model_args.cache_dir, 206 | use_auth_token=True if model_args.use_auth_token else None, 207 | ) 208 | # See more about loading any type of standard or custom dataset at 209 | # https://huggingface.co/docs/datasets/loading_datasets.html. 210 | 211 | # Load pretrained model and tokenizer 212 | # 213 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 214 | # download model & vocab. 215 | model, tokenizer = load_model_tokenizer(model_args) 216 | 217 | # Preprocessing the raw_datasets 218 | # Padding strategy 219 | if data_args.pad_to_max_length: 220 | padding = "max_length" 221 | else: 222 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch 223 | padding = False 224 | 225 | if data_args.max_seq_length > tokenizer.model_max_length: 226 | logger.warning( 227 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 228 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 229 | ) 230 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 231 | 232 | def preprocess_function(examples): 233 | # Tokenize the texts 234 | input_data = preprocess_batch( 235 | examples['author_name'], examples['text']) 236 | 237 | result = tokenizer(input_data, padding=padding, 238 | max_length=max_seq_length, truncation=True) 239 | 240 | # Map labels to IDs 241 | if model.config.label2id is not None and 'label' in examples: 242 | result['label'] = [(model.config.label2id[l] if l != -1 else -1) 243 | for l in examples['label']] 244 | return result 245 | 246 | with training_args.main_process_first(desc="dataset map pre-processing"): 247 | raw_datasets = raw_datasets.map( 248 | preprocess_function, 249 | batched=True, 250 | load_from_cache_file=not data_args.overwrite_cache, 251 | desc="Running tokenizer on dataset", 252 | ) 253 | if training_args.do_train: 254 | if "train" not in raw_datasets: 255 | raise ValueError("--do_train requires a train dataset") 256 | train_dataset = raw_datasets["train"] 257 | if data_args.max_train_samples is not None: 258 | max_train_samples = min( 259 | len(train_dataset), data_args.max_train_samples) 260 | train_dataset = train_dataset.select(range(max_train_samples)) 261 | 262 | if training_args.do_eval: 263 | if "validation" not in raw_datasets: 264 | raise ValueError("--do_eval requires a validation dataset") 265 | eval_dataset = raw_datasets["validation"] 266 | if data_args.max_eval_samples is not None: 267 | max_eval_samples = min( 268 | len(eval_dataset), data_args.max_eval_samples) 269 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 270 | 271 | if training_args.do_predict: 272 | if "test" not in raw_datasets: 273 | raise ValueError("--do_predict requires a test dataset") 274 | predict_dataset = raw_datasets["test"] 275 | if data_args.max_predict_samples is not None: 276 | max_predict_samples = min( 277 | len(predict_dataset), data_args.max_predict_samples) 278 | predict_dataset = predict_dataset.select( 279 | range(max_predict_samples)) 280 | 281 | # Log a few random samples from the training set: 282 | if training_args.do_train: 283 | for index in random.sample(range(len(train_dataset)), 3): 284 | logger.info( 285 | f"Sample {index} of the training set: {train_dataset[index]}.") 286 | 287 | from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report, precision_recall_fscore_support 288 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a 289 | # predictions and label_ids field) and has to return a dictionary string to float. 290 | 291 | def compute_metrics(eval_pred: EvalPrediction): 292 | logits, labels = eval_pred 293 | predictions = np.argmax(logits, axis=1) 294 | 295 | eval_data = precision_recall_fscore_support( 296 | y_true=labels, y_pred=predictions, average='weighted') 297 | 298 | # class_report=classification_report(y_true=labels, y_pred=predictions, output_dict=True, target_names=label_list) 299 | return dict( 300 | accuracy=accuracy_score(y_true=labels, y_pred=predictions), 301 | balanced_accuracy=balanced_accuracy_score( 302 | y_true=labels, y_pred=predictions), 303 | precision=eval_data[0], 304 | recall=eval_data[1], 305 | fscore=eval_data[2], 306 | # classification_report=class_report 307 | ) 308 | 309 | # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, 310 | # so we change it if we already did the padding. 311 | if data_args.pad_to_max_length: 312 | data_collator = default_data_collator 313 | elif training_args.fp16: 314 | data_collator = DataCollatorWithPadding( 315 | tokenizer, pad_to_multiple_of=8) 316 | else: 317 | data_collator = None 318 | 319 | # Initialize our Trainer 320 | trainer = Trainer( 321 | model=model, 322 | args=training_args, 323 | train_dataset=train_dataset if training_args.do_train else None, 324 | eval_dataset=eval_dataset if training_args.do_eval else None, 325 | compute_metrics=compute_metrics, 326 | tokenizer=tokenizer, 327 | data_collator=data_collator, 328 | ) 329 | 330 | # Training 331 | if training_args.do_train: 332 | checkpoint = None 333 | if training_args.resume_from_checkpoint is not None: 334 | checkpoint = training_args.resume_from_checkpoint 335 | elif last_checkpoint is not None: 336 | checkpoint = last_checkpoint 337 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 338 | metrics = train_result.metrics 339 | max_train_samples = ( 340 | data_args.max_train_samples if data_args.max_train_samples is not None else len( 341 | train_dataset) 342 | ) 343 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 344 | 345 | trainer.save_model() # Saves the tokenizer too for easy upload 346 | 347 | trainer.log_metrics("train", metrics) 348 | trainer.save_metrics("train", metrics) 349 | trainer.save_state() 350 | 351 | # Evaluation 352 | if training_args.do_eval: 353 | logger.info("*** Evaluate ***") 354 | 355 | metrics = trainer.evaluate(eval_dataset=eval_dataset) 356 | 357 | max_eval_samples = ( 358 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len( 359 | eval_dataset) 360 | ) 361 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 362 | 363 | trainer.log_metrics("eval", metrics) 364 | trainer.save_metrics("eval", metrics) 365 | 366 | if training_args.do_predict: 367 | logger.info("*** Predict ***") 368 | 369 | # Removing the `label` columns because it contains -1 and Trainer won't like that. 370 | predict_dataset = predict_dataset.remove_columns('label') 371 | predictions = trainer.predict( 372 | predict_dataset, metric_key_prefix="predict").predictions 373 | predictions = np.argmax(predictions, axis=1) 374 | 375 | output_predict_file = os.path.join( 376 | training_args.output_dir, "predict_results.txt") 377 | if trainer.is_world_process_zero(): 378 | with open(output_predict_file, "w", encoding='utf-8') as writer: 379 | logger.info(f"***** Predict results *****") 380 | writer.write("index\ttext\tprediction\n") 381 | for index, item in enumerate(predictions): 382 | item = model.config.id2label[item] 383 | writer.write( 384 | f"{index}\t{predict_dataset[index]}\t{item}\n") 385 | 386 | kwargs = {"finetuned_from": model_args.model_name_or_path, 387 | "tasks": "text-classification"} 388 | 389 | if training_args.push_to_hub: 390 | trainer.push_to_hub(**kwargs) 391 | 392 | 393 | if __name__ == "__main__": 394 | main() 395 | -------------------------------------------------------------------------------- /src/urls.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import re 3 | 4 | ip_middle_octet = r"(?:\.(?:1?\d{1,2}|2[0-4]\d|25[0-5]))" 5 | ip_last_octet = r"(?:\.(?:0|[1-9]\d?|1\d\d|2[0-4]\d|25[0-5]))" 6 | 7 | URL_PATTERN = re.compile( # noqa: W605 8 | # protocol identifier 9 | r"(?:(?:https?|ftp)://)" 10 | # user:pass authentication 11 | r"(?:[-a-z\u00a1-\uffff0-9._~%!$&'()*+,;=:]+" 12 | r"(?::[-a-z0-9._~%!$&'()*+,;=:]*)?@)?" 13 | r"(?:" 14 | r"(?P" 15 | # IP address exclusion 16 | # private & local networks 17 | r"(?:(?:10|127)" + ip_middle_octet + r"{2}" + ip_last_octet + r")|" 18 | r"(?:(?:169\.254|192\.168)" + ip_middle_octet + ip_last_octet + r")|" 19 | r"(?:172\.(?:1[6-9]|2\d|3[0-1])" + ip_middle_octet + ip_last_octet + r"))" 20 | r"|" 21 | # private & local hosts 22 | r"(?P" 23 | r"(?:localhost))" 24 | r"|" 25 | # IP address dotted notation octets 26 | # excludes loopback network 0.0.0.0 27 | # excludes reserved space >= 224.0.0.0 28 | # excludes network & broadcast addresses 29 | # (first & last IP address of each class) 30 | r"(?P" 31 | r"(?:[1-9]\d?|1\d\d|2[01]\d|22[0-3])" 32 | r"" + ip_middle_octet + r"{2}" 33 | r"" + ip_last_octet + r")" 34 | r"|" 35 | # IPv6 RegEx from https://stackoverflow.com/a/17871737 36 | r"\[(" 37 | # 1:2:3:4:5:6:7:8 38 | r"([0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4}|" 39 | # 1:: 1:2:3:4:5:6:7:: 40 | r"([0-9a-fA-F]{1,4}:){1,7}:|" 41 | # 1::8 1:2:3:4:5:6::8 1:2:3:4:5:6::8 42 | r"([0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}|" 43 | # 1::7:8 1:2:3:4:5::7:8 1:2:3:4:5::8 44 | r"([0-9a-fA-F]{1,4}:){1,5}(:[0-9a-fA-F]{1,4}){1,2}|" 45 | # 1::6:7:8 1:2:3:4::6:7:8 1:2:3:4::8 46 | r"([0-9a-fA-F]{1,4}:){1,4}(:[0-9a-fA-F]{1,4}){1,3}|" 47 | # 1::5:6:7:8 1:2:3::5:6:7:8 1:2:3::8 48 | r"([0-9a-fA-F]{1,4}:){1,3}(:[0-9a-fA-F]{1,4}){1,4}|" 49 | # 1::4:5:6:7:8 1:2::4:5:6:7:8 1:2::8 50 | r"([0-9a-fA-F]{1,4}:){1,2}(:[0-9a-fA-F]{1,4}){1,5}|" 51 | # 1::3:4:5:6:7:8 1::3:4:5:6:7:8 1::8 52 | r"[0-9a-fA-F]{1,4}:((:[0-9a-fA-F]{1,4}){1,6})|" 53 | # ::2:3:4:5:6:7:8 ::2:3:4:5:6:7:8 ::8 :: 54 | r":((:[0-9a-fA-F]{1,4}){1,7}|:)|" 55 | # fe80::7:8%eth0 fe80::7:8%1 56 | # (link-local IPv6 addresses with zone index) 57 | r"fe80:(:[0-9a-fA-F]{0,4}){0,4}%[0-9a-zA-Z]{1,}|" 58 | r"::(ffff(:0{1,4}){0,1}:){0,1}" 59 | r"((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}" 60 | # ::255.255.255.255 ::ffff:255.255.255.255 ::ffff:0:255.255.255.255 61 | # (IPv4-mapped IPv6 addresses and IPv4-translated addresses) 62 | r"(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])|" 63 | r"([0-9a-fA-F]{1,4}:){1,4}:" 64 | r"((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}" 65 | # 2001:db8:3:4::192.0.2.33 64:ff9b::192.0.2.33 66 | # (IPv4-Embedded IPv6 Address) 67 | r"(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])" 68 | r")\]|" 69 | # host name 70 | r"(?:(?:(?:xn--[-]{0,2})|[a-z\u00a1-\uffff\U00010000-\U0010ffff0-9]-?)*" 71 | r"[a-z\u00a1-\uffff\U00010000-\U0010ffff0-9]+)" 72 | # domain name 73 | r"(?:\.(?:(?:xn--[-]{0,2})|[a-z\u00a1-\uffff\U00010000-\U0010ffff0-9]-?)*" 74 | r"[a-z\u00a1-\uffff\U00010000-\U0010ffff0-9]+)*" 75 | # TLD identifier 76 | r"(?:\.(?:(?:xn--[-]{0,2}[a-z\u00a1-\uffff\U00010000-\U0010ffff0-9]{2,})|" 77 | r"[a-z\u00a1-\uffff\U00010000-\U0010ffff]{2,}))" 78 | r")" 79 | # port number 80 | r"(?::\d{2,5})?" 81 | # resource path 82 | r"(?:/[-a-z\u00a1-\uffff\U00010000-\U0010ffff0-9._~%!$&'()*+,;=:@/]*)?" 83 | # query string 84 | r"(?:\?\S*)?" 85 | # fragment 86 | r"(?:#\S*)?", 87 | re.UNICODE | re.IGNORECASE 88 | ) 89 | 90 | 91 | def validate_url(text) -> Tuple[bool, bool]: 92 | # Returns tuple: (contains URL, is only a URL) 93 | match = URL_PATTERN.search(text) 94 | if match: 95 | exact = match.span() == (0, len(text)) 96 | return True, exact 97 | 98 | return False, False 99 | -------------------------------------------------------------------------------- /src/youtube.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import socket 3 | import os 4 | from googleapiclient.discovery import build, Resource 5 | from googleapiclient.errors import HttpError 6 | 7 | from dotenv import load_dotenv 8 | load_dotenv() 9 | 10 | 11 | API_SERVICE_NAME = 'youtube' 12 | API_VERSION = 'v3' 13 | YT_API_KEY = os.environ['YT_API_KEY'] 14 | 15 | NUM_RETRIES = 3 16 | 17 | 18 | class YouTubeError(Exception): 19 | reason = '' 20 | 21 | def __init__(self, error: HttpError) -> None: 22 | self.error = error 23 | 24 | 25 | class VideoNotFound(YouTubeError): 26 | reason = 'videoNotFound' 27 | 28 | 29 | class QuotaExceeded(YouTubeError): 30 | reason = 'quotaExceeded' 31 | 32 | 33 | class CommentsDisabled(YouTubeError): 34 | reason = 'commentsDisabled' 35 | 36 | 37 | errors: List[YouTubeError] = [VideoNotFound, QuotaExceeded, CommentsDisabled] 38 | 39 | 40 | def make_api_request(function: Resource, **kwargs): 41 | request = function(**kwargs) 42 | try: 43 | return request.execute() 44 | except socket.timeout: 45 | # TODO catch exceptions, even after retrying (built in) 46 | raise 47 | except HttpError as e: 48 | for error_class in errors: 49 | if e.reason == error_class.reason: 50 | raise error_class(e) 51 | raise YouTubeError(e) 52 | 53 | 54 | YOUTUBE_API = build( 55 | API_SERVICE_NAME, 56 | API_VERSION, 57 | developerKey=YT_API_KEY, 58 | num_retries=NUM_RETRIES 59 | ) 60 | 61 | 62 | def pagination_helper(function, api_kwargs, max_requests=None): 63 | request_kwargs = {} 64 | request_count = 0 65 | while True: 66 | if max_requests is not None and request_count > max_requests: 67 | break 68 | try: 69 | response = make_api_request( 70 | function=function, 71 | **api_kwargs, 72 | **request_kwargs 73 | ) 74 | request_count += 1 75 | except QuotaExceeded: 76 | raise 77 | except YouTubeError as e: 78 | print(e) 79 | return 80 | 81 | yield response 82 | 83 | next_token = response.get('nextPageToken') 84 | if not next_token: 85 | break 86 | 87 | request_kwargs['pageToken'] = next_token 88 | --------------------------------------------------------------------------------