├── .gitignore ├── Fujisaki_GPU_Stream.ipynb ├── LICENSE ├── README.md ├── config_example.py ├── cover_alpaca2jsonl.py ├── finetune.py ├── finetune_adv.py ├── github_assets └── banner.png ├── infer.py ├── length.py ├── prompt.py ├── prompt_openai.py ├── prompt_util.py ├── requirements.txt ├── scrape_twitter.py ├── tokenize_dataset_rows.py ├── tweets_sample.md └── twitter-parser.py /.gitignore: -------------------------------------------------------------------------------- 1 | config.py 2 | parser-output/ 3 | parser-cache/ 4 | tweet_cache/ 5 | chrome_profile/ 6 | assets/ 7 | data/ 8 | Your archive.html 9 | .DS_Store 10 | tweets.md 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | -------------------------------------------------------------------------------- /Fujisaki_GPU_Stream.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "dpcJsqajcua6" 7 | }, 8 | "source": [ 9 | "李萌萌的电子骨灰盒\n", 10 | "----\n", 11 | "\n", 12 | "这是一个通过ChatGLM模型训练的李萌萌的数字分身,你可以在问题栏目填入内容,或者什么都不填,来观察李萌萌到底会说些什么。\n", 13 | "T4级别的GPU已经可以很胜任这个任务了。" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": { 20 | "id": "XRViHJfokAHa" 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "! nvidia-smi" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": { 30 | "id": "Ep-GXSbvc4UR" 31 | }, 32 | "source": [ 33 | "### 安装依赖" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": { 40 | "id": "EG5LbTdJcr-a" 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "! git clone https://github.com/ljsabc/Fujisaki\n", 45 | "%cd Fujisaki\n", 46 | "\n", 47 | "%pip install -q -r requirements.txt\n", 48 | "%pip install gradio" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "id": "DRnutF1NeiEh" 56 | }, 57 | "outputs": [], 58 | "source": [ 59 | "import torch\n", 60 | "import sys\n", 61 | "\n", 62 | "from transformers import AutoTokenizer, GenerationConfig, AutoModel\n", 63 | "\n", 64 | "torch.set_default_tensor_type(torch.cuda.HalfTensor)\n", 65 | "model = AutoModel.from_pretrained(\"THUDM/chatglm-6b\", revision=\"658202d\", trust_remote_code=True).half().cuda()\n", 66 | "setattr(model, \"lm_head_raw\", model.lm_head)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": { 72 | "id": "MGcANpbi1k4z" 73 | }, 74 | "source": [ 75 | "The huggingface repo has something more than the QKV lora. We checkout a version that is the closest to the current version. " 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": { 82 | "id": "ggFUQpf2en_s" 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "from peft import get_peft_model, LoraConfig, TaskType, PeftModel\n", 87 | "peft_path = 'ljsabc/Fujisaki_GLM' # change it to your own\n", 88 | "model = PeftModel.from_pretrained(\n", 89 | " model,\n", 90 | " peft_path,\n", 91 | " torch_dtype=torch.float16\n", 92 | " )\n", 93 | "\n", 94 | "# We have to use full precision, as some tokens are >65535\n", 95 | "torch.set_default_tensor_type(torch.cuda.FloatTensor)\n", 96 | "model.eval()" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": { 103 | "id": "7Wv8zzKtgre2" 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "def evaluate(context, temperature, top_p, top_k):\n", 108 | " generation_config = GenerationConfig(\n", 109 | " temperature=temperature,\n", 110 | " top_p=top_p,\n", 111 | " top_k=top_k,\n", 112 | " #repetition_penalty=1.1,\n", 113 | " num_beams=1,\n", 114 | " do_sample=True,\n", 115 | " )\n", 116 | " with torch.no_grad():\n", 117 | " input_text = f\"Context: {context}Answer: \" \n", 118 | " ids = tokenizer([input_text], return_tensors=\"pt\")\n", 119 | " inputs = ids.to(\"cuda\")\n", 120 | " out = model.generate(\n", 121 | " **inputs,\n", 122 | " max_length=224,\n", 123 | " generation_config=generation_config\n", 124 | " )\n", 125 | " out = out.tolist()[0]\n", 126 | " decoder_output = tokenizer.decode(out)\n", 127 | " out_text = decoder_output.split(\"Answer: \")[1]\n", 128 | " return out_text" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": { 135 | "id": "J7PNWg3gFRJw" 136 | }, 137 | "outputs": [], 138 | "source": [ 139 | "def evaluate_stream(msg, history, temperature, top_p):\n", 140 | " generation_config = GenerationConfig(\n", 141 | " temperature=temperature,\n", 142 | " top_p=top_p,\n", 143 | " #repetition_penalty=1.1,\n", 144 | " num_beams=1,\n", 145 | " do_sample=True,\n", 146 | " )\n", 147 | "\n", 148 | " history.append([msg, None])\n", 149 | "\n", 150 | " context = \"\"\n", 151 | " if len(history) > 4:\n", 152 | " history.pop(0)\n", 153 | "\n", 154 | " for j in range(len(history)):\n", 155 | " history[j][0] = history[j][0].replace(\"
\", \"\")\n", 156 | "\n", 157 | " # concatenate context\n", 158 | " for h in history[:-1]:\n", 159 | " context += h[0] + \"||\" + h[1] + \"||\"\n", 160 | "\n", 161 | " context += history[-1][0]\n", 162 | " context = context.replace(r'
', '')\n", 163 | "\n", 164 | " # TODO: Avoid the tokens are too long.\n", 165 | " CUTOFF = 224\n", 166 | " while len(tokenizer.encode(context)) > CUTOFF:\n", 167 | " # save 15 token size for the answer\n", 168 | " context = context[15:]\n", 169 | "\n", 170 | " h = []\n", 171 | " print(\"History:\", history)\n", 172 | " print(\"Context:\", context)\n", 173 | " for response, h in model.stream_chat(tokenizer, context, h, max_length=CUTOFF, top_p=top_p, temperature=temperature):\n", 174 | " history[-1][1] = response\n", 175 | " yield history, \"\"" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": { 182 | "id": "egPtBitEHsz1" 183 | }, 184 | "outputs": [], 185 | "source": [ 186 | "history = [['你是谁','我是喵喵'], ['你住在哪里', \"我不知道\"]]\n", 187 | "for h in evaluate_stream(\"你在干什么\", history, 1.0, 0.9):\n", 188 | " print(h)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": { 195 | "id": "-Xy1tvCTApYN" 196 | }, 197 | "outputs": [], 198 | "source": [ 199 | "import gradio as gr\n", 200 | "\n", 201 | "title = \"\"\"

李萌萌(Alter Ego)

\n", 202 | "

这是一个通过ChatGLM模型训练的李萌萌的数字分身,你可以与她聊天,或者直接在文本框按下Enter,来观察李萌萌到底会说些什么。

\"\"\"\n", 203 | "\n", 204 | "footer = \"\"\"

项目在GitHub上托管,基于清华的THUDM/chatglm-6b项目。

\n", 205 | "

\"I'm... a boy.\" --Chihiro Fujisaki

\"\"\"\n", 206 | "\n", 207 | "with gr.Blocks() as demo:\n", 208 | " gr.HTML(title)\n", 209 | " state = gr.State()\n", 210 | " with gr.Row():\n", 211 | " with gr.Column(scale=2):\n", 212 | " temp = gr.components.Slider(minimum=0, maximum=1.1, value=0.95, label=\"Temperature\",\n", 213 | " info=\"温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。\")\n", 214 | " top_p = gr.components.Slider(minimum=0.5, maximum=1.0, value=0.99, label=\"Top-p\",\n", 215 | " info=\"top-p参数,只输出前p>top-p的文字,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。\")\n", 216 | " #code = gr.Textbox(label=\"temp_output\", info=\"解码器输出\")\n", 217 | " #top_k = gr.components.Slider(minimum=1, maximum=200, step=1, value=25, label=\"Top k\",\n", 218 | " # info=\"top-k参数,下一个输出的文字会从top-k个文字中进行选择,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。\")\n", 219 | " \n", 220 | " with gr.Column(scale=3):\n", 221 | " chatbot = gr.Chatbot(label=\"聊天框\", info=\"\")\n", 222 | " msg = gr.Textbox(label=\"输入框\", placeholder=\"最近过得怎么样?\",\n", 223 | " info=\"输入你的内容,按[Enter]发送。也可以什么都不填写生成随机数据。\")\n", 224 | " clear = gr.Button(\"清除聊天\")\n", 225 | "\n", 226 | " msg.submit(evaluate_stream, [msg, chatbot, temp, top_p], [chatbot, msg])\n", 227 | " clear.click(lambda: None, None, chatbot, queue=False)\n", 228 | " gr.HTML(footer)\n", 229 | "\n", 230 | "demo.queue()\n", 231 | "demo.launch(debug=False)\n" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": { 238 | "id": "GsS8ZYH5ficF" 239 | }, 240 | "outputs": [], 241 | "source": [ 242 | "import gradio as gr\n", 243 | "gr.Interface(\n", 244 | " fn=evaluate,\n", 245 | " inputs=[\n", 246 | " gr.components.Textbox(\n", 247 | " lines=2, label=\"问题\", placeholder=\"最近过得怎么样?\",\n", 248 | " info=\"可以在这里输入你的问题。也可以什么都不填写生成随机数据。\"\n", 249 | " ),\n", 250 | " #gr.components.Textbox(lines=2, label=\"Input\", placeholder=\"none\"),\n", 251 | " gr.components.Slider(minimum=0, maximum=1.1, value=1.0, label=\"Temperature\",\n", 252 | " info=\"温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。\"),\n", 253 | " gr.components.Slider(minimum=0.5, maximum=1.0, value=0.99, label=\"Top p\",\n", 254 | " info=\"top-p参数,只输出前p>top-p的文字,建议不要修改。\"),\n", 255 | " gr.components.Slider(minimum=1, maximum=200, step=1, value=25, label=\"Top k\",\n", 256 | " info=\"top-k参数,下一个输出的文字会从top-k个文字中进行选择,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。\"),\n", 257 | " ],\n", 258 | " outputs=[\n", 259 | " gr.inputs.Textbox(\n", 260 | " lines=5,\n", 261 | " label=\"Output\",\n", 262 | " )\n", 263 | " ],\n", 264 | " title=\"李萌萌(Alter Ego)\",\n", 265 | " description=\"这是一个通过ChatGLM模型训练的李萌萌的数字分身,你可以在问题栏目填入内容,或者什么都不填,来观察李萌萌到底会说些什么。\",\n", 266 | ").launch()" 267 | ] 268 | } 269 | ], 270 | "metadata": { 271 | "accelerator": "GPU", 272 | "colab": { 273 | "machine_shape": "hm", 274 | "provenance": [] 275 | }, 276 | "gpuClass": "standard", 277 | "kernelspec": { 278 | "display_name": "Python 3", 279 | "name": "python3" 280 | }, 281 | "language_info": { 282 | "name": "python" 283 | } 284 | }, 285 | "nbformat": 4, 286 | "nbformat_minor": 0 287 | } 288 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Miaomiao Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fujisaki 2 | 3 | 致力于创造一个属于你的不二咲千寻。项目处于初期阶段。 4 | 5 | An ongoing (fast prototyping) project to create your own doppelgänger based on your Twitter archive and LoRA models. 6 | 7 | ## 这是什么项目? 8 | 9 | 在我十年前玩某个叫做《弹丸论破》的游戏的时候,有一位可爱的……后来她挂了,但是她以数字生命的形式帮助了未来的主人公冲出难关。如果我死了,还有人记得我吗? 10 | 11 | 该项目目前基于[ChatGLM+LoRa](https://github.com/mymusise/ChatGLM-Tuning/),暂时致力于生成中文内容。也希望该项目能帮助到一些使用中文语料库进行prompt engineering以及内容生成的朋友。 12 | 13 | Inspired by Fujisaki Chihiro (i.e., Alter Ego). I thought it would be a fun project, as I really cannot predict my death in the future, and I would like to see how my doppelgänger would react to the world. Thank to the super-strong LLM and LoRA to make it happen by allowing fine-tuning on small corpora. 14 | 15 | ## 项目基于什么原理 16 | 17 | 将你的Twitter存档目录解压缩放置在项目根目录下,即可自动处理。相关的程序会根据推文的内容、时间、引用、回复等信息,生成一个instruction风格的JSON数据集。然后,我们将该数据集用于某种中文语言模型的调参,生成新的推文,或启用给予上下文的对话。 18 | 19 | 目前项目的文字生成能力很优秀,问答一般,不过在推文涵盖的范围内能够识别一定的问题。对话功能也因为推文都是陈述句受到一定限制,但是因为ChatGLM本身有一定的对话能力,我们还是可以信赖它可以生成一定的人话的。 20 | 21 | 优秀训练数据的生成是本项目的最终目的,目前的实现均在`prompt.py`中,也已经引入OpenAI生成更有价值的增广上下文。 22 | 23 | ## Demo: 24 | 25 | HuggingFace Hub(慢一点的在线版本,可以直接运行):[![Run on HuggingFace Hub](https://img.shields.io/badge/Run%20on-Hugging%20Face%20Hub-blue?logo=huggingface)](https://huggingface.co/spaces/ljsabc/Fujisaki) 26 | 27 | Colab演示(需要GPU):[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ljsabc/Fujisaki/blob/main/Fujisaki_GPU_Stream.ipynb) 28 | 29 | 将Colab笔记本中的`peft_path`改为你自己训练的模型(可以是一个本地文件夹,或者是一个HuggingFace Repo),就可以复现你自己训练后的版本了。 30 | 31 | 32 | ![Banner](github_assets/banner.png) 33 | 34 | 相对于索引、问答目的的项目,该项目会更大程度上利用Sampling,即相似的上下文也会生成非常不一样的回答。希望这样的生成方法能够提供更高的互动性。 35 | 36 | ## To-do List 37 | 38 | - [x] Modify the twitter-parser to output your twitter archive into a instruction dataset 39 | - [x] Categorized in-reply-to and quoted tweets for better conditional generation 40 | - [x] LoRA finetuning with multiple GPUs 41 | - [x] Basic prompt engineering for original posts 42 | - [x] Hyperparameter tuning (incl. LoRA rank, batch size, learning rate, etc.) 43 | - [x] Allow in-reply-to tweets to be downloaded. 44 | - [x] Advanced prompt engineering from OpenAI 45 | - [x] Colab notebook for easy deployment (I believe this code can surely run on T4 as we are expecting much shortened tokens) 46 | - [ ] Download quoted tweets. 47 | - [ ] Support other datasets (e.g. Reddit, Weibo, etc. Future plan, let's discuss in #2 .) 48 | 49 | ## 环境与安装 50 | 51 | 推荐使用 `conda` 环境。安装依赖: 52 | 53 | ```pip install -r requirements.txt``` 54 | 55 | 有些时候,你可能要安装cudatoolkit: 56 | 57 | ```conda install cudatoolkit=11.3``` 58 | 59 | 如果使用4090需要安装11.8的版本。s 60 | 61 | 可以帮助解决bitsandbytes检测不到CUDA runtime的问题。准备好环境之后不要忘记将`config.example.py`重命名为`config.py`,并且修改其中的参数(如有必要)。 62 | 63 | ## Data requirements 64 | 65 | 解压缩你的推文存档,放置在项目根目录下,即可自动处理。解压缩之后你应该能在项目根目录里面看到`Your archive.html`这个文件。然后,运行`twitter-parser.py`来解析你的推文存档,生成一个RLHF风格的JSON数据集。 66 | 67 | 同样的,你可以参考`tweets_sample.md`来生成你自己的数据集,或者等待项目更新。 68 | 69 | ## 训练 70 | 71 | 目前的模型基于[ChatGLM+LoRa](https://github.com/mymusise/ChatGLM-Tuning/),与[Luotuo](https://github.com/LC1332/Chinese-alpaca-lora)的处理方式较为类似。 72 | 73 | ### 生成训练数据 74 | 75 | 首先使用 76 | 77 | python twitter-parser.py 78 | 79 | 来处理推文存档,稍许等待之后,你会在项目根目录下看到一个`tweets.md`的文件。这个文件包含了你的推文存档中的所有推文,以及相关的信息。为了保护你的隐私,请不要公开该文件。 80 | 81 | 生成相应的数据之后,我们需要进一步调用ChatGLM的`tokenizer`来生成对应的tokenized数据集。这一步需要一些时间。这个原始的版本会过度cache同一个generator导致数据无法更新,我改了一个单文件的版本。 82 | 83 | python ./tokenize_dataset_rows.py --json_path ./tweets.md --save_path tweets.tokens --max_seq_length 240 84 | 85 | (可选)使用240个token是因为我的大部份推文,连同instruction一起,也不会超过240个token。如果你的推文较长,可以在生成jsonl之后调用`length.py`,根据输出的数据适当增加`max_seq_length`的数值。 86 | 87 | python3 ./cover_alpaca2jsonl.py --data_path tweets.md --save_path tweets.jsonl 88 | python length.py 89 | 90 | ### 可选项:根据回复推文生成更有互动性的数据集(Beta) 91 | 92 | 因为推文存档中的推文并不储存上下文信息,所以我们可以通过抓取回复推文来获得原始的上下文信息,来实现非常强的对话功能。这一步需要一些时间,但是只需要运行一次。 93 | 很不幸,抓取网站总是一个很痛苦的事情。你可能需要一个代理池,以及将`chromedriver`的二进制放在项目根目录中(需要你了解[selenium](https://selenium-python.readthedocs.io/))。你可能还需要一台性能足够强的电脑,以便多线程抓取回复信息。如果你有这样的条件,可以调整`config.py`,将`PARSE_REPLIES = False`改为`PARSE_REPLIES = True`,然后运行`twitter-parser.py`。该部分已经写好。 94 | 95 | 该方法无法抓取已经被删除的你的,和别人的推文。抓取引用推文的功能也尚在检讨中。 96 | 97 | ### 可选项:使用OpenAI根据推文生成问题 98 | 99 | 调用OpenAI API为你的原创推文生成一个问题,或者前情提要,以提供更顺畅的问答体验。该部分已经写好,需要你使用OpenAI API Key。价格稍显昂贵,每一万条推文需要准备2-3刀的预算。请填写config.py以配置相应内容,并参照`prompt_openai.py`查看实现方式。 100 | 101 | ### 训练 102 | 103 | 接下来便可调用`finetune.py`来进行模型训练。根据不同的GPU数量,你可以直接调用 104 | 105 | WORLD_SIZE=4 CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \ 106 | --master_port=1234 \ 107 | finetune.py \ 108 | --dataset_path tweets.tokens \ 109 | --lora_rank 8 \ 110 | --per_device_train_batch_size 8 \ 111 | --gradient_accumulation_steps 1 \ 112 | --num_train_epoch 2 \ 113 | --save_steps 2000 \ 114 | --save_total_limit 2 \ 115 | --learning_rate 6e-4 \ 116 | --remove_unused_columns false \ 117 | --logging_steps 50 \ 118 | --output_dir output \ 119 | --ddp_find_unused_parameters false \ 120 | --warmup_steps 50 121 | 122 | 进行多卡训练。如果多卡训练报错,可能再跑一遍就好了,是一个小的缓存bug。 123 | 124 | 单卡训练: 125 | 126 | python finetune.py \ 127 | --dataset_path tweets.tokens \ 128 | --lora_rank 8 \ 129 | --per_device_train_batch_size 8 \ 130 | --gradient_accumulation_steps 1 \ 131 | --num_train_epoch 2 \ 132 | --save_steps 2000 \ 133 | --save_total_limit 2 \ 134 | --learning_rate 2e-4 \ 135 | --remove_unused_columns false \ 136 | --logging_steps 50 \ 137 | --output_dir output \ 138 | --warmup_steps 100 139 | 140 | 目前的参数和[ChatGLM+LoRa](https://github.com/mymusise/ChatGLM-Tuning/)很类似,不过可以根据GPU数量调节学习率。默认的学习率是`2e-4` (每8个sample,如果loss突增可能还要降一些),请根据batch size和显卡能力自行测试调节。上游增加了gradient checkpointing之后3090甚至可以用到8的batch size了,体验非常好。LoRA的rank可以根据你希望的模型性能进行调节,默认的8是足够的,你也可以提升到12甚至更高,经过一定的测试`lora_rank`上到16结果会上升一个台阶,代价是稍微更长一点的训练和测试时间,但是不会多很多。 141 | 142 | 训练好的模型会保存在`output`文件夹下,你可以在`output/`中找到对应的模型文件。 143 | 144 | ## 预测(以及互动) 145 | 146 | 调用 `infer.py` 进行对话。你可以输入任何问题(因为是推文的原因,输入陈述句的效果比疑问句好),不过即便什么都不输入也可以生成一个很类似我的推文。 147 | 148 | ```python3 ./infer.py ``` 149 | 150 | 可以在该文件中调节top-p,top-k和temerature,以便生成更多的样本。可以根据gradio demo的结果适当调节。 151 | 152 | ## Benchmark 153 | 154 | 在1张A100的配置下面,训练一个75,000条推文的数据集,在设定最大长度为240的情况下,每一个epoch需要3小时。训练大概需要2-3个epoch能够达成最佳状态。 155 | 156 | ## Credits 157 | 158 | This project is based on the following projects: 159 | 160 | 27182812/ChatGLM-chinese-insturct 161 | timhutton/twitter-archive-parser 162 | LC1332/Chinese-alpaca-lora (Donated❤️) 163 | 164 | Inspired by the following projects: 165 | 166 | tloen/alpaca-lora 167 | HuggingFace: KBlueLeaf/guanaco-7B-lora-embed 168 | (potentially) twint-fork 169 | -------------------------------------------------------------------------------- /config_example.py: -------------------------------------------------------------------------------- 1 | # 默认参数,不太需要修改 2 | # Limit the minimal length of your response 3 | RESPONSE_THRESH = 6 4 | # the augmentation factor can be reduced if you enable the parsing of replies 5 | AUGMENTATION_FACTOR_ORIGINAL = 2 6 | # better just 1, as we are using full context 7 | AUGMENTATION_FACTOR_REPLIES = 1 8 | 9 | # reply delimiter, useful to distniguish the speaker in the context. 10 | REPLY_DELIMITER = "||" 11 | 12 | # 以下都是可选功能,可以根据需要开启 13 | # More interestingly, let's ask OpenAI to make a question for your tweets 14 | ENABLE_OPENAI = False 15 | OPENAI_KEY = '' 16 | OPENAI_MAX_SAMPLE = 2 # dont make it larger than the augmentation factor 17 | OPENAI_THREADS = 128 # Be careful when you try this! It can be expensive. 18 | OPENAI_CACHE_ONLY = False 19 | 20 | # Parsing replies is not recommended, as it will take a lot of time 21 | # But it can increase the overall quality significantly 22 | PARSE_REPLIES = False 23 | SCRAPE_TIMEOUT = 20 24 | SCRAPE_CACHE_ONLY = False 25 | PROXY = False 26 | PROXY_ADDR = "" 27 | PROXY_PORT = "" 28 | PROXY_USER = "" 29 | PROXY_PASSWD = "" 30 | 31 | # numbers of processes to used for selenium 32 | # base on experience this is better 2-3 times your CPU cores. Use it wisely. 33 | PROCESSES = 8 -------------------------------------------------------------------------------- /cover_alpaca2jsonl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from tqdm import tqdm 4 | 5 | 6 | def format_example(example: dict) -> dict: 7 | context = f"Instruction: {example['instruction']}\n" 8 | if example.get("input"): 9 | context += f"Input: {example['input']}\n" 10 | context += "Answer: " 11 | target = example["output"] 12 | return {"context": context, "target": target} 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--data_path", type=str, default="data/alpaca_data.json") 18 | parser.add_argument("--save_path", type=str, default="data/alpaca_data.jsonl") 19 | 20 | args = parser.parse_args() 21 | with open(args.data_path) as f: 22 | examples = json.load(f) 23 | 24 | with open(args.save_path, 'w') as f: 25 | for example in tqdm(examples, desc="formatting.."): 26 | f.write(json.dumps(format_example(example)) + '\n') 27 | 28 | 29 | if __name__ == "__main__": 30 | main() 31 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainingArguments 2 | from transformers import Trainer, HfArgumentParser 3 | from transformers import AutoTokenizer, AutoModel 4 | import torch 5 | import torch.nn as nn 6 | from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_int8_training 7 | from dataclasses import dataclass, field 8 | import datasets 9 | import os 10 | 11 | 12 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, revision="658202d") 13 | 14 | 15 | @dataclass 16 | class FinetuneArguments: 17 | dataset_path: str = field(default="data/alpaca") 18 | model_path: str = field(default="output") 19 | lora_rank: int = field(default=8) 20 | 21 | 22 | class CastOutputToFloat(nn.Sequential): 23 | def forward(self, x): 24 | return super().forward(x).to(torch.float32) 25 | 26 | 27 | def data_collator(features: list) -> dict: 28 | len_ids = [len(feature["input_ids"]) for feature in features] 29 | longest = max(len_ids) 30 | input_ids = [] 31 | labels_list = [] 32 | for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]): 33 | ids = feature["input_ids"] 34 | seq_len = feature["seq_len"] 35 | labels = ( 36 | [-100] * (seq_len - 1) + ids[(seq_len - 1) :] + [-100] * (longest - ids_l) 37 | ) 38 | ids = ids + [tokenizer.pad_token_id] * (longest - ids_l) 39 | _ids = torch.LongTensor(ids) 40 | labels_list.append(torch.LongTensor(labels)) 41 | input_ids.append(_ids) 42 | input_ids = torch.stack(input_ids) 43 | labels = torch.stack(labels_list) 44 | return { 45 | "input_ids": input_ids, 46 | "labels": labels, 47 | } 48 | 49 | 50 | class ModifiedTrainer(Trainer): 51 | def compute_loss(self, model, inputs, return_outputs=False): 52 | return model( 53 | input_ids=inputs["input_ids"], 54 | labels=inputs["labels"], 55 | ).loss 56 | 57 | def save_model(self, output_dir=None, _internal_call=False): 58 | from transformers.trainer import TRAINING_ARGS_NAME 59 | 60 | os.makedirs(output_dir, exist_ok=True) 61 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) 62 | saved_params = { 63 | k: v.to("cpu") for k, v in self.model.named_parameters() if v.requires_grad 64 | } 65 | torch.save(saved_params, os.path.join(output_dir, "adapter_model.bin")) 66 | 67 | 68 | def main(): 69 | finetune_args, training_args = HfArgumentParser( 70 | (FinetuneArguments, TrainingArguments) 71 | ).parse_args_into_dataclasses() 72 | 73 | 74 | device_map = "auto" 75 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 76 | ddp = world_size != 1 77 | if ddp: 78 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 79 | 80 | # init model 81 | model = AutoModel.from_pretrained( 82 | "THUDM/chatglm-6b", load_in_8bit=True, trust_remote_code=True, device_map=device_map,revision="658202d" 83 | ) 84 | 85 | # It's exactly the following. 86 | model = prepare_model_for_int8_training(model) 87 | #model.gradient_checkpointing_enable() 88 | #model.enable_input_require_grads() 89 | model.is_parallelizable = True 90 | model.model_parallel = True 91 | #model.lm_head = CastOutputToFloat(model.lm_head) 92 | model.config.use_cache = ( 93 | False # silence the warnings. Please re-enable for inference! 94 | ) 95 | 96 | #print(model) 97 | 98 | # setup peft 99 | peft_config = LoraConfig( 100 | task_type=TaskType.CAUSAL_LM, 101 | inference_mode=False, 102 | r=finetune_args.lora_rank, 103 | lora_alpha=32, 104 | lora_dropout=0.1, 105 | ) 106 | model = get_peft_model(model, peft_config) 107 | 108 | # load dataset 109 | dataset = datasets.load_from_disk(finetune_args.dataset_path) 110 | #print(f"\n{len(dataset)=}\n") 111 | 112 | # start train 113 | trainer = ModifiedTrainer( 114 | model=model, 115 | train_dataset=dataset, 116 | args=training_args, 117 | data_collator=data_collator, 118 | ) 119 | trainer.train() 120 | # save model 121 | #save_tunable_parameters( 122 | # model, os.path.join(training_args.output_dir, "chatglm-lora.pt") 123 | #) 124 | 125 | # don't just save model, save configs also. 126 | model.save_pretrained(training_args.output_dir) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /finetune_adv.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainingArguments 2 | from transformers import Trainer, HfArgumentParser 3 | from transformers import AutoTokenizer, AutoModel 4 | import torch 5 | import torch.nn as nn 6 | from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_int8_training 7 | from dataclasses import dataclass, field 8 | import datasets 9 | import os 10 | 11 | 12 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, revision="658202d") 13 | 14 | 15 | @dataclass 16 | class FinetuneArguments: 17 | dataset_path: str = field(default="data/alpaca") 18 | model_path: str = field(default="output") 19 | lora_rank: int = field(default=8) 20 | 21 | 22 | class CastOutputToFloat(nn.Sequential): 23 | def forward(self, x): 24 | return super().forward(x).to(torch.float32) 25 | 26 | 27 | def data_collator(features: list) -> dict: 28 | len_ids = [len(feature["input_ids"]) for feature in features] 29 | longest = max(len_ids) 30 | input_ids = [] 31 | labels_list = [] 32 | for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]): 33 | ids = feature["input_ids"] 34 | seq_len = feature["seq_len"] 35 | labels = ( 36 | [-100] * (seq_len - 1) + ids[(seq_len - 1) :] + [-100] * (longest - ids_l) 37 | ) 38 | ids = ids + [tokenizer.pad_token_id] * (longest - ids_l) 39 | _ids = torch.LongTensor(ids) 40 | labels_list.append(torch.LongTensor(labels)) 41 | input_ids.append(_ids) 42 | input_ids = torch.stack(input_ids) 43 | labels = torch.stack(labels_list) 44 | return { 45 | "input_ids": input_ids, 46 | "labels": labels, 47 | } 48 | 49 | 50 | class ModifiedTrainer(Trainer): 51 | def compute_loss(self, model, inputs, return_outputs=False): 52 | return model( 53 | input_ids=inputs["input_ids"], 54 | labels=inputs["labels"], 55 | ).loss 56 | 57 | def save_model(self, output_dir=None, _internal_call=False): 58 | from transformers.trainer import TRAINING_ARGS_NAME 59 | 60 | os.makedirs(output_dir, exist_ok=True) 61 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) 62 | saved_params = { 63 | k: v.to("cpu") for k, v in self.model.named_parameters() if v.requires_grad 64 | } 65 | torch.save(saved_params, os.path.join(output_dir, "adapter_model.bin")) 66 | 67 | 68 | def main(): 69 | finetune_args, training_args = HfArgumentParser( 70 | (FinetuneArguments, TrainingArguments) 71 | ).parse_args_into_dataclasses() 72 | 73 | 74 | device_map = "auto" 75 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 76 | ddp = world_size != 1 77 | if ddp: 78 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 79 | 80 | # init model 81 | model = AutoModel.from_pretrained( 82 | "THUDM/chatglm-6b", load_in_8bit=True, trust_remote_code=True, device_map=device_map,revision="658202d" 83 | ) 84 | 85 | # It's exactly the following. 86 | model = prepare_model_for_int8_training(model) 87 | #model.gradient_checkpointing_enable() 88 | #model.enable_input_require_grads() 89 | model.is_parallelizable = True 90 | model.model_parallel = True 91 | #model.lm_head = CastOutputToFloat(model.lm_head) 92 | model.config.use_cache = ( 93 | False # silence the warnings. Please re-enable for inference! 94 | ) 95 | 96 | # setup peft 97 | peft_config = LoraConfig( 98 | task_type=TaskType.CAUSAL_LM, 99 | inference_mode=False, 100 | r=finetune_args.lora_rank, 101 | lora_alpha=32, 102 | lora_dropout=0.1, 103 | target_modules=['query_key_value', 'lm_head','word_embeddings'] 104 | ) 105 | model = get_peft_model(model, peft_config) 106 | print(model) 107 | 108 | # load dataset 109 | dataset = datasets.load_from_disk(finetune_args.dataset_path) 110 | #print(f"\n{len(dataset)=}\n") 111 | 112 | # start train 113 | trainer = ModifiedTrainer( 114 | model=model, 115 | train_dataset=dataset, 116 | args=training_args, 117 | data_collator=data_collator, 118 | ) 119 | trainer.train() 120 | # save model 121 | #save_tunable_parameters( 122 | # model, os.path.join(training_args.output_dir, "chatglm-lora.pt") 123 | #) 124 | 125 | # don't just save model, save configs also. 126 | model.save_pretrained(training_args.output_dir) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /github_assets/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljsabc/Fujisaki/a4a70eba2bbd95266440b3a9ac94b7a29df96004/github_assets/banner.png -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | 4 | from transformers import AutoTokenizer, GenerationConfig, AutoModel 5 | 6 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 7 | model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, revision="658202d").cuda().half() 8 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, revision="658202d") 9 | 10 | from peft import get_peft_model, LoraConfig, TaskType, PeftModel 11 | 12 | peft_path = sys.argv[1] if len(sys.argv) > 1 else "output/" 13 | model = PeftModel.from_pretrained( 14 | model, 15 | peft_path, 16 | torch_dtype=torch.float16, 17 | ) 18 | print(model) 19 | 20 | # TODO: check if full precision is necessary 21 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 22 | model.eval() 23 | 24 | generation_config = GenerationConfig( 25 | temperature=0.9, 26 | top_p=0.975, 27 | #top_k=150, 28 | #repetition_penalty=1.1, 29 | num_beams=1, 30 | do_sample=True, 31 | ) 32 | 33 | with torch.no_grad(): 34 | while True: 35 | context = input(">") 36 | input_text = f"Context: {context}Answer: " 37 | ids = tokenizer([input_text], return_tensors="pt") 38 | inputs = ids.to("cuda") 39 | #input_ids = torch.LongTensor([ids]).cuda() 40 | out = model.generate( 41 | **inputs, 42 | max_length=224, 43 | generation_config=generation_config 44 | 45 | ) 46 | out = out.tolist()[0] 47 | #print(out) 48 | decoder_output = tokenizer.decode(out) 49 | #print(decoder_output) 50 | out_text = "Chihiro:" + decoder_output.split("Answer: ")[1] 51 | print(out_text) 52 | -------------------------------------------------------------------------------- /length.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import transformers 3 | from tqdm import tqdm 4 | import json 5 | 6 | def preprocess(tokenizer, config, example, max_seq_length): 7 | prompt = example["context"] 8 | target = example["target"] 9 | prompt_ids = tokenizer.encode(prompt, max_length=max_seq_length, truncation=True) 10 | target_ids = tokenizer.encode( 11 | target, 12 | max_length=max_seq_length, 13 | truncation=True, 14 | add_special_tokens=False) 15 | input_ids = prompt_ids + target_ids + [config.eos_token_id] 16 | #print(prompt+target, input_ids) 17 | return {"input_ids": input_ids, "seq_len": len(prompt_ids)} 18 | 19 | model_name = "THUDM/chatglm-6b" 20 | tokenizer = transformers.AutoTokenizer.from_pretrained( 21 | model_name, trust_remote_code=True) 22 | config = transformers.AutoConfig.from_pretrained( 23 | model_name, trust_remote_code=True, device_map='auto') 24 | 25 | import numpy as np 26 | 27 | length = [] 28 | with open("tweets.jsonl", "r") as f: 29 | for line in tqdm(f.readlines()): 30 | example = json.loads(line) 31 | feature = preprocess(tokenizer, config, example, 4096) 32 | length.append(len(feature["input_ids"])) 33 | 34 | l = np.array(length) 35 | for q in [0.9, 0.95, 0.97, 0.99, 0.999]: 36 | print(f"{q}: {np.quantile(l, q)}") 37 | 38 | -------------------------------------------------------------------------------- /prompt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import re 4 | 5 | import numpy as np 6 | import collections 7 | 8 | from scrape_twitter import process_tweet_ids 9 | import config 10 | 11 | from prompt_openai import processOriginalTweet_openai 12 | 13 | from prompt_util import original_post_prompt, findTopic, cut_sent, checkResponse 14 | 15 | 16 | def processOriginalTweet(tweets): 17 | 18 | # It is about the original tweet 19 | if config.PARSE_REPLIES: 20 | # if we parse the replies, we will have more data to sample from 21 | # we do not need to do the completion, and the Q&A part can be inferred from the in-reply-to 22 | # of the original posts, 40% are unconditional (with questions), 10% are completion, 20% are Q&A, 30% are unconditional (with no prompts) 23 | sample_range = [0.4, 0.5, 0.7, 1] 24 | else: 25 | sample_range = [0.35, 0.5, 0.95, 1] 26 | 27 | 28 | final = [] 29 | for item in tweets: 30 | # sample a random float from 0-1 to decide the ways of generation 31 | # sample_range is a probablity accumulative list 32 | # [random post, completion, Q&A, rest (direct original post)] 33 | tweet = item["text"] 34 | rr = random.random() 35 | if rr < sample_range[0]: 36 | # sample a random question, and concatenate 37 | instruction = f"{random.choice(original_post_prompt)}" 38 | user_input = f"" 39 | if checkResponse(tweet): 40 | final.append({"instruction": instruction, "input": user_input, "output": tweet}) 41 | elif rr < sample_range[1]: 42 | # given a truncated tweet, ask for completion 43 | substring = cut_sent(tweet) 44 | if len(substring) > 1: 45 | user_input = f"" 46 | rr = random.randint(1, len(substring)-1) 47 | instruction = "".join(substring[0:rr]) 48 | if checkResponse(instruction): 49 | final.append({"instruction": instruction, "input": user_input, "output": "".join(substring[rr:])}) 50 | else: 51 | instruction = f"{random.choice(original_post_prompt)}" 52 | user_input = f"" 53 | final.append({"instruction": instruction, "input": user_input, "output": tweet}) 54 | elif rr < sample_range[2]: 55 | # QA like 56 | # ask for a topic, the topic is mainly based on a substring of this tweet 57 | instruction = findTopic(tweet) 58 | if instruction is not None: 59 | user_input = f"" 60 | if checkResponse(tweet): 61 | final.append({"instruction": instruction, "input": user_input, "output": tweet}) 62 | else: 63 | #if cannot find a topic 64 | instruction = f"{random.choice(original_post_prompt)}" 65 | user_input = f"" 66 | if checkResponse(tweet): 67 | final.append({"instruction": instruction, "input": user_input, "output": tweet}) 68 | 69 | else: 70 | # no instructions, unconditional generation. 71 | final.append({"instruction": "", "input": "", "output": tweet}) 72 | 73 | return final 74 | 75 | def processReplyTweets(tweets): 76 | 77 | final = [] 78 | context_count = [] 79 | 80 | for index, t in enumerate(tweets): 81 | #print(index, t) 82 | tweet_id = t['id'] 83 | tweet_text = t['text'] 84 | context = t['context'] 85 | 86 | # first, we need to check if the reply itself is interesting 87 | if not checkResponse(tweet_text): 88 | continue 89 | 90 | # then, we need to check if the context is blank 91 | if context is None: 92 | continue 93 | 94 | if len(context) == 0: 95 | continue 96 | 97 | # next, we do a rough check in if the context is interesting 98 | # if the context is too short, we will not use it 99 | if not checkResponse("".join(context)): 100 | continue 101 | 102 | # We believe the context is interesting 103 | # in this way, we want to sample a random context length based on the probability distribution of 1/x 104 | # the longer the context, the less likely it will be sampled 105 | # if the context's length is not long enough, we will sample again 106 | ''' 107 | l = len(context) 108 | p = [] 109 | for j in range(l): 110 | p.append(np.power(j+1, -config.REPLY_TEMP)) 111 | 112 | # normalize the probability 113 | p = np.array([i/sum(p) for i in p], dtype=np.float32) 114 | 115 | while True: 116 | # sample a number based on p 117 | r = np.random.choice(l, 1, p=p)[0] 118 | context_text = "\n".join(context[-r:]) 119 | # check if the context is interesting 120 | if checkResponse(context_text): 121 | r_count.append(r) 122 | break 123 | ''' 124 | 125 | context_count.append(len(context)) 126 | 127 | # now we have a context, and a reply 128 | # go give the prompt 129 | # 130 | final.append({"instruction": config.REPLY_DELIMITER.join(context), "input": "", "output": tweet_text}) 131 | 132 | # but we can do more, we can also augment a Q&A like discussion within the topic, if we want 133 | # give a small random chance to do this 134 | if random.random() < 0.05: 135 | instruction = findTopic(context[-1]) 136 | if instruction is not None: 137 | final.append({"instruction": instruction, "input": "", "output": tweet_text}) 138 | 139 | # TODO: any other sort of prompt engineering? 140 | # use a counter to see how many long contexts are used 141 | print(collections.Counter(context_count)) 142 | 143 | return final 144 | 145 | def write_json(md_path, final_md, lang): 146 | 147 | # construct a instruction dataset 148 | final = [] 149 | 150 | # construct a list of tweets to be downloaded to sample the contexts 151 | context_tweets = [] 152 | 153 | # original tweets 154 | original_tweets = [] 155 | 156 | # firstly classify the type of each tweet 157 | for id, md, in_reply_to, quote, retweet in final_md: 158 | 159 | # content filter goes here: 160 | if md.strip() == "(media)": 161 | continue 162 | 163 | if in_reply_to and quote: 164 | # todo: process replies and quotes 165 | pass 166 | elif in_reply_to: 167 | # save them into a list; we will download them later 168 | context_tweets.append({"id": id, "text": md}) 169 | elif quote: 170 | # todo: process quotes 171 | pass 172 | elif retweet: 173 | # not my tweet, simply discard them 174 | pass 175 | else: 176 | # original tweets 177 | original_tweets.append({"id": id, "text": md}) 178 | 179 | 180 | # process with the original tweets 181 | for l in range(config.AUGMENTATION_FACTOR_ORIGINAL): 182 | if config.ENABLE_OPENAI: 183 | res = processOriginalTweet_openai(original_tweets, l) 184 | print(f"Batch: {l+1}, {len(res)} original tweets are added with openai.") 185 | final.extend(res) 186 | else: 187 | res = processOriginalTweet(original_tweets) 188 | print(f"Batch: {l+1}, {len(res)} original tweets are added.") 189 | final.extend(res) 190 | 191 | 192 | # process with the replies 193 | if config.PARSE_REPLIES: 194 | # Now things get even more interesting, we will scrape the tweets from the context_tweet_ids 195 | parsed_tweets = process_tweet_ids(context_tweets) 196 | print(f"Processed {len(parsed_tweets)} tweets from the context tweets.") 197 | 198 | for l in range(config.AUGMENTATION_FACTOR_REPLIES): 199 | res = processReplyTweets(parsed_tweets) 200 | print(f"Batch: {l+1}, {len(res)} replies are added.") 201 | final.extend(res) 202 | 203 | 204 | 205 | with open(md_path, "w") as f: 206 | # shuffle the dataset 207 | random.shuffle(final) 208 | f.write(json.dumps(final, indent=4, ensure_ascii=False)) 209 | -------------------------------------------------------------------------------- /prompt_openai.py: -------------------------------------------------------------------------------- 1 | # prompt_openai.py 2 | # parsing with openai-guided prompts 3 | from tkinter import E 4 | import config 5 | import random 6 | import json 7 | from prompt_util import checkResponse, original_post_prompt, findTopic 8 | import openai 9 | import os 10 | from concurrent import futures 11 | import collections 12 | 13 | def process_single_original_qa(value): 14 | response = openai.ChatCompletion.create( 15 | model="gpt-3.5-turbo", 16 | messages=[ 17 | #{"role": "system", "content": "你是一个富有情感的善于分析他人发言的发帖助手。\ 18 | # 请提出一个问题,使用户输入的内容可以恰当回复你提出的问题。问题中禁止包含“这”这个字。\ 19 | # 如果难以提问,或者提出的问题更像是在追问用户的输入而不是让用户的输入解答问题,就请提出一个诸如“最近发生什么事?”之类的通用问题。"}, 20 | 21 | {"role": "system", "content": "Based on the Chinese user input, create a question in SIMPLIFIED CHINESE that allows the user's input to serve as an appropriate response to your question.\ 22 | If it's too difficult to come up with a question, or the user's input is too ambiguous, or you have to question base on the details of the user's input, \ 23 | please give up and just output a general Chinese question which has the similar meaning of, '最近发生什么事了?'."}, 24 | 25 | {"role": "user", "content": f"User input (in Chinese): {value}"}, 26 | ], 27 | max_tokens=128, 28 | temperature=0.95 29 | ) 30 | return response.choices[0]["message"]["content"].strip() 31 | 32 | def process_single_original_continue(value): 33 | response = openai.ChatCompletion.create( 34 | model="gpt-3.5-turbo", 35 | messages=[ 36 | #{"role": "system", "content": "你是一个富有情感的善于分析他人发言的发帖助手。\ 37 | # 请试着在用户输入的内容前面补充一个不是疑问句的简短上文,使得用户的输入可以恰当地衔接你的上文。\ 38 | # 你的上文可以是用户输入之前发生的事情,也可以是一个短暂的前情提要,但不允许总结或重复用户输入。\ 39 | # 注意:是让用户输入跟随你的上文,而不是你的上文去跟随用户的输入。你的回答只需要包含上文。"}, 40 | 41 | {"role": "system", "content": "Please create a EXTREMELY brief context IN SIMPLIFIED CHINESE that precedes the Chinese user input, \ 42 | imagining a scenario where the user's input can naturally be right AFTER your context to form a complete story. \ 43 | Your context should not cover the details of the user input, but rather set the stage for it. Note: \ 44 | Your context should not contain the texts that are already in the user input. Your should output the Chinese context only, without any modifier."}, 45 | {"role": "user", "content": f"User input (in Chinese): {value}"}, 46 | ], 47 | max_tokens=128, 48 | temperature=0.95 49 | ) 50 | return response.choices[0]["message"]["content"].strip() 51 | 52 | def openai_process_original(item): 53 | tweet = item["text"] 54 | id = item["id"] 55 | res = [] 56 | loadedCount = 0 57 | 58 | # check cache first 59 | loadedCount = 0 60 | if os.path.exists(f'openai_cache/{id}.txt'): 61 | try: 62 | with open(f'openai_cache/{id}.txt', 'r') as f: 63 | res = json.load(f) 64 | if len(res) >= config.OPENAI_MAX_SAMPLE: 65 | # possibility that the cache is even larger than the max sample 66 | return tweet, res 67 | else: 68 | loadedCount = len(res) 69 | except Exception as e: 70 | print(e) 71 | pass 72 | 73 | if config.OPENAI_CACHE_ONLY: 74 | return None 75 | 76 | # now we need to invoke openai to generate the rest of the samples 77 | # sample several seed questions for each tweet 78 | for j in range(loadedCount, config.OPENAI_MAX_SAMPLE): 79 | # 75% Q&A, 25% completion 80 | if random.random() < 0.75: 81 | try: 82 | res.append(process_single_original_qa(tweet)) 83 | except Exception as e: 84 | print(e) 85 | # broken 86 | return None 87 | 88 | else: 89 | try: 90 | res.append(process_single_original_continue(tweet)) 91 | except Exception as e: 92 | print(e) 93 | # broken 94 | return None 95 | 96 | # save to cache 97 | with open(f'openai_cache/{id}.txt', 'w') as f: 98 | json.dump(res, f, indent=4, ensure_ascii=False) 99 | 100 | return tweet, res 101 | 102 | 103 | 104 | def processOriginalTweet_openai(tweets, iteration_count): 105 | # the iteration count matters, as we will use it as the index to sample cached openai prompts 106 | openai.api_key = config.OPENAI_KEY 107 | 108 | # make the openai cache directory 109 | if not os.path.exists('openai_cache'): 110 | os.makedirs('openai_cache') 111 | 112 | # because it's sampling from openai, we do not need to do Q/A or completion 113 | # we better just do random sampling with a seed question for a small portion of samples 114 | 115 | final = [] 116 | sample_threshold = 0.10 if iteration_count == 0 else 0 # sample 10% of the tweets 117 | 118 | for item in tweets: 119 | # sample a random float from 0-1 to decide the ways of generation 120 | # sample_range is a probablity accumulative list 121 | # [random post, completion, Q&A, rest (direct original post)] 122 | tweet = item["text"] 123 | rr = random.random() 124 | if rr < sample_threshold: 125 | # sample a random question, and concatenate 126 | instruction = f"{random.choice(original_post_prompt)}" 127 | user_input = f"" 128 | if checkResponse(tweet): 129 | final.append({"instruction": instruction, "input": user_input, "output": tweet}) 130 | elif rr < sample_threshold * 2: 131 | # no instructions, unconditional generation. 132 | final.append({"instruction": "", "input": "", "output": tweet}) 133 | else: 134 | # do nothing, as we have made the sampling successfully 135 | pass 136 | 137 | # now proceed to the openai generation 138 | openai_process_list = [] 139 | for item in tweets: 140 | tweet = item["text"] 141 | if checkResponse(tweet): 142 | openai_process_list.append(item) 143 | 144 | # now invoke a threadpool to process the openai generation 145 | with futures.ThreadPoolExecutor(max_workers=config.OPENAI_THREADS) as executor: 146 | results = executor.map(openai_process_original, openai_process_list) 147 | 148 | for r in results: 149 | if r: 150 | tweet, response = r 151 | if iteration_count < len(response): 152 | final.append({"instruction": response[iteration_count], "input": "", "output": tweet}) 153 | else: 154 | # we do not have enough cached openai prompts 155 | pass 156 | 157 | return final -------------------------------------------------------------------------------- /prompt_util.py: -------------------------------------------------------------------------------- 1 | import re 2 | import random 3 | import config 4 | 5 | import jieba 6 | import jieba.posseg as pseg 7 | 8 | 9 | original_post_prompt = [ 10 | "最近过得怎么样?", 11 | "你这段时间都在忙什么?", 12 | "离我们上次见面以来,有什么新鲜事吗?", 13 | "最近有什么有趣的经历吗?", 14 | "你近期的生活如何?", 15 | "距离我们上次见面已经过了一段时间了,你过得好吗?", 16 | "你最近都在做些什么?", 17 | "近期有什么重要的变化吗?", 18 | "有什么新的发展吗?", 19 | "能和我分享一下你近期的生活吗?", 20 | "你最近有什么特别的事情发生吗?", 21 | "你最近的生活有什么变化吗?", 22 | "最近有没有什么大事发生?", 23 | "过去的几个月里,你都经历了些什么?", 24 | "你过去一段时间都在忙些什么?", 25 | "有什么新鲜事要和我分享吗?", 26 | "这段时间有没有什么让你难忘的事情发生?", 27 | "最近在工作或生活方面有什么挑战吗?", 28 | "从我们上次见面到现在,你都遇到了哪些有意思的事情?", 29 | "过去几个月你都忙些什么?", 30 | "最近你都有哪些新发现?", 31 | "你近期有哪些值得分享的经历?", 32 | "跟我聊聊你最近的日常吧?", 33 | "近期有什么值得一提的事吗?", 34 | "这段时间有没有什么特别的回忆?", 35 | "最近有没有什么让你感到骄傲的事?", 36 | "你最近有什么惊喜发生吗?", 37 | "过去一段时间,你的生活中有哪些亮点?", 38 | "近来你都有哪些新的尝试?", 39 | "你最近的日子过得怎样?有什么精彩的时刻?", 40 | "这段时间有什么特别的见闻吗?", 41 | "从上次见面到现在,你有什么有趣的故事吗?", 42 | "最近有没有什么特殊的事情让你感到开心?", 43 | "你最近有什么新的梦想或目标吗?", 44 | "过去的一段时间里,你有什么难忘的经历吗?", 45 | "你最近有没有什么特别的发现或灵感?", 46 | "最近有没有什么突破性的成就?", 47 | "这段时间你都学到了哪些新知识?", 48 | "最近有没有什么值得一提的事情发生在你身边?", 49 | "最近你有什么值得庆祝的事情吗?", 50 | "从我们上次见面以来,你都有哪些新的探险?", 51 | "最近有没有什么让你感到充实的事情?", 52 | "过去一段时间里,你有什么特别的体验?", 53 | "你近期有什么令人惊讶的收获?", 54 | "跟我分享一下你最近的喜怒哀乐吧?" 55 | ] 56 | 57 | related_topic_prompt = [ 58 | "关于[AAA],你有什么想法?", 59 | "你觉得[AAA]在日常生活中扮演着什么角色?", 60 | "你曾经与[AAA]有过什么相关经历吗?", 61 | "你觉得[AAA]有哪些相关的话题值得探讨?", 62 | "关于[AAA],你有什么好的经验可以分享吗?", 63 | "你如何看待[AAA]与我们生活中的其他方面的关系?", 64 | "你觉得[AAA]有哪些值得关注的趋势?", 65 | "你认为[AAA]对于我们的日常生活有多大意义?", 66 | "你觉得[AAA]有哪些令人好奇的方面?", 67 | "你有没有关于[AAA]的有趣故事?", 68 | "你觉得[AAA]有哪些令人欣赏的方面?", 69 | "关于[AAA],你觉得有哪些有趣的讨论点?", 70 | "关于[AAA],你有什么新的见解吗?", 71 | "你觉得[AAA]是如何影响你生活的?", 72 | "你曾经有过关于[AAA]的有趣经历吗?", 73 | "关于[AAA],有哪些热门话题值得关注?", 74 | "你有没有关于[AAA]的特别记忆?", 75 | "你觉得[AAA]与你生活中的其他元素有何联系?", 76 | "你觉得[AAA]有哪些令人惊讶的方面?", 77 | "你有没有听过关于[AAA]的趣闻轶事?", 78 | "关于[AAA],你觉得有哪些引人入胜的话题?", 79 | "你对[AAA]有哪些深刻的印象?", 80 | "你有没有关于[AAA]的奇闻异事?", 81 | "关于[AAA],你有什么特别的见解吗?", 82 | "你有没有关于[AAA]的特别喜好?", 83 | "你对[AAA]有哪些独到的见解?", 84 | "你觉得[AAA]如何影响我们的思考方式?", 85 | "你有没有关于[AAA]的趣味事例?", 86 | "你觉得[AAA]有哪些令人叹为观止的特点?", 87 | "关于[AAA],你有什么富有启发性的想法?", 88 | "你有没有关于[AAA]的奇特经历?", 89 | "你觉得[AAA]有哪些吸引人的特质?", 90 | "关于[AAA],你有什么令人兴奋的见解?" 91 | ] 92 | 93 | def cut_sent(text): 94 | sub_sentences = re.split(r'([\。|\!|\?|\;|\,|\n])', text) 95 | sub_sentences = [s1 + s2 for s1, s2 in zip(sub_sentences[0::2], sub_sentences[1::2])] + ([sub_sentences[-1]] if len(sub_sentences) % 2 == 1 else []) 96 | return [s.replace("\n", "") for s in sub_sentences if (s.strip() != "" and s.strip() != "(media)" and s.strip() != "(link)")] 97 | 98 | 99 | def findTopic(md): 100 | substring = cut_sent(md) 101 | if len(substring) > 1: 102 | rr = random.randint(0, len(substring)-1) 103 | # choose a random substring, maximum length is 5 104 | topic = substring[rr] 105 | # cut off the last word, if it's a chinese punctuation 106 | if topic[-1] in [",", "。", "!", "?", ";"]: 107 | topic = topic[:-1] 108 | 109 | # use jieba to do the word segmentation and pos tagging 110 | tokens = [word for word,flag in pseg.cut(topic) if 'n' in flag] 111 | if len(tokens) > 0: 112 | topic = random.choice(tokens) 113 | else: 114 | # it's just it. 115 | topic = topic 116 | 117 | instruction = random.choice(related_topic_prompt).replace("[AAA]", topic) 118 | return instruction 119 | else: 120 | return None 121 | 122 | 123 | def checkResponse(response): 124 | # check if the user's response is too short. Filter it out. 125 | if len(response.replace("\n","").replace(" ", "").replace("(media)","").replace("(link)", "")) < config.RESPONSE_THRESH: 126 | return False 127 | return True -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # int8 2 | bitsandbytes==0.38.1 3 | accelerate==0.17.1 4 | 5 | # chatglm 6 | protobuf>=3.19.5,<3.20.1 7 | transformers==4.27.1 8 | icetk 9 | cpm_kernels==1.0.11 10 | torch>=1.13.1 11 | 12 | # 13 | datasets==2.10.1 14 | git+https://github.com/huggingface/peft.git # 最新版本 >=0.3.0.dev0 15 | 16 | # Chinese tokenization 17 | jieba 18 | 19 | # Advanced twitter scraping (beta) 20 | selenium 21 | beautifulsoup4 22 | webdriver-manager 23 | 24 | # parsing and sampling (should be enabled by default) 25 | numpy 26 | openai -------------------------------------------------------------------------------- /scrape_twitter.py: -------------------------------------------------------------------------------- 1 | # Warning: research purposes only. This code is not intended for production use. Use at your own risk. 2 | # Scraping Twitter may violate the Terms of Service. 3 | 4 | import multiprocessing 5 | import re 6 | import threading 7 | import zipfile 8 | from selenium import webdriver 9 | from webdriver_manager.chrome import ChromeDriverManager 10 | from bs4 import BeautifulSoup 11 | import time 12 | import concurrent.futures 13 | import json 14 | import os 15 | import config 16 | 17 | from selenium.webdriver.common.by import By 18 | from selenium.webdriver.support.ui import WebDriverWait 19 | from selenium.webdriver.support import expected_conditions as EC 20 | from selenium.webdriver.common.proxy import Proxy, ProxyType 21 | from selenium.webdriver.common.keys import Keys 22 | 23 | class document_is_ready(object): 24 | def __call__(self, driver): 25 | ready_state = driver.execute_script("return document.readyState") 26 | return ready_state == "complete" 27 | 28 | def get_tweet_info(tweet): 29 | 30 | # check out the cache first, 31 | # if hit, return the cached result 32 | tweet_id = tweet['id'] 33 | 34 | try: 35 | with open(f'tweet_cache/{tweet_id}.txt', 'r') as f: 36 | return {'id': tweet_id, 'text': tweet['text'], 'context': json.load(f)} 37 | except: 38 | pass 39 | 40 | if config.SCRAPE_CACHE_ONLY: 41 | return None 42 | 43 | driver = get_driver() 44 | 45 | try: 46 | url = f"https://mobile.twitter.com/ljsabc/status/{tweet_id}" 47 | driver.get(url) 48 | WebDriverWait(driver, config.SCRAPE_TIMEOUT).until(EC.presence_of_element_located((By.XPATH, '//*[@data-testid="bookmark"]'))) 49 | 50 | # Wait for DOM to be ready 51 | wait = WebDriverWait(driver, config.SCRAPE_TIMEOUT) 52 | wait.until(document_is_ready()) 53 | 54 | body_element = driver.find_element(By.TAG_NAME, "body") 55 | 56 | 57 | for j in range(2): 58 | # looks like we need to scroll up a few times to get the full context 59 | # 2 times for scroll up should be okay for most of the tweets 60 | # if you tweet longer you may try a longer range 61 | for j in range(3): 62 | body_element.send_keys(Keys.PAGE_UP) 63 | time.sleep(0.15) 64 | # introduce a small delay to let the page load 65 | # the delays can be hidden by using more threads 66 | time.sleep(1.0) 67 | 68 | soup = BeautifulSoup(driver.page_source, 'html.parser') 69 | articles = soup.find_all('article') 70 | 71 | if articles is None: 72 | # weird, probably a deleted tweet 73 | # or, possibly web driver failed to load the page 74 | # at this moment, we just return None and let the caller handle it 75 | return None 76 | 77 | target_index = -1 78 | for i, article in enumerate(articles): 79 | if article.find(lambda tag: tag.get('data-testid') == 'bookmark'): 80 | target_index = i 81 | break 82 | 83 | # without scroll to top, the tweet context may not be complete 84 | # but as we are only targetting a small context range, it should not hurt. 85 | print(f"id: {tweet_id}, located {target_index + 1} of {len(articles)} tweets.") 86 | 87 | results = [] 88 | if target_index >= 0: 89 | for i, article in enumerate(articles): 90 | if i < target_index: 91 | #f True: 92 | target = article.find('div', {'data-testid': 'tweetText'}) 93 | if target: 94 | tweet_text = target.get_text() 95 | print(i, tweet_text) 96 | results.append(tweet_text) 97 | 98 | # dump the results to a file 99 | with open(f'tweet_cache/{tweet_id}.txt', 'w') as f: 100 | json.dump(results, f, indent=4, ensure_ascii=False) 101 | 102 | print(f"Original tweet: {tweet['text']}\n\n") 103 | 104 | return {'id': tweet_id, 'text': tweet['text'], 'context': results} 105 | except Exception as e: 106 | print(e) 107 | #return {'id': tweet_id, 'text': tweet['text'], 'context': None} 108 | return None 109 | 110 | 111 | 112 | threadLocal = threading.local() 113 | 114 | def get_driver(): 115 | driver = getattr(threadLocal, 'driver', None) 116 | if driver is None: 117 | mobile_user_agent = 'Mozilla/5.0 (Linux; Android 9; Pixel 3 Build/PQ3A.190705.001) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/75.0.3770.101 Mobile Safari/537.36' 118 | chrome_options = webdriver.ChromeOptions() 119 | chrome_options.add_argument(f'user-agent={mobile_user_agent}') 120 | prefs = {"profile.managed_default_content_settings.images": 2} 121 | chrome_options.add_experimental_option("prefs", prefs) 122 | chrome_options.add_argument(f"user-data-dir=chrome_profile/{threading.current_thread().ident}") 123 | 124 | # The window is place on purpose to avoid "error request". 125 | 126 | if config.PROXY: 127 | #proxy or proxy pool, it will be really useful 128 | pluginfile = 'proxy_auth_plugin.zip' 129 | manifest_json = """ 130 | { 131 | "version": "1.0.0", 132 | "manifest_version": 2, 133 | "name": "Chrome Proxy", 134 | "permissions": [ 135 | "proxy", 136 | "tabs", 137 | "unlimitedStorage", 138 | "storage", 139 | "", 140 | "webRequest", 141 | "webRequestBlocking" 142 | ], 143 | "background": { 144 | "scripts": ["background.js"] 145 | }, 146 | "minimum_chrome_version":"22.0.0" 147 | } 148 | """ 149 | 150 | background_js = """ 151 | var config = { 152 | mode: "fixed_servers", 153 | rules: { 154 | singleProxy: { 155 | scheme: "http", 156 | host: "%s", 157 | port: parseInt(%s) 158 | }, 159 | bypassList: ["localhost"] 160 | } 161 | }; 162 | 163 | chrome.proxy.settings.set({value: config, scope: "regular"}, function() {}); 164 | 165 | function callbackFn(details) { 166 | return { 167 | authCredentials: { 168 | username: "%s", 169 | password: "%s" 170 | } 171 | }; 172 | } 173 | 174 | chrome.webRequest.onAuthRequired.addListener( 175 | callbackFn, 176 | {urls: [""]}, 177 | ['blocking'] 178 | );""" % (config.PROXY_ADDR, config.PROXY_PORT, config.PROXY_USER, config.PROXY_PASSWD) 179 | 180 | 181 | with zipfile.ZipFile(pluginfile, 'w') as zp: 182 | zp.writestr("manifest.json", manifest_json) 183 | zp.writestr("background.js", background_js) 184 | chrome_options.add_extension(pluginfile) 185 | driver = webdriver.Chrome(executable_path='./chromedriver', options=chrome_options) 186 | else: 187 | driver = webdriver.Chrome(executable_path='./chromedriver', options=chrome_options) 188 | driver.set_window_size(340, 695) 189 | setattr(threadLocal, 'driver', driver) 190 | 191 | return driver 192 | 193 | def process_tweet_ids(tweets): 194 | # create a cache folder to store the scraped tweets 195 | if not os.path.exists('tweet_cache'): 196 | os.mkdir('tweet_cache') 197 | 198 | if not os.path.exists('chrome_profile'): 199 | os.mkdir('chrome_profile') 200 | else: 201 | # remove the dir, then create a new one 202 | # this is to avoid the cache issue 203 | import shutil 204 | shutil.rmtree('chrome_profile') 205 | os.mkdir('chrome_profile') 206 | 207 | tweet_data = [] 208 | pool = concurrent.futures.ThreadPoolExecutor(max_workers=config.PROCESSES) 209 | for result in pool.map(get_tweet_info, tweets): 210 | if result is not None: 211 | tweet_data.append(result) 212 | 213 | return tweet_data 214 | 215 | if __name__ == '__main__': 216 | # Replace with a list of tweet IDs you want to fetch information for 217 | tweet_ids = ['1639865895374958592', '1639689372059725825', '1639488084705439745', '1639639351897522176'] 218 | tweet_data = process_tweet_ids(tweet_ids) 219 | 220 | for data in tweet_data: 221 | print(data) -------------------------------------------------------------------------------- /tokenize_dataset_rows.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from tqdm import tqdm 4 | 5 | import datasets 6 | import transformers 7 | 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--json_path", type=str, default="data/alpaca_data.jsonl") 11 | parser.add_argument("--save_path", type=str, default="data/alpaca") 12 | parser.add_argument("--max_seq_length", type=int, default=384) 13 | parser.add_argument("--skip_overlength", type=bool, default=False) 14 | args = parser.parse_args() 15 | 16 | model_name = "THUDM/chatglm-6b" 17 | tokenizer = transformers.AutoTokenizer.from_pretrained( 18 | model_name, trust_remote_code=True) 19 | config = transformers.AutoConfig.from_pretrained( 20 | model_name, trust_remote_code=True, device_map='auto') 21 | 22 | def format_example(example: dict) -> dict: 23 | context = f"Instruction: {example['instruction']}\n" 24 | if example.get("input"): 25 | context += f"Input: {example['input']}\n" 26 | context += "Answer: " 27 | target = example["output"] 28 | return {"context": context, "target": target} 29 | 30 | 31 | def preprocess(tokenizer, config, example, max_seq_length): 32 | example = format_example(example) 33 | prompt = example["context"] 34 | target = example["target"] 35 | prompt_ids = tokenizer.encode(prompt, max_length=max_seq_length, truncation=True) 36 | target_ids = tokenizer.encode( 37 | target, 38 | max_length=max_seq_length, 39 | truncation=True, 40 | add_special_tokens=False) 41 | input_ids = prompt_ids + target_ids + [config.eos_token_id] 42 | return {"input_ids": input_ids, "seq_len": len(prompt_ids)} 43 | 44 | 45 | def read_jsonl(path, max_seq_length, skip_overlength=False): 46 | model_name = "THUDM/chatglm-6b" 47 | tokenizer = transformers.AutoTokenizer.from_pretrained( 48 | model_name, trust_remote_code=True) 49 | config = transformers.AutoConfig.from_pretrained( 50 | model_name, trust_remote_code=True, device_map='auto') 51 | with open(path, "r") as f: 52 | for line in tqdm(f.readlines()): 53 | example = json.loads(line) 54 | feature = preprocess(tokenizer, config, example, max_seq_length) 55 | if skip_overlength and len(feature["input_ids"]) > max_seq_length: 56 | continue 57 | feature["input_ids"] = feature["input_ids"][:max_seq_length] 58 | yield feature 59 | 60 | def parse(element): 61 | feature = preprocess(tokenizer, config, element, args.max_seq_length) 62 | feature["input_ids"] = feature["input_ids"][:args.max_seq_length] 63 | return feature 64 | 65 | 66 | def main(): 67 | dataset = datasets.load_dataset("json", data_files=args.json_path) 68 | train_data = dataset["train"].shuffle().map(parse, num_proc=4) 69 | train_data.save_to_disk(args.save_path) 70 | 71 | #dataset = datasets.Dataset.from_generator( 72 | # lambda: read_jsonl(args.jsonl_path, args.max_seq_length, args.skip_overlength), 73 | #) 74 | #dataset.save_to_disk(args.save_path) 75 | 76 | # poorly written generator, should better mapped, I guess. 77 | # it ignores the updates in the same jsonl file 78 | #dataset.cleanup_cache_files() 79 | 80 | 81 | if __name__ == "__main__": 82 | main() 83 | -------------------------------------------------------------------------------- /tweets_sample.md: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instruction": "想让父母认同计算机是生产力工具这个观点还真是难啊……", 4 | "input": "", 5 | "output": "这个就尽量死心吧…\n我娘和我住了几个月看我每天对着电脑摔东西就知道电脑真的可以用来干活儿了…" 6 | }, 7 | { 8 | "instruction": "然后发觉所谓的AI其实就是一套很简单的自动控制系统 https://twitter.com/thirtydb/status/1511244138112765957…\nAI=自动化机床专业", 9 | "input": "", 10 | "output": "其实没错的!" 11 | }, 12 | { 13 | "instruction": "o365的在线多人协作是真的好用…\n而且现在还有自动补全和语法检查了,非常顺", 14 | "input": "", 15 | "output": "另外这个时候3:2屏幕的优势就出来了" 16 | }, 17 | { 18 | "instruction": "最近有什么有趣的经历吗?", 19 | "input": "", 20 | "output": "我觉得我得在acknowledgement里面写上感谢waifu2x" 21 | }, 22 | { 23 | "instruction": "", 24 | "input": "", 25 | "output": "zij sounds like some math notations\n\nDijkstra" 26 | }, 27 | { 28 | "instruction": "跟我聊聊你最近的日常吧?", 29 | "input": "", 30 | "output": "满脑子里都是那个YEAH\n\n另外这曲子有十几个Remix,最早听的版本果然还是BO2在酒吧里面DJ放的那个版本,在最后credit roll里面也有完整版重放。 \n\n(media)" 31 | }, 32 | { 33 | "instruction": "最近过得怎么样?", 34 | "input": "", 35 | "output": "想写个tensorboard的iOS客户端之类的东西…" 36 | }, 37 | { 38 | "instruction": "这都是东方的莫名含蓄造成的结果啊(趴\n是,感觉很奇怪", 39 | "input": "", 40 | "output": "主要是禁止早恋但是又不跟进教育讲清楚为啥不能早恋……" 41 | }, 42 | { 43 | "instruction": "过去的几个月里,你都经历了些什么?", 44 | "input": "", 45 | "output": "吹蛤爷我不反对,但是吹也要按照基本法啊\n要避免跟瞎吹的人比爬山 比拔河\n他姿势不如你 但是身体比你好啊" 46 | }, 47 | { 48 | "instruction": "嗯你说的对呀 https://t.co/DmixTHi2sf", 49 | "input": "", 50 | "output": "可是要怎么改呢…(沉思" 51 | }, 52 | { 53 | "instruction": "直到懂车之后还是不喜欢F1的主要原因是\n我觉得这东西开着不像车=、=", 54 | "input": "", 55 | "output": "问个类似的问题,你觉得LMP1像车么……" 56 | }, 57 | { 58 | "instruction": "你近期有什么令人惊讶的收获?", 59 | "input": "", 60 | "output": "天下果然没有这样的好事儿 \n\n(media)" 61 | }, 62 | { 63 | "instruction": "", 64 | "input": "", 65 | "output": "复兴-建国 \n>RT" 66 | }, 67 | { 68 | "instruction": "120Hz是文明的啊…", 69 | "input": "", 70 | "output": "完了我已经回不去了" 71 | }, 72 | { 73 | "instruction": "映画「君の名は。」中国語吹き替え版予告編 https://youtu.be/0Qt5OecQk-I via @YouTube\n这是中国上映的那个吹替版?怎么感觉CV好业余。", 74 | "input": "", 75 | "output": "就是好业余啊……" 76 | }, 77 | { 78 | "instruction": "关于video,你觉得有哪些引人入胜的话题?", 79 | "input": "", 80 | "output": "Lets save your fucking planet earth \n\n(video)" 81 | }, 82 | { 83 | "instruction": "跟我分享一下你最近的喜怒哀乐吧?", 84 | "input": "", 85 | "output": "Everything is just so shitty that I cannot get rid of these smells" 86 | }, 87 | { 88 | "instruction": "离我们上次见面以来,有什么新鲜事吗?", 89 | "input": "", 90 | "output": "我foing掉了10个(" 91 | }, 92 | { 93 | "instruction": "你看到三个拳头的直接B了就好千万别留情面\n其实什么叫三个全头", 94 | "input": "", 95 | "output": "你看他粉丝的昵称" 96 | }, 97 | { 98 | "instruction": "下面的4599咋回事", 99 | "input": "", 100 | "output": "4600U的价格啊" 101 | }, 102 | { 103 | "instruction": "关于头文字,你觉得有哪些有趣的讨论点?", 104 | "input": "", 105 | "output": "(link)\n头文字D\n\n// 可以看秋名山开车 开心" 106 | }, 107 | { 108 | "instruction": "有iPad 用户留言\n\n“为什么我原来没发现”\n\n因为文件管理器加入的比较晚\n\n顺便光记住IP地址没用啊\n\n要确保你局域网内IP地址不是随意分配的啊", 109 | "input": "", 110 | "output": "一般第一次之后就不会变了…\n不过直接搞个nas不香嘛" 111 | }, 112 | { 113 | "instruction": "有什么新鲜事要和我分享吗?", 114 | "input": "", 115 | "output": "还是孕妇装出门QAQ \n\n(media)" 116 | }, 117 | { 118 | "instruction": "你最近的日子过得怎样?有什么精彩的时刻?", 119 | "input": "", 120 | "output": "我要是有辆车就好了" 121 | }, 122 | { 123 | "instruction": "请不要在戏院里食用 小小酥\n为什么你可以发语音!\n现在还有吗这个 感觉都好久远了\n有 真的有\n还有那个 兵卒一口脆 不就是圆块儿的方便面嘛", 124 | "input": "", 125 | "output": "那个也还有…\n我有一次看见了都快哭了\n\n不过买了一袋尝尝 觉得没有以前好吃了" 126 | }, 127 | { 128 | "instruction": "碰气!\n你不说我都快忘记他叫什么了\n就记得是碰气碰气", 129 | "input": "", 130 | "output": "我的台语都是你们教的!\n\n业务员问我:你是北京人么!\n我:是一个比北京还要靠北的地方\n韩先生:哈哈哈哈哈\n\n(" 131 | }, 132 | { 133 | "instruction": "我真的不会拼misc全称……\n可耻", 134 | "input": "", 135 | "output": "我打misc它也给我拼写检查……变成了misc.\n(" 136 | }, 137 | { 138 | "instruction": "这DPI太温柔了……", 139 | "input": "", 140 | "output": "不对,是一样的……" 141 | }, 142 | { 143 | "instruction": "", 144 | "input": "", 145 | "output": "spin大师" 146 | }, 147 | { 148 | "instruction": "离我们上次见面以来,有什么新鲜事吗?", 149 | "input": "", 150 | "output": "我已经欠了十数位“回来了一定去你家”的约定了" 151 | }, 152 | { 153 | "instruction": "索尼这个a系列单反样片app似乎是支持HDR的\n看起来真是太带感了\n\n好了有了一个新的买的理由了", 154 | "input": "", 155 | "output": "这个可以有!(赞许 \n\n(media)" 156 | }, 157 | { 158 | "instruction": "", 159 | "input": "", 160 | "output": "所以还是作者心绪的投射吧。也是我自己心灵的投射。\n\n14岁的时候读不懂《18岁出门远行》\n等到24岁的时候,就只有一声叹息了。" 161 | }, 162 | { 163 | "instruction": "草还真有啊\n藍店\n反正我没去过\n淡定淡定,重點還是價格合理營商老實", 164 | "input": "", 165 | "output": "我去旁边阿布泰它不香么(w \n\n(video)" 166 | }, 167 | { 168 | "instruction": "Apollo Pro Plus(100应该是)\n\nCPU是铜矿,比我家的370赛扬1.0(图拉丁)早一代\n\n我家那台Apollo Pro 133+图拉丁三年前还能boot,估计现在应该也能……\n370最经典的还是300A可以超到450那一批,直接把旗舰奔腾2给收割了", 169 | "input": "", 170 | "output": "那时候的赛扬体质都超级好……" 171 | }, 172 | { 173 | "instruction": "关于照片,你有什么特别的见解吗?", 174 | "input": "", 175 | "output": "欸喵 好期待w" 176 | }, 177 | { 178 | "instruction": "看着比PI都牛逼,不知道以为是ECCV的AC呢…\n真当了ac还有空水zhihu吗", 179 | "input": "", 180 | "output": "倒是的确没见过哪个ac吹自己年级第一的" 181 | }, 182 | { 183 | "instruction": "我去这个防水眼线笔好厉害\n现在都没有花\n\n补一下粉底又是条好汉(", 184 | "input": "", 185 | "output": "是呢 而且还有点流汗什么的w" 186 | }, 187 | { 188 | "instruction": "你对机箱有哪些深刻的印象?", 189 | "input": "", 190 | "output": "算了,与其折腾没有保修的4090 Blower,不如直接上4块4090 AIO算了,反正两个固定在机箱上两个支出来一样用" 191 | }, 192 | { 193 | "instruction": "设计不能么?", 194 | "input": "", 195 | "output": "懂了\n…也就是说我过关还要过条街啊" 196 | }, 197 | { 198 | "instruction": "你们一起玩了么!w", 199 | "input": "", 200 | "output": "哇啊啊啊羡慕死了" 201 | }, 202 | { 203 | "instruction": "我妈每天都在想:你天天发自曝就不怕别人笑话你?", 204 | "input": "", 205 | "output": "反正我从小到大被歧视最多的就tm家里了 必须服" 206 | }, 207 | { 208 | "instruction": "跟你讲,ARM64笔记本,除了做工没有MacBook 12“好之外(你要是等也有三爽的Galaxy Book S),其他几乎爽翻天\n\nCPU虽然捞但是8个蛋加在一起不比Core M差,最关键是永远都不会throttle\n满载续航起码三个小时,上网看报10+小时起步\n用来吃饭干活儿的工具都有Native版本(至少对我来说全了)\n还有快充\n有什么推荐的型号,我找类似的替代品找很久了。", 209 | "input": "", 210 | "output": "便宜的话Huawei的Matebook E就很不错\n如果对价格不在意那就等Galaxy Book S吧\n\n不过要注意的是写代码上网看视频都行,但是打游戏或者开PS可能就算了吧(" 211 | }, 212 | { 213 | "instruction": "你觉得gitp有哪些令人欣赏的方面?", 214 | "input": "", 215 | "output": "$ gitp ull\n$ fuck" 216 | }, 217 | { 218 | "instruction": "爽!好穿!好看!\n\n*我在奥特莱斯买的", 219 | "input": "", 220 | "output": "等我打扮好看了我给你返图" 221 | }, 222 | { 223 | "instruction": "“不行 就两天”\n佛说", 224 | "input": "", 225 | "output": "三天之内撒了你" 226 | }, 227 | { 228 | "instruction": "大部分人的话都是左边会大一点的 没毛病\n诶诶是这样么……\n但是(", 229 | "input": "", 230 | "output": "但是我就是大很多啊!(绝望\n\n(link)\n// 众人的反应" 231 | }, 232 | { 233 | "instruction": "关于卡,你有什么富有启发性的想法?", 234 | "input": "", 235 | "output": "之前记得有一个科学研究是说喜欢的音乐类型可能和听众的学力水平有相关性\n\n我在想会不会有一种可能是只有家境、教育到了一种水平才会接触到某种音乐类型?\n我父母就永远理解不了电子乐的verse,他们觉得是卡碟了……所以他们也不会推荐给我听" 236 | }, 237 | { 238 | "instruction": "啥情況啊……?\n哎……我老闆總是教育我,“你得熟悉自己不喜歡做的事情”,然而這麽下去我就也不喜歡我自己了大概", 239 | "input": "", 240 | "output": "他跟我性格其實有點像的,都是技術流派的人,結果做了professor每天開會帶學生做的都是不喜歡的事情,我真不知道他是畫餅還是自己真的接受這種生活了……\n\n反正我不行" 241 | }, 242 | { 243 | "instruction": "", 244 | "input": "", 245 | "output": "榨干.jpg居然还有续集的么!" 246 | }, 247 | { 248 | "instruction": "这段时间有没有什么特别的回忆?", 249 | "input": "", 250 | "output": "白天俩猫活蹦乱跳\n夜里就是俩死猪" 251 | }, 252 | { 253 | "instruction": "你们女装癖是不是都爱养猫啊\n\n// 这心态真的太能理解了", 254 | "input": "", 255 | "output": "二手玫瑰的吉他手\n(link)" 256 | }, 257 | { 258 | "instruction": "靠这烂尾比liar game还…\nliar game剧场版还知道玩家和庄家对赌剧情,这烂尾的什么玩意儿\n幕后黑手是谁,神崎直他爸和中国黑客啊", 259 | "input": "", 260 | "output": "动物世界我能给到7.5\nliar game一二剧场版能给到8.5\n这个最多7\n\n抢椅子游戏好玩啊 然而换女主了不是我最爱的Erika了呜呜呜" 261 | }, 262 | { 263 | "instruction": "这东西超级棒啊!w\n他家咸的小饼咸的很到位…辣的也很到位\n以前补课还有考研的时候经常吃…", 264 | "input": "", 265 | "output": "嗯!我上中学就分为农心派和格力高派wwww" 266 | }, 267 | { 268 | "instruction": "", 269 | "input": "", 270 | "output": "下一个问题\n我手台在哪儿" 271 | }, 272 | { 273 | "instruction": "垃圾学校一辈子都没有地铁\n南门在北门的北面(", 274 | "input": "", 275 | "output": "好像是每栋楼的门都标出来了" 276 | }, 277 | { 278 | "instruction": "卧槽又有喜闻乐见的男娘题材", 279 | "input": "", 280 | "output": "重案六组4神了!\n还有白学剧情的! \n\n(media)\n\n(media)\n\n(media)\n\n(media)" 281 | }, 282 | { 283 | "instruction": "40% 人选择了 O,意味着他们不愿意在 PS5 上玩 PS4 游戏。\n30% 人选择了 X,意味着他们不愿意 PS4 游戏运行在 PS5 上。\n\n综上,索尼罪恶滔天!", 284 | "input": "", 285 | "output": "你这个投票赢麻了" 286 | }, 287 | { 288 | "instruction": "饼姐姐生快!", 289 | "input": "", 290 | "output": "我也爱你和狗子!" 291 | }, 292 | { 293 | "instruction": "", 294 | "input": "", 295 | "output": "小时候就说你跟别人不一样你只要努力长大以后就一定能成为栋梁之材只手遮天\n后来我发现只手遮天的人有千千万,不差我一个" 296 | }, 297 | { 298 | "instruction": "https://zh.wikipedia.org/wiki/%E6%B7%B1%E5%9C%B3%E5%B7%B4%E5%A3%AB%E8%B7%AF%E7%BA%BF%E5%88%97%E8%A1%A8…\n自从东莞公汽成功收编了以后现行开设了不少跨市公交,比如长安、虎门公汽进松岗,光明去黄江,雁田去凤岗等等,还有两元的龙岗到惠州公交\n大湾区(不含香港)也许有一天真的会来\n据了解,近年来,我市积极融入粤港澳大湾区交通一体化布局,自2019年1月,东莞巴士公司开通首条跨市公交线路——876路(后更名“莞75路”,起讫站为国贸中心公交站和新塘白江地铁站)起,截至目前,共开通14条跨市线路,加快了我市融入粤港澳大湾区及深莞惠一体化的步伐。", 299 | "input": "", 300 | "output": "其中,莞深跨市线路7条均接驳深圳地铁,莞穗跨市线路2条,莞惠跨市线路5条,覆盖虎门高铁站、长安北站、东莞火车站、国贸中心首末站、凤岗车站等市内大型交通枢纽。" 301 | }, 302 | { 303 | "instruction": "键盘摔在地上,然后PC就休眠了,是撞到了休眠的快捷键?这个还有快捷键?", 304 | "input": "", 305 | "output": "有的呀 多媒体键w" 306 | }, 307 | { 308 | "instruction": "斑点猫!尾巴也是斑点,腿也是斑点!\n刚刚把腿看成腰了 还找了半天!", 309 | "input": "", 310 | "output": "反例:细长的尾巴条纹的腿 \n\n(media)" 311 | }, 312 | { 313 | "instruction": "我爱你啊我的母校\n这啥(", 314 | "input": "", 315 | "output": "texlive,可以理解为每年的tex工具合集" 316 | }, 317 | { 318 | "instruction": "好作文都是些大人眼中的好作文\n大人是不会欣赏这种乱七八糟的兴趣的,孩子也不会么\n话说回来这篇可真是四不像\n叙述毫无亮点,读起来颇像蒋公日记\n议论连小学生水平都没有,可笑啊", 319 | "input": "", 320 | "output": "比较无趣的是政治挂帅,观点不重要,重要的是是不是迎合当时的主流思想" 321 | }, 322 | { 323 | "instruction": "有什么新的发展吗?", 324 | "input": "", 325 | "output": "哎对呀,steam deck+Xbox Remote\n几乎除了索尼独占直接通吃了啊\n\n(老任还得自己买就是了)" 326 | }, 327 | { 328 | "instruction": "当时Anandtech评测了15W和25W的1065G7,看到25W的成绩我都快哭了,大喊出了一定买\n\n结果美帝真香想直接给你一个35W的1065G7尝尝……\n小新pro是可以35w的10710,1065g7有35w tdp的版本吗?", 329 | "input": "", 330 | "output": "PL1解了想多少多少啊" 331 | }, 332 | { 333 | "instruction": "现在这个社会太可怕啦\n我不能接受.gif https://twitter.com/MikaAkagi/status/1063223977282158592…", 334 | "input": "", 335 | "output": "另外你真的有鸡鸡诶!(失望" 336 | }, 337 | { 338 | "instruction": "你最近有什么特别的事情发生吗?", 339 | "input": "", 340 | "output": "逐渐找回自己 \n\n(media)\n\n(media)" 341 | }, 342 | { 343 | "instruction": "千 元 大 钞\n恕 不 找 续\n你要换港币么", 344 | "input": "", 345 | "output": "换啊 我各种需要rmb啊" 346 | }, 347 | { 348 | "instruction": "你这段时间都在忙什么?", 349 | "input": "", 350 | "output": "【RV770背后的故事 - AMD小核心之路】The RV770 Story: Documenting ATI's Road to Success\nvia:(link), 地址:(link)\n\n// 恍如隔世,RV770都有快10年了\n// 还记得用4850挖矿的日子么?" 351 | }, 352 | { 353 | "instruction": "你近期的生活如何?", 354 | "input": "", 355 | "output": "卓悦汇那个卖贡茶的男的xnmbyy,越说那个机器扫二维码扫不出来我可以直接扫,他就要按回去,他还一副理直气壮的样子,搞得我差点没跟他打起来" 356 | }, 357 | { 358 | "instruction": "大家好我也是女装子\n\n#不要脸", 359 | "input": "", 360 | "output": "嘛 无所谓呀。" 361 | }, 362 | { 363 | "instruction": "卧槽我想用的正则居然一遍就写对了 nice", 364 | "input": "", 365 | "output": "常用的网上都有现成的啦\n有些奇形怪状的就要自己写了" 366 | }, 367 | { 368 | "instruction": "发现有些美国人也开始把“龙”念成“多拉弓”了,不是“拽根”吗?\n这个界线好像已经很模糊了:\n对,“竜”是没翅膀的。\n这俩是一个字…\n简化字\n\nhttps://en.m.wiktionary.org/wiki/竜\n不过繁体龙似乎还是人名汉字,所以和花泽香菜一样还是旧字形", 369 | "input": "", 370 | "output": "不过年轻人也开始慢慢用新字形了,叫tatsuya写作竜也的越来越多了(" 371 | }, 372 | { 373 | "instruction": "受不了了带着队友们往前冲,", 374 | "input": "", 375 | "output": "死了之后发现家里还有一大块没涂就先涂一下我队友似乎觉得没有我一样很牛逼就继续死冲对面那当然是要g了啊" 376 | } 377 | ] -------------------------------------------------------------------------------- /twitter-parser.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | twitter-archive-parser - Python code to parse a Twitter archive and output in various ways 4 | Copyright (C) 2022 Tim Hutton - https://github.com/timhutton/twitter-archive-parser 5 | 6 | This program is free software: you can redistribute it and/or modify 7 | it under the terms of the GNU General Public License as published by 8 | the Free Software Foundation, either version 3 of the License, or 9 | (at your option) any later version. 10 | 11 | This program is distributed in the hope that it will be useful, 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | GNU General Public License for more details. 15 | 16 | You should have received a copy of the GNU General Public License 17 | along with this program. If not, see . 18 | """ 19 | 20 | from collections import defaultdict 21 | from dis import Instruction 22 | from typing import Optional 23 | from urllib.parse import urlparse 24 | import datetime 25 | import glob 26 | import importlib 27 | import json 28 | import logging 29 | import os 30 | import re 31 | import shutil 32 | import subprocess 33 | import sys 34 | import time 35 | 36 | from prompt import write_json 37 | # hot-loaded if needed, see import_module(): 38 | # imagesize 39 | # requests 40 | 41 | 42 | # Print a compile-time error in Python < 3.6. This line does nothing in Python 3.6+ but is reported to the user 43 | # as an error (because it is the first line that fails to compile) in older versions. 44 | f' Error: This script requires Python 3.6 or later. Use `python --version` to check your version.' 45 | 46 | 47 | class UserData: 48 | def __init__(self, user_id: str, handle: str): 49 | if user_id is None: 50 | raise ValueError('ID "None" is not allowed in UserData.') 51 | self.user_id = user_id 52 | if handle is None: 53 | raise ValueError('handle "None" is not allowed in UserData.') 54 | self.handle = handle 55 | 56 | 57 | class PathConfig: 58 | """ 59 | Helper class containing constants for various directories and files. 60 | 61 | The script will only add / change / delete content in its own directories, which start with `parser-`. 62 | Files within `parser-output` are the end result that the user is probably interested in. 63 | Files within `parser-cache` are temporary working files, which improve the efficiency if you run 64 | this script multiple times. They can safely be removed without harming the consistency of the 65 | files within `parser-output`. 66 | """ 67 | def __init__(self, dir_archive): 68 | self.dir_archive = dir_archive 69 | self.dir_input_data = os.path.join(dir_archive, 'data') 70 | self.file_account_js = os.path.join(self.dir_input_data, 'account.js') 71 | 72 | # check if user is in correct folder 73 | if not os.path.isfile(self.file_account_js): 74 | print(f'Error: Failed to load {self.file_account_js}. ') 75 | exit() 76 | 77 | self.dir_input_media = find_dir_input_media(self.dir_input_data) 78 | self.dir_output = os.path.join(self.dir_archive, 'parser-output') 79 | self.dir_output_media = os.path.join(self.dir_output, 'media') 80 | self.dir_output_cache = os.path.join(self.dir_archive, 'parser-cache') 81 | self.file_output_following = os.path.join(self.dir_output, 'following.txt') 82 | self.file_output_followers = os.path.join(self.dir_output, 'followers.txt') 83 | self.file_download_log = os.path.join(self.dir_output_media, 'download_log.txt') 84 | self.file_tweet_icon = os.path.join(self.dir_output_media, 'tweet.ico') 85 | self.files_input_tweets = find_files_input_tweets(self.dir_input_data) 86 | 87 | # structured like an actual tweet output file, can be used to compute relative urls to a media file 88 | self.example_file_output_tweets = self.create_path_for_file_output_tweets(year=2020, month=12) 89 | 90 | def create_path_for_file_output_tweets(self, year, month, format="html", kind="tweets") -> str: 91 | """Builds the path for a tweet-archive file based on some properties.""" 92 | # Previously the filename was f'{dt.year}-{dt.month:02}-01-Tweet-Archive-{dt.year}-{dt.month:02}' 93 | return os.path.join(self.dir_output, f"{kind}-{format}", f"{year:04}", f"{year:04}-{month:02}-01-{kind}.{format}") 94 | 95 | def create_path_for_file_output_dms(self, name: str, index: Optional[int]=None, format: str="html", kind: str="DMs") -> str: 96 | """Builds the path for a dm-archive file based on some properties.""" 97 | index_suffix = "" 98 | if (index): 99 | index_suffix = f"-part{index:03}" 100 | return os.path.join(self.dir_output, kind, f"{kind}-{name}{index_suffix}.{format}") 101 | 102 | def create_path_for_file_output_single(self, format: str, kind: str)->str: 103 | """Builds the path for a single output file which, i.e. one that is not part of a larger group or sequence.""" 104 | return os.path.join(self.dir_output, f"{kind}.{format}") 105 | 106 | 107 | def get_consent(prompt: str, default_to_yes: bool = False): 108 | """Asks the user for consent, using the given prompt. Accepts various versions of yes/no, or 109 | an empty answer to accept the default. The default is 'no' unless default_to_yes is passed as 110 | True. The default will be indicated automatically. For unacceptable answers, the user will 111 | be asked again.""" 112 | if default_to_yes: 113 | suffix = " [Y/n]" 114 | default_answer = "yes" 115 | else: 116 | suffix = " [y/N]" 117 | default_answer = "no" 118 | 119 | # a temporary hack to make the script work in a non-interactive environment 120 | return False 121 | 122 | while True: 123 | user_input = input(prompt + suffix) 124 | if user_input == "": 125 | print (f"Your empty response was assumed to mean '{default_answer}' (the default for this question).") 126 | return default_to_yes 127 | if user_input.lower() in ('y', 'yes'): 128 | return True 129 | if user_input.lower() in ('n', 'no'): 130 | return False 131 | print (f"Sorry, did not understand. Please answer with y, n, yes, no, or press enter to accept " 132 | f"the default (which is '{default_answer}' in this case, as indicated by the uppercase " 133 | f"'{default_answer.upper()[0]}'.)") 134 | 135 | 136 | def import_module(module): 137 | """Imports a module specified by a string. Example: requests = import_module('requests')""" 138 | try: 139 | return importlib.import_module(module) 140 | except ImportError: 141 | print(f'\nError: This script uses the "{module}" module which is not installed.\n') 142 | if not get_consent('OK to install using pip?'): 143 | exit() 144 | subprocess.run([sys.executable, '-m', 'pip', 'install', module], check=True) 145 | return importlib.import_module(module) 146 | 147 | 148 | def open_and_mkdirs(path_file): 149 | """Opens a file for writing. If the parent directory does not exist yet, it is created first.""" 150 | mkdirs_for_file(path_file) 151 | return open(path_file, 'w', encoding='utf-8') 152 | 153 | 154 | def mkdirs_for_file(path_file): 155 | """Creates the parent directory of the given file, if it does not exist yet.""" 156 | path_dir = os.path.split(path_file)[0] 157 | os.makedirs(path_dir, exist_ok=True) 158 | 159 | 160 | def rel_url(media_path, document_path): 161 | """Computes the relative URL needed to link from `document_path` to `media_path`. 162 | Assumes that `document_path` points to a file (e.g. `.md` or `.html`), not a directory.""" 163 | return os.path.relpath(media_path, os.path.split(document_path)[0]).replace("\\", "/") 164 | 165 | 166 | def get_twitter_api_guest_token(session, bearer_token): 167 | """Returns a Twitter API guest token for the current session.""" 168 | guest_token_response = session.post("https://api.twitter.com/1.1/guest/activate.json", 169 | headers={'authorization': f'Bearer {bearer_token}'}, 170 | timeout=2, 171 | ) 172 | guest_token = json.loads(guest_token_response.content)['guest_token'] 173 | if not guest_token: 174 | raise Exception(f"Failed to retrieve guest token") 175 | return guest_token 176 | 177 | 178 | def get_twitter_users(session, bearer_token, guest_token, user_ids): 179 | """Asks Twitter for all metadata associated with user_ids.""" 180 | users = {} 181 | while user_ids: 182 | max_batch = 100 183 | user_id_batch = user_ids[:max_batch] 184 | user_ids = user_ids[max_batch:] 185 | user_id_list = ",".join(user_id_batch) 186 | query_url = f"https://api.twitter.com/1.1/users/lookup.json?user_id={user_id_list}" 187 | response = session.get(query_url, 188 | headers={'authorization': f'Bearer {bearer_token}', 'x-guest-token': guest_token}, 189 | timeout=2, 190 | ) 191 | if not response.status_code == 200: 192 | raise Exception(f'Failed to get user handle: {response}') 193 | response_json = json.loads(response.content) 194 | for user in response_json: 195 | users[user["id_str"]] = user 196 | return users 197 | 198 | 199 | def lookup_users(user_ids, users): 200 | """Fill the users dictionary with data from Twitter""" 201 | # Filter out any users already known 202 | filtered_user_ids = [id for id in user_ids if id not in users] 203 | if not filtered_user_ids: 204 | # Don't bother opening a session if there's nothing to get 205 | return 206 | # Account metadata observed at ~2.1KB on average. 207 | estimated_size = int(2.1 * len(filtered_user_ids)) 208 | print(f'{len(filtered_user_ids)} users are unknown.') 209 | if not get_consent(f'Download user data from Twitter (approx {estimated_size:,} KB)?'): 210 | return 211 | 212 | requests = import_module('requests') 213 | try: 214 | with requests.Session() as session: 215 | bearer_token = 'AAAAAAAAAAAAAAAAAAAAANRILgAAAAAAnNwIzUejRCOuH5E6I8xnZz4puTs%3D1Zv7ttfk8LF81IUq16cHjhLTvJu4FA33AGWWjCpTnA' 216 | guest_token = get_twitter_api_guest_token(session, bearer_token) 217 | retrieved_users = get_twitter_users(session, bearer_token, guest_token, filtered_user_ids) 218 | for user_id, user in retrieved_users.items(): 219 | if user["screen_name"] is not None: 220 | users[user_id] = UserData(user_id=user_id, handle=user["screen_name"]) 221 | print() # empty line for better readability of output 222 | except Exception as err: 223 | print(f'Failed to download user data: {err}') 224 | 225 | 226 | def read_json_from_js_file(filename): 227 | """Reads the contents of a Twitter-produced .js file into a dictionary.""" 228 | print(f'Parsing {filename}...') 229 | with open(filename, 'r', encoding='utf8') as f: 230 | data = f.readlines() 231 | # if the JSON has no real content, it can happen that the file is only one line long. 232 | # in this case, return an empty dict to avoid errors while trying to read non-existing lines. 233 | if len(data) <= 1: 234 | return {} 235 | # convert js file to JSON: replace first line with just '[', squash lines into a single string 236 | prefix = '[' 237 | if '{' in data[0]: 238 | prefix += ' {' 239 | data = prefix + ''.join(data[1:]) 240 | # parse the resulting JSON and return as a dict 241 | return json.loads(data) 242 | 243 | 244 | def extract_username(paths: PathConfig): 245 | """Returns the user's Twitter username from account.js.""" 246 | account = read_json_from_js_file(paths.file_account_js) 247 | return account[0]['account']['username'] 248 | 249 | 250 | def escape_markdown(input_text: str) -> str: 251 | """ 252 | Escape markdown control characters from input text so that the text will not break in rendered markdown. 253 | (Only use on unformatted text parts that do not yet have any markdown control characters added on purpose!) 254 | """ 255 | # disable escapes permanently 256 | return input_text 257 | 258 | characters_to_escape: str = r"\_*[]()~`>#+-=|{}.!" 259 | output_text: str = '' 260 | for char in input_text: 261 | if char in characters_to_escape: 262 | # add backslash before control char 263 | output_text = output_text + "\\" + char 264 | elif char == '\n': 265 | # add double space before line break 266 | output_text = output_text + " " + char 267 | else: 268 | output_text = output_text + char 269 | return output_text 270 | 271 | def collect_tweet_references(tweet): 272 | if 'tweet' in tweet.keys(): 273 | tweet = tweet['tweet'] 274 | tweet_ids = set() 275 | in_reply_to = False 276 | quote = False 277 | retweet = False 278 | # Collect quoted tweets 279 | if 'entities' in tweet and 'urls' in tweet['entities']: 280 | for url in tweet['entities']['urls']: 281 | if 'url' in url and 'expanded_url' in url: 282 | expanded_url = url['expanded_url'] 283 | matches = re.match(r'^https://twitter.com/([0-9A-Za-z_]*)/status/(\d+)$', expanded_url) 284 | if (matches): 285 | quote = True 286 | 287 | # Collect previous tweets in conversation 288 | if 'in_reply_to_status_id_str' in tweet: 289 | in_reply_to = True 290 | 291 | # Collect RT retweets 292 | if 'full_text' in tweet and tweet['full_text'].startswith('RT @'): 293 | retweet = True 294 | 295 | # TODO: really parse the tweet ids into texts 296 | # needs to create 3 sets for each type. 297 | return in_reply_to, quote, retweet 298 | 299 | 300 | def convert_tweet(tweet, username, media_sources, users: dict, paths: PathConfig): 301 | """Converts a JSON-format tweet. Returns tuple of timestamp, markdown and HTML.""" 302 | if 'tweet' in tweet.keys(): 303 | tweet = tweet['tweet'] 304 | 305 | if not 'created_at' in tweet: 306 | # seems to be a bug in archive: 307 | return None, None, None, None, None, None, None 308 | 309 | timestamp_str = tweet['created_at'] 310 | timestamp = int(round(datetime.datetime.strptime(timestamp_str, '%a %b %d %X %z %Y').timestamp())) 311 | # Example: Tue Mar 19 14:05:17 +0000 2019 312 | body_markdown = tweet['full_text'] 313 | body_html = tweet['full_text'] 314 | tweet_id_str = tweet['id_str'] 315 | 316 | # for old tweets before embedded t.co redirects were added, ensure the links are 317 | # added to the urls entities list so that we can build correct links later on. 318 | if 'entities' in tweet and 'media' not in tweet['entities'] and len(tweet['entities'].get("urls", [])) == 0: 319 | for word in tweet['full_text'].split(): 320 | try: 321 | url = urlparse(word) 322 | except ValueError: 323 | pass # don't crash when trying to parse something that looks like a URL but actually isn't 324 | else: 325 | if url.scheme != '' and url.netloc != '' and not word.endswith('\u2026'): 326 | # Shorten links similar to twitter 327 | netloc_short = url.netloc[4:] if url.netloc.startswith("www.") else url.netloc 328 | path_short = url.path if len(url.path + '?' + url.query) < 15 \ 329 | else (url.path + '?' + url.query)[:15] + '\u2026' 330 | tweet['entities']['urls'].append({ 331 | 'url': word, 332 | 'expanded_url': word, 333 | 'display_url': netloc_short + path_short, 334 | 'indices': [tweet['full_text'].index(word), tweet['full_text'].index(word) + len(word)], 335 | }) 336 | # replace t.co URLs with their original versions 337 | if 'entities' in tweet and 'urls' in tweet['entities']: 338 | for url in tweet['entities']['urls']: 339 | if 'url' in url and 'expanded_url' in url: 340 | expanded_url = url['expanded_url'] 341 | #body_markdown = body_markdown.replace(url['url'], expanded_url) 342 | body_markdown = body_markdown.replace(url['url'], '(link)') 343 | expanded_url_html = f'{expanded_url}' 344 | body_html = body_html.replace(url['url'], expanded_url_html) 345 | # if the tweet is a reply, construct a header that links the names 346 | # of the accounts being replied to the tweet being replied to 347 | header_markdown = '' 348 | header_html = '' 349 | if 'in_reply_to_status_id' in tweet: 350 | # match and remove all occurrences of '@username ' at the start of the body 351 | replying_to = re.match(r'^(@[0-9A-Za-z_]* )*', body_markdown)[0] 352 | if replying_to: 353 | body_markdown = body_markdown[len(replying_to):] 354 | body_html = body_html[len(replying_to):] 355 | else: 356 | # no '@username ' in the body: we're replying to self 357 | replying_to = f'@{username}' 358 | names = replying_to.split() 359 | # some old tweets lack 'in_reply_to_screen_name': use it if present, otherwise fall back to names[0] 360 | in_reply_to_screen_name = tweet['in_reply_to_screen_name'] if 'in_reply_to_screen_name' in tweet else names[0] 361 | # create a list of names of the form '@name1, @name2 and @name3' - or just '@name1' if there is only one name 362 | name_list = ', '.join(names[:-1]) + (f' and {names[-1]}' if len(names) > 1 else names[0]) 363 | in_reply_to_status_id = tweet['in_reply_to_status_id'] 364 | replying_to_url = f'https://twitter.com/{in_reply_to_screen_name}/status/{in_reply_to_status_id}' 365 | header_markdown += f'Replying to [{escape_markdown(name_list)}]({replying_to_url})\n\n' 366 | header_html += f'Replying to {name_list}
' 367 | # escape tweet body for markdown rendering: 368 | # no need to escape here as it will be tokenized again. 369 | body_markdown = escape_markdown(body_markdown) 370 | #body_markdown = body_markdown.strip().replace("\n", r"\n") 371 | # replace image URLs with image links to local files 372 | if 'entities' in tweet and 'media' in tweet['entities'] and 'extended_entities' in tweet \ 373 | and 'media' in tweet['extended_entities']: 374 | original_url = tweet['entities']['media'][0]['url'] 375 | markdown = '' 376 | html = '' 377 | for media in tweet['extended_entities']['media']: 378 | if 'url' in media and 'media_url' in media: 379 | original_expanded_url = media['media_url'] 380 | original_filename = os.path.split(original_expanded_url)[1] 381 | archive_media_filename = tweet_id_str + '-' + original_filename 382 | archive_media_path = os.path.join(paths.dir_input_media, archive_media_filename) 383 | file_output_media = os.path.join(paths.dir_output_media, archive_media_filename) 384 | media_url = rel_url(file_output_media, paths.example_file_output_tweets) 385 | markdown += '' if not markdown and body_markdown == escape_markdown(original_url) else '\n\n' 386 | #markdown += '' if not markdown and body_markdown == original_url else '\n\n' 387 | html += '' if not html and body_html == original_url else '
' 388 | if os.path.isfile(archive_media_path): 389 | # Found a matching image, use this one 390 | if not os.path.isfile(file_output_media): 391 | shutil.copy(archive_media_path, file_output_media) 392 | #markdown += f'![]({media_url})' 393 | # ignore media at this moment 394 | markdown += '(media)' 395 | html += f'' 396 | # Save the online location of the best-quality version of this file, for later upgrading if wanted 397 | best_quality_url = f'https://pbs.twimg.com/media/{original_filename}:orig' 398 | media_sources.append( 399 | (os.path.join(paths.dir_output_media, archive_media_filename), best_quality_url) 400 | ) 401 | else: 402 | # Is there any other file that includes the tweet_id in its filename? 403 | archive_media_paths = glob.glob(os.path.join(paths.dir_input_media, tweet_id_str + '*')) 404 | if len(archive_media_paths) > 0: 405 | for archive_media_path in archive_media_paths: 406 | archive_media_filename = os.path.split(archive_media_path)[-1] 407 | file_output_media = os.path.join(paths.dir_output_media, archive_media_filename) 408 | media_url = rel_url(file_output_media, paths.example_file_output_tweets) 409 | if not os.path.isfile(file_output_media): 410 | shutil.copy(archive_media_path, file_output_media) 411 | #markdown += f'\n' 413 | markdown += '(video)' 414 | html += f'\n' 416 | # Save the online location of the best-quality version of this file, 417 | # for later upgrading if wanted 418 | if 'video_info' in media and 'variants' in media['video_info']: 419 | best_quality_url = '' 420 | best_bitrate = -1 # some valid videos are marked with bitrate=0 in the JSON 421 | for variant in media['video_info']['variants']: 422 | if 'bitrate' in variant: 423 | bitrate = int(variant['bitrate']) 424 | if bitrate > best_bitrate: 425 | best_quality_url = variant['url'] 426 | best_bitrate = bitrate 427 | if best_bitrate == -1: 428 | print(f"Warning No URL found for {original_url} {original_expanded_url} " 429 | f"{archive_media_path} {media_url}") 430 | print(f"JSON: {tweet}") 431 | else: 432 | media_sources.append( 433 | (os.path.join(paths.dir_output_media, archive_media_filename), 434 | best_quality_url) 435 | ) 436 | else: 437 | print(f'Warning: missing local file: {archive_media_path}. Using original link instead: ' 438 | f'{original_url} (expands to {original_expanded_url})') 439 | #markdown += f'![]({original_url})' 440 | markdown += "(media)" 441 | html += f'{original_url}' 442 | body_markdown = body_markdown.replace(escape_markdown(original_url), markdown) 443 | #body_markdown = body_markdown.replace(original_url, markdown) 444 | body_html = body_html.replace(original_url, html) 445 | # make the body a quote 446 | body_markdown = body_markdown 447 | body_html = '

' + '
\n'.join(body_html.splitlines()) + '
' 448 | # append the original Twitter URL as a link 449 | original_tweet_url = f'https://twitter.com/{username}/status/{tweet_id_str}' 450 | icon_url = rel_url(paths.file_tweet_icon, paths.example_file_output_tweets) 451 | 452 | # we do not like header because it's slightly confusing 453 | # TODO: consider if we do need images in this markdown 454 | # timestamp is not needed as it's for corpus 455 | body_markdown = body_markdown #+ f'\n\n ' \ 456 | #f'[{timestamp_str}]({original_tweet_url})' 457 | body_html = header_html + body_html + f' {timestamp_str}

' 459 | # extract user_id:handle connections 460 | if 'in_reply_to_user_id' in tweet and 'in_reply_to_screen_name' in tweet and \ 461 | tweet['in_reply_to_screen_name'] is not None: 462 | reply_to_id = tweet['in_reply_to_user_id'] 463 | if int(reply_to_id) >= 0: # some ids are -1, not sure why 464 | handle = tweet['in_reply_to_screen_name'] 465 | users[reply_to_id] = UserData(user_id=reply_to_id, handle=handle) 466 | if 'entities' in tweet and 'user_mentions' in tweet['entities'] and tweet['entities']['user_mentions'] is not None: 467 | for mention in tweet['entities']['user_mentions']: 468 | if mention is not None and 'id' in mention and 'screen_name' in mention: 469 | mentioned_id = mention['id'] 470 | if int(mentioned_id) >= 0: # some ids are -1, not sure why 471 | handle = mention['screen_name'] 472 | if handle is not None: 473 | users[mentioned_id] = UserData(user_id=mentioned_id, handle=handle) 474 | 475 | # extract the type of tweet 476 | in_reply_to, quote, retweet = collect_tweet_references(tweet) 477 | 478 | return timestamp, tweet_id_str, body_markdown, body_html, in_reply_to, quote, retweet 479 | 480 | 481 | def find_files_input_tweets(dir_path_input_data): 482 | """Identify the tweet archive's file and folder names - 483 | they change slightly depending on the archive size it seems.""" 484 | input_tweets_file_templates = ['tweet.js', 'tweets.js', 'tweets-part*.js'] 485 | files_paths_input_tweets = [] 486 | for input_tweets_file_template in input_tweets_file_templates: 487 | files_paths_input_tweets += glob.glob(os.path.join(dir_path_input_data, input_tweets_file_template)) 488 | if len(files_paths_input_tweets)==0: 489 | print(f'Error: no files matching {input_tweets_file_templates} in {dir_path_input_data}') 490 | exit() 491 | return files_paths_input_tweets 492 | 493 | 494 | def find_dir_input_media(dir_path_input_data): 495 | input_media_dir_templates = ['tweet_media', 'tweets_media'] 496 | input_media_dirs = [] 497 | for input_media_dir_template in input_media_dir_templates: 498 | input_media_dirs += glob.glob(os.path.join(dir_path_input_data, input_media_dir_template)) 499 | if len(input_media_dirs) == 0: 500 | print(f'Error: no folders matching {input_media_dir_templates} in {dir_path_input_data}') 501 | exit() 502 | if len(input_media_dirs) > 1: 503 | print(f'Error: multiple folders matching {input_media_dir_templates} in {dir_path_input_data}') 504 | exit() 505 | return input_media_dirs[0] 506 | 507 | 508 | def download_file_if_larger(url, filename, index, count, sleep_time): 509 | """Attempts to download from the specified URL. Overwrites file if larger. 510 | Returns whether the file is now known to be the largest available, and the number of bytes downloaded. 511 | """ 512 | requests = import_module('requests') 513 | imagesize = import_module('imagesize') 514 | 515 | pref = f'{index:3d}/{count:3d} {filename}: ' 516 | # Sleep briefly, in an attempt to minimize the possibility of trigging some auto-cutoff mechanism 517 | if index > 1: 518 | print(f'{pref}Sleeping...', end='\r') 519 | time.sleep(sleep_time) 520 | # Request the URL (in stream mode so that we can conditionally abort depending on the headers) 521 | print(f'{pref}Requesting headers for {url}...', end='\r') 522 | byte_size_before = os.path.getsize(filename) 523 | try: 524 | with requests.get(url, stream=True, timeout=2) as res: 525 | if not res.status_code == 200: 526 | # Try to get content of response as `res.text`. 527 | # For twitter.com, this will be empty in most (all?) cases. 528 | # It is successfully tested with error responses from other domains. 529 | raise Exception(f'Download failed with status "{res.status_code} {res.reason}". ' 530 | f'Response content: "{res.text}"') 531 | byte_size_after = int(res.headers['content-length']) 532 | if byte_size_after != byte_size_before: 533 | # Proceed with the full download 534 | tmp_filename = filename+'.tmp' 535 | print(f'{pref}Downloading {url}... ', end='\r') 536 | with open(tmp_filename,'wb') as f: 537 | shutil.copyfileobj(res.raw, f) 538 | post = f'{byte_size_after/2**20:.1f}MB downloaded' 539 | width_before, height_before = imagesize.get(filename) 540 | width_after, height_after = imagesize.get(tmp_filename) 541 | pixels_before, pixels_after = width_before * height_before, width_after * height_after 542 | pixels_percentage_increase = 100.0 * (pixels_after - pixels_before) / pixels_before 543 | 544 | if width_before == -1 and height_before == -1 and width_after == -1 and height_after == -1: 545 | # could not check size of both versions, probably a video or unsupported image format 546 | os.replace(tmp_filename, filename) 547 | bytes_percentage_increase = 100.0 * (byte_size_after - byte_size_before) / byte_size_before 548 | logging.info(f'{pref}SUCCESS. New version is {bytes_percentage_increase:3.0f}% ' 549 | f'larger in bytes (pixel comparison not possible). {post}') 550 | return True, byte_size_after 551 | elif width_before == -1 or height_before == -1 or width_after == -1 or height_after == -1: 552 | # could not check size of one version, this should not happen (corrupted download?) 553 | logging.info(f'{pref}SKIPPED. Pixel size comparison inconclusive: ' 554 | f'{width_before}*{height_before}px vs. {width_after}*{height_after}px. {post}') 555 | return False, byte_size_after 556 | elif pixels_after >= pixels_before: 557 | os.replace(tmp_filename, filename) 558 | bytes_percentage_increase = 100.0 * (byte_size_after - byte_size_before) / byte_size_before 559 | if bytes_percentage_increase >= 0: 560 | logging.info(f'{pref}SUCCESS. New version is {bytes_percentage_increase:3.0f}% larger in bytes ' 561 | f'and {pixels_percentage_increase:3.0f}% larger in pixels. {post}') 562 | else: 563 | logging.info(f'{pref}SUCCESS. New version is actually {-bytes_percentage_increase:3.0f}% ' 564 | f'smaller in bytes but {pixels_percentage_increase:3.0f}% ' 565 | f'larger in pixels. {post}') 566 | return True, byte_size_after 567 | else: 568 | logging.info(f'{pref}SKIPPED. Online version has {-pixels_percentage_increase:3.0f}% ' 569 | f'smaller pixel size. {post}') 570 | return True, byte_size_after 571 | else: 572 | logging.info(f'{pref}SKIPPED. Online version is same byte size, assuming same content. Not downloaded.') 573 | return True, 0 574 | except Exception as err: 575 | logging.error(f"{pref}FAIL. Media couldn't be retrieved from {url} because of exception: {err}") 576 | return False, 0 577 | 578 | 579 | def download_larger_media(media_sources, paths: PathConfig): 580 | """Uses (filename, URL) tuples in media_sources to download files from remote storage. 581 | Aborts downloads if the remote file is the same size or smaller than the existing local version. 582 | Retries the failed downloads several times, with increasing pauses between each to avoid being blocked. 583 | """ 584 | # Log to file as well as the console 585 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(message)s') 586 | mkdirs_for_file(paths.file_download_log) 587 | logfile_handler = logging.FileHandler(filename=paths.file_download_log, mode='w') 588 | logfile_handler.setLevel(logging.INFO) 589 | logging.getLogger().addHandler(logfile_handler) 590 | # Download new versions 591 | start_time = time.time() 592 | total_bytes_downloaded = 0 593 | sleep_time = 0.25 594 | remaining_tries = 5 595 | while remaining_tries > 0: 596 | number_of_files = len(media_sources) 597 | success_count = 0 598 | retries = [] 599 | for index, (local_media_path, media_url) in enumerate(media_sources): 600 | success, bytes_downloaded = download_file_if_larger( 601 | media_url, local_media_path, index + 1, number_of_files, sleep_time 602 | ) 603 | if success: 604 | success_count += 1 605 | else: 606 | retries.append((local_media_path, media_url)) 607 | total_bytes_downloaded += bytes_downloaded 608 | 609 | # show % done and estimated remaining time: 610 | time_elapsed: float = time.time() - start_time 611 | estimated_time_per_file: float = time_elapsed / (index + 1) 612 | estimated_time_remaining: datetime.datetime = \ 613 | datetime.datetime.fromtimestamp( 614 | (number_of_files - (index + 1)) * estimated_time_per_file, 615 | tz=datetime.timezone.utc 616 | ) 617 | if estimated_time_remaining.hour >= 1: 618 | time_remaining_string: str = \ 619 | f"{estimated_time_remaining.hour} hour{'' if estimated_time_remaining.hour == 1 else 's'} " \ 620 | f"{estimated_time_remaining.minute} minute{'' if estimated_time_remaining.minute == 1 else 's'}" 621 | elif estimated_time_remaining.minute >= 1: 622 | time_remaining_string: str = \ 623 | f"{estimated_time_remaining.minute} minute{'' if estimated_time_remaining.minute == 1 else 's'} " \ 624 | f"{estimated_time_remaining.second} second{'' if estimated_time_remaining.second == 1 else 's'}" 625 | else: 626 | time_remaining_string: str = \ 627 | f"{estimated_time_remaining.second} second{'' if estimated_time_remaining.second == 1 else 's'}" 628 | 629 | if index + 1 == number_of_files: 630 | print(' 100 % done.') 631 | else: 632 | print(f' {(100*(index+1)/number_of_files):.1f} % done, about {time_remaining_string} remaining...') 633 | 634 | media_sources = retries 635 | remaining_tries -= 1 636 | sleep_time += 2 637 | logging.info(f'\n{success_count} of {number_of_files} tested media files ' 638 | f'are known to be the best-quality available.\n') 639 | if len(retries) == 0: 640 | break 641 | if remaining_tries > 0: 642 | print(f'----------------------\n\nRetrying the ones that failed, with a longer sleep. ' 643 | f'{remaining_tries} tries remaining.\n') 644 | end_time = time.time() 645 | 646 | logging.info(f'Total downloaded: {total_bytes_downloaded/2**20:.1f}MB = {total_bytes_downloaded/2**30:.2f}GB') 647 | logging.info(f'Time taken: {end_time-start_time:.0f}s') 648 | print(f'Wrote log to {paths.file_download_log}') 649 | 650 | 651 | def parse_tweets(username, users, html_template, paths: PathConfig, lang): 652 | """Read tweets from paths.files_input_tweets, write to *.md and *.html. 653 | Copy the media used to paths.dir_output_media. 654 | Collect user_id:user_handle mappings for later use, in 'users'. 655 | Returns the mapping from media filename to best-quality URL. 656 | """ 657 | tweets = [] 658 | media_sources = [] 659 | for tweets_js_filename in paths.files_input_tweets: 660 | jsons = read_json_from_js_file(tweets_js_filename) 661 | for tweet in jsons: 662 | result = convert_tweet(tweet, username, media_sources, users, paths) 663 | if result[0]: 664 | # check valid tweet by timestamp 665 | tweets.append(result) 666 | tweets.sort(key=lambda tup: tup[0]) # oldest first 667 | 668 | # Group tweets by month 669 | grouped_tweets = defaultdict(list) 670 | for timestamp, ids, md, html, in_reply_to, quote, retweet in tweets: 671 | # Use a (markdown) filename that can be imported into Jekyll: YYYY-MM-DD-your-title-here.md 672 | dt = datetime.datetime.fromtimestamp(timestamp) 673 | grouped_tweets[(dt.year, dt.month)].append((ids, md, in_reply_to, quote, retweet)) 674 | 675 | final_md = [] 676 | for (year, month), content in grouped_tweets.items(): 677 | # Write into *.md files 678 | for id, md, in_reply_to, quote, retweet in content: 679 | final_md.append((id, md, in_reply_to, quote, retweet)) 680 | #md_path = paths.create_path_for_file_output_tweets(year, month, format="md") 681 | md_path = "tweets.md" 682 | 683 | write_json(md_path, final_md, lang) 684 | 685 | 686 | # Write into *.html files 687 | # no need, as we gonna output markdowns only 688 | 689 | print(f'Wrote {len(tweets)} tweets to *.md and *.html, ' 690 | f'with images and video embedded from {paths.dir_output_media}') 691 | 692 | return media_sources 693 | 694 | 695 | def collect_user_ids_from_followings(paths) -> list: 696 | """ 697 | Collect all user ids that appear in the followings archive data. 698 | (For use in bulk online lookup from Twitter.) 699 | """ 700 | # read JSON file from archive 701 | following_json = read_json_from_js_file(os.path.join(paths.dir_input_data, 'following.js')) 702 | # collect all user ids in a list 703 | following_ids = [] 704 | for follow in following_json: 705 | if 'following' in follow and 'accountId' in follow['following']: 706 | following_ids.append(follow['following']['accountId']) 707 | return following_ids 708 | 709 | 710 | def parse_followings(users, user_id_url_template, paths: PathConfig): 711 | """Parse paths.dir_input_data/following.js, write to paths.file_output_following. 712 | """ 713 | following = [] 714 | following_json = read_json_from_js_file(os.path.join(paths.dir_input_data, 'following.js')) 715 | following_ids = [] 716 | for follow in following_json: 717 | if 'following' in follow and 'accountId' in follow['following']: 718 | following_ids.append(follow['following']['accountId']) 719 | for following_id in following_ids: 720 | handle = users[following_id].handle if following_id in users else '~unknown~handle~' 721 | following.append(handle + ' ' + user_id_url_template.format(following_id)) 722 | following.sort() 723 | following_output_path = paths.create_path_for_file_output_single(format="txt", kind="following") 724 | with open_and_mkdirs(following_output_path) as f: 725 | f.write('\n'.join(following)) 726 | print(f"Wrote {len(following)} accounts to {following_output_path}") 727 | 728 | 729 | def collect_user_ids_from_followers(paths) -> list: 730 | """ 731 | Collect all user ids that appear in the followers archive data. 732 | (For use in bulk online lookup from Twitter.) 733 | """ 734 | # read JSON file from archive 735 | follower_json = read_json_from_js_file(os.path.join(paths.dir_input_data, 'follower.js')) 736 | # collect all user ids in a list 737 | follower_ids = [] 738 | for follower in follower_json: 739 | if 'follower' in follower and 'accountId' in follower['follower']: 740 | follower_ids.append(follower['follower']['accountId']) 741 | return follower_ids 742 | 743 | 744 | def parse_followers(users, user_id_url_template, paths: PathConfig): 745 | """Parse paths.dir_input_data/followers.js, write to paths.file_output_followers. 746 | """ 747 | followers = [] 748 | follower_json = read_json_from_js_file(os.path.join(paths.dir_input_data, 'follower.js')) 749 | follower_ids = [] 750 | for follower in follower_json: 751 | if 'follower' in follower and 'accountId' in follower['follower']: 752 | follower_ids.append(follower['follower']['accountId']) 753 | for follower_id in follower_ids: 754 | handle = users[follower_id].handle if follower_id in users else '~unknown~handle~' 755 | followers.append(handle + ' ' + user_id_url_template.format(follower_id)) 756 | followers.sort() 757 | followers_output_path = paths.create_path_for_file_output_single(format="txt", kind="followers") 758 | with open_and_mkdirs(followers_output_path) as f: 759 | f.write('\n'.join(followers)) 760 | print(f"Wrote {len(followers)} accounts to {followers_output_path}") 761 | 762 | 763 | def chunks(lst: list, n: int): 764 | """Yield successive n-sized chunks from lst.""" 765 | for i in range(0, len(lst), n): 766 | yield lst[i:i + n] 767 | 768 | 769 | def collect_user_ids_from_direct_messages(paths) -> list: 770 | """ 771 | Collect all user ids that appear in the direct messages archive data. 772 | (For use in bulk online lookup from Twitter.) 773 | """ 774 | # read JSON file from archive 775 | dms_json = read_json_from_js_file(os.path.join(paths.dir_input_data, 'direct-messages.js')) 776 | # collect all user ids in a set 777 | dms_user_ids = set() 778 | for conversation in dms_json: 779 | if 'dmConversation' in conversation and 'conversationId' in conversation['dmConversation']: 780 | dm_conversation = conversation['dmConversation'] 781 | conversation_id = dm_conversation['conversationId'] 782 | user1_id, user2_id = conversation_id.split('-') 783 | dms_user_ids.add(user1_id) 784 | dms_user_ids.add(user2_id) 785 | return list(dms_user_ids) 786 | 787 | 788 | def parse_direct_messages(username, users, user_id_url_template, paths: PathConfig): 789 | """Parse paths.dir_input_data/direct-messages.js, write to one markdown file per conversation. 790 | """ 791 | # read JSON file 792 | dms_json = read_json_from_js_file(os.path.join(paths.dir_input_data, 'direct-messages.js')) 793 | 794 | # Parse the DMs and store the messages in a dict 795 | conversations_messages = defaultdict(list) 796 | for conversation in dms_json: 797 | if 'dmConversation' in conversation and 'conversationId' in conversation['dmConversation']: 798 | dm_conversation = conversation['dmConversation'] 799 | conversation_id = dm_conversation['conversationId'] 800 | user1_id, user2_id = conversation_id.split('-') 801 | messages = [] 802 | if 'messages' in dm_conversation: 803 | for message in dm_conversation['messages']: 804 | if 'messageCreate' in message: 805 | message_create = message['messageCreate'] 806 | if all(tag in message_create for tag in ['senderId', 'recipientId', 'text', 'createdAt']): 807 | from_id = message_create['senderId'] 808 | to_id = message_create['recipientId'] 809 | body = message_create['text'] 810 | # replace t.co URLs with their original versions 811 | if 'urls' in message_create and len(message_create['urls']) > 0: 812 | for url in message_create['urls']: 813 | if 'url' in url and 'expanded' in url: 814 | expanded_url = url['expanded'] 815 | body = body.replace(url['url'], expanded_url) 816 | # escape message body for markdown rendering: 817 | body_markdown = escape_markdown(body) 818 | # replace image URLs with image links to local files 819 | if 'mediaUrls' in message_create \ 820 | and len(message_create['mediaUrls']) == 1 \ 821 | and 'urls' in message_create: 822 | original_expanded_url = message_create['urls'][0]['expanded'] 823 | message_id = message_create['id'] 824 | media_hash_and_type = message_create['mediaUrls'][0].split('/')[-1] 825 | media_id = message_create['mediaUrls'][0].split('/')[-2] 826 | archive_media_filename = f'{message_id}-{media_hash_and_type}' 827 | new_url = os.path.join(paths.dir_output_media, archive_media_filename) 828 | archive_media_path = \ 829 | os.path.join(paths.dir_input_data, 'direct_messages_media', archive_media_filename) 830 | if os.path.isfile(archive_media_path): 831 | # found a matching image, use this one 832 | if not os.path.isfile(new_url): 833 | shutil.copy(archive_media_path, new_url) 834 | image_markdown = f'\n![]({new_url})\n' 835 | body_markdown = body_markdown.replace( 836 | escape_markdown(original_expanded_url), image_markdown 837 | ) 838 | 839 | # Save the online location of the best-quality version of this file, 840 | # for later upgrading if wanted 841 | best_quality_url = \ 842 | f'https://ton.twitter.com/i//ton/data/dm/' \ 843 | f'{message_id}/{media_id}/{media_hash_and_type}' 844 | # there is no ':orig' here, the url without any suffix has the original size 845 | 846 | # TODO: a cookie (and a 'Referer: https://twitter.com' header) 847 | # is needed to retrieve it, so the url might be useless anyway... 848 | 849 | # WARNING: Do not uncomment the statement below until the cookie problem is solved! 850 | # media_sources.append( 851 | # ( 852 | # os.path.join(output_media_folder_name, archive_media_filename), 853 | # best_quality_url 854 | # ) 855 | # ) 856 | 857 | else: 858 | archive_media_paths = glob.glob( 859 | os.path.join(paths.dir_input_data, 'direct_messages_media', message_id + '*')) 860 | if len(archive_media_paths) > 0: 861 | for archive_media_path in archive_media_paths: 862 | archive_media_filename = os.path.split(archive_media_path)[-1] 863 | media_url = os.path.join(paths.dir_output_media, archive_media_filename) 864 | if not os.path.isfile(media_url): 865 | shutil.copy(archive_media_path, media_url) 866 | video_markdown = f'\n\n' 868 | body_markdown = body_markdown.replace( 869 | escape_markdown(original_expanded_url), video_markdown 870 | ) 871 | 872 | # TODO: maybe also save the online location of the best-quality version for videos? 873 | # (see above) 874 | 875 | else: 876 | print(f'Warning: missing local file: {archive_media_path}. ' 877 | f'Using original link instead: {original_expanded_url})') 878 | 879 | created_at = message_create['createdAt'] # example: 2022-01-27T15:58:52.744Z 880 | timestamp = \ 881 | int(round(datetime.datetime.strptime(created_at, '%Y-%m-%dT%X.%fZ').timestamp())) 882 | 883 | from_handle = escape_markdown(users[from_id].handle) if from_id in users \ 884 | else user_id_url_template.format(from_id) 885 | to_handle = escape_markdown(users[to_id].handle) if to_id in users \ 886 | else user_id_url_template.format(to_id) 887 | 888 | # make the body a quote 889 | body_markdown = '> ' + '\n> '.join(body_markdown.splitlines()) 890 | message_markdown = f'{from_handle} -> {to_handle}: ({created_at}) \n\n' \ 891 | f'{body_markdown}' 892 | messages.append((timestamp, message_markdown)) 893 | 894 | # find identifier for the conversation 895 | other_user_id = user2_id if (user1_id in users and users[user1_id].handle == username) else user1_id 896 | 897 | # collect messages per identifying user in conversations_messages dict 898 | conversations_messages[other_user_id].extend(messages) 899 | 900 | # output as one file per conversation (or part of long conversation) 901 | num_written_messages = 0 902 | num_written_files = 0 903 | for other_user_id, messages in conversations_messages.items(): 904 | # sort messages by timestamp 905 | messages.sort(key=lambda tup: tup[0]) 906 | 907 | other_user_name = escape_markdown(users[other_user_id].handle) if other_user_id in users \ 908 | else user_id_url_template.format(other_user_id) 909 | 910 | other_user_short_name: str = users[other_user_id].handle if other_user_id in users else other_user_id 911 | 912 | escaped_username = escape_markdown(username) 913 | 914 | # if there are more than 1000 messages, the conversation was split up in the twitter archive. 915 | # following this standard, also split up longer conversations in the output files: 916 | 917 | if len(messages) > 1000: 918 | for chunk_index, chunk in enumerate(chunks(messages, 1000)): 919 | markdown = '' 920 | markdown += f'### Conversation between {escaped_username} and {other_user_name}, ' \ 921 | f'part {chunk_index+1}: ###\n\n----\n\n' 922 | markdown += '\n\n----\n\n'.join(md for _, md in chunk) 923 | conversation_output_path = paths.create_path_for_file_output_dms(name=other_user_short_name, index=(chunk_index + 1), format="md") 924 | 925 | # write part to a markdown file 926 | with open_and_mkdirs(conversation_output_path) as f: 927 | f.write(markdown) 928 | print(f'Wrote {len(chunk)} messages to {conversation_output_path}') 929 | num_written_files += 1 930 | 931 | else: 932 | markdown = '' 933 | markdown += f'### Conversation between {escaped_username} and {other_user_name}: ###\n\n----\n\n' 934 | markdown += '\n\n----\n\n'.join(md for _, md in messages) 935 | conversation_output_path = paths.create_path_for_file_output_dms(name=other_user_short_name, format="md") 936 | 937 | with open_and_mkdirs(conversation_output_path) as f: 938 | f.write(markdown) 939 | print(f'Wrote {len(messages)} messages to {conversation_output_path}') 940 | num_written_files += 1 941 | 942 | num_written_messages += len(messages) 943 | 944 | print(f"\nWrote {len(conversations_messages)} direct message conversations " 945 | f"({num_written_messages} total messages) to {num_written_files} markdown files\n") 946 | 947 | 948 | def make_conversation_name_safe_for_filename(conversation_name: str) -> str: 949 | """ 950 | Remove/replace characters that could be unsafe in filenames 951 | """ 952 | forbidden_chars = \ 953 | ['"', "'", '*', '/', '\\', ':', '<', '>', '?', '|', '!', '@', ';', ',', '=', '.', '\n', '\r', '\t'] 954 | new_conversation_name = '' 955 | for char in conversation_name: 956 | if char in forbidden_chars: 957 | new_conversation_name = new_conversation_name + '_' 958 | elif char.isspace(): 959 | # replace spaces with underscores 960 | new_conversation_name = new_conversation_name + '_' 961 | elif char == 0x7F or (0x1F >= ord(char) >= 0x00): 962 | # 0x00 - 0x1F and 0x7F are also forbidden, just discard them 963 | continue 964 | else: 965 | new_conversation_name = new_conversation_name + char 966 | 967 | return new_conversation_name 968 | 969 | 970 | def find_group_dm_conversation_participant_ids(conversation: dict) -> set: 971 | """ 972 | Find IDs of all participating Users in a group direct message conversation 973 | """ 974 | group_user_ids = set() 975 | if 'dmConversation' in conversation and 'conversationId' in conversation['dmConversation']: 976 | dm_conversation = conversation['dmConversation'] 977 | if 'messages' in dm_conversation: 978 | for message in dm_conversation['messages']: 979 | if 'messageCreate' in message: 980 | group_user_ids.add(message['messageCreate']['senderId']) 981 | elif 'joinConversation' in message: 982 | group_user_ids.add(message['joinConversation']['initiatingUserId']) 983 | for participant_id in message['joinConversation']['participantsSnapshot']: 984 | group_user_ids.add(participant_id) 985 | elif "participantsJoin" in message: 986 | group_user_ids.add(message['participantsJoin']['initiatingUserId']) 987 | for participant_id in message['participantsJoin']['userIds']: 988 | group_user_ids.add(participant_id) 989 | return group_user_ids 990 | 991 | 992 | def collect_user_ids_from_group_direct_messages(paths) -> list: 993 | """ 994 | Collect all user ids that appear in the group direct messages archive data. 995 | (For use in bulk online lookup from Twitter.) 996 | """ 997 | # read JSON file from archive 998 | group_dms_json = read_json_from_js_file(os.path.join(paths.dir_input_data, 'direct-messages-group.js')) 999 | # collect all user ids in a set 1000 | group_dms_user_ids = set() 1001 | for conversation in group_dms_json: 1002 | participants = find_group_dm_conversation_participant_ids(conversation) 1003 | for participant_id in participants: 1004 | group_dms_user_ids.add(participant_id) 1005 | return list(group_dms_user_ids) 1006 | 1007 | 1008 | def parse_group_direct_messages(username, users, user_id_url_template, paths): 1009 | """Parse data_folder/direct-messages-group.js, write to one markdown file per conversation. 1010 | """ 1011 | # read JSON file from archive 1012 | group_dms_json = read_json_from_js_file(os.path.join(paths.dir_input_data, 'direct-messages-group.js')) 1013 | 1014 | # Parse the group DMs, store messages and metadata in a dict 1015 | group_conversations_messages = defaultdict(list) 1016 | group_conversations_metadata = defaultdict(dict) 1017 | for conversation in group_dms_json: 1018 | if 'dmConversation' in conversation and 'conversationId' in conversation['dmConversation']: 1019 | dm_conversation = conversation['dmConversation'] 1020 | conversation_id = dm_conversation['conversationId'] 1021 | participants = find_group_dm_conversation_participant_ids(conversation) 1022 | participant_names = [] 1023 | for participant_id in participants: 1024 | if participant_id in users: 1025 | participant_names.append(users[participant_id].handle) 1026 | else: 1027 | participant_names.append(user_id_url_template.format(participant_id)) 1028 | 1029 | # save names in metadata 1030 | group_conversations_metadata[conversation_id]['participants'] = participants 1031 | group_conversations_metadata[conversation_id]['participant_names'] = participant_names 1032 | group_conversations_metadata[conversation_id]['conversation_names'] = [(0, conversation_id)] 1033 | group_conversations_metadata[conversation_id]['participant_message_count'] = defaultdict(int) 1034 | for participant_id in participants: 1035 | # init every participant's message count with 0, so that users with no activity are not ignored 1036 | group_conversations_metadata[conversation_id]['participant_message_count'][participant_id] = 0 1037 | messages = [] 1038 | if 'messages' in dm_conversation: 1039 | for message in dm_conversation['messages']: 1040 | if 'messageCreate' in message: 1041 | message_create = message['messageCreate'] 1042 | if all(tag in message_create for tag in ['senderId', 'text', 'createdAt']): 1043 | from_id = message_create['senderId'] 1044 | # count how many messages this user has sent to the group 1045 | group_conversations_metadata[conversation_id]['participant_message_count'][from_id] += 1 1046 | body = message_create['text'] 1047 | # replace t.co URLs with their original versions 1048 | if 'urls' in message_create: 1049 | for url in message_create['urls']: 1050 | if 'url' in url and 'expanded' in url: 1051 | expanded_url = url['expanded'] 1052 | body = body.replace(url['url'], expanded_url) 1053 | # escape message body for markdown rendering: 1054 | body_markdown = escape_markdown(body) 1055 | # replace image URLs with image links to local files 1056 | if 'mediaUrls' in message_create \ 1057 | and len(message_create['mediaUrls']) == 1 \ 1058 | and 'urls' in message_create: 1059 | original_expanded_url = message_create['urls'][0]['expanded'] 1060 | message_id = message_create['id'] 1061 | media_hash_and_type = message_create['mediaUrls'][0].split('/')[-1] 1062 | media_id = message_create['mediaUrls'][0].split('/')[-2] 1063 | archive_media_filename = f'{message_id}-{media_hash_and_type}' 1064 | new_url = os.path.join(paths.dir_output_media, archive_media_filename) 1065 | archive_media_path = \ 1066 | os.path.join(paths.dir_input_data, 'direct_messages_group_media', 1067 | archive_media_filename) 1068 | if os.path.isfile(archive_media_path): 1069 | # found a matching image, use this one 1070 | if not os.path.isfile(new_url): 1071 | shutil.copy(archive_media_path, new_url) 1072 | image_markdown = f'\n![]({new_url})\n' 1073 | body_markdown = body_markdown.replace( 1074 | escape_markdown(original_expanded_url), image_markdown 1075 | ) 1076 | 1077 | # Save the online location of the best-quality version of this file, 1078 | # for later upgrading if wanted 1079 | best_quality_url = \ 1080 | f'https://ton.twitter.com/i//ton/data/dm/' \ 1081 | f'{message_id}/{media_id}/{media_hash_and_type}' 1082 | # there is no ':orig' here, the url without any suffix has the original size 1083 | 1084 | # TODO: a cookie (and a 'Referer: https://twitter.com' header) 1085 | # is needed to retrieve it, so the url might be useless anyway... 1086 | 1087 | # WARNING: Do not uncomment the statement below until the cookie problem is solved! 1088 | # media_sources.append( 1089 | # ( 1090 | # os.path.join(output_media_folder_name, archive_media_filename), 1091 | # best_quality_url 1092 | # ) 1093 | # ) 1094 | 1095 | else: 1096 | archive_media_paths = glob.glob( 1097 | os.path.join(paths.dir_input_data, 'direct_messages_group_media', 1098 | message_id + '*')) 1099 | if len(archive_media_paths) > 0: 1100 | for archive_media_path in archive_media_paths: 1101 | archive_media_filename = os.path.split(archive_media_path)[-1] 1102 | media_url = os.path.join(paths.dir_output_media, 1103 | archive_media_filename) 1104 | if not os.path.isfile(media_url): 1105 | shutil.copy(archive_media_path, media_url) 1106 | video_markdown = f'\n\n' 1108 | body_markdown = body_markdown.replace( 1109 | escape_markdown(original_expanded_url), video_markdown 1110 | ) 1111 | 1112 | # TODO: maybe also save the online location of the best-quality version for videos? 1113 | # (see above) 1114 | 1115 | else: 1116 | print(f'Warning: missing local file: {archive_media_path}. ' 1117 | f'Using original link instead: {original_expanded_url})') 1118 | created_at = message_create['createdAt'] # example: 2022-01-27T15:58:52.744Z 1119 | timestamp = int(round( 1120 | datetime.datetime.strptime(created_at, '%Y-%m-%dT%X.%fZ').timestamp() 1121 | )) 1122 | from_handle = escape_markdown(users[from_id].handle) if from_id in users \ 1123 | else user_id_url_template.format(from_id) 1124 | # make the body a quote 1125 | body_markdown = '> ' + '\n> '.join(body_markdown.splitlines()) 1126 | message_markdown = f'{from_handle}: ({created_at})\n\n' \ 1127 | f'{body_markdown}' 1128 | messages.append((timestamp, message_markdown)) 1129 | elif "conversationNameUpdate" in message: 1130 | conversation_name_update = message['conversationNameUpdate'] 1131 | if all(tag in conversation_name_update for tag in ['initiatingUserId', 'name', 'createdAt']): 1132 | from_id = conversation_name_update['initiatingUserId'] 1133 | body_markdown = f"_changed group name to: {escape_markdown(conversation_name_update['name'])}_" 1134 | created_at = conversation_name_update['createdAt'] # example: 2022-01-27T15:58:52.744Z 1135 | timestamp = int(round( 1136 | datetime.datetime.strptime(created_at, '%Y-%m-%dT%X.%fZ').timestamp() 1137 | )) 1138 | from_handle = escape_markdown(users[from_id].handle) if from_id in users \ 1139 | else user_id_url_template.format(from_id) 1140 | message_markdown = f'{from_handle}: ({created_at})\n\n{body_markdown}' 1141 | messages.append((timestamp, message_markdown)) 1142 | # save metadata about name change: 1143 | group_conversations_metadata[conversation_id]['conversation_names'].append( 1144 | (timestamp, conversation_name_update['name']) 1145 | ) 1146 | elif "joinConversation" in message: 1147 | join_conversation = message['joinConversation'] 1148 | if all(tag in join_conversation for tag in ['initiatingUserId', 'createdAt']): 1149 | from_id = join_conversation['initiatingUserId'] 1150 | created_at = join_conversation['createdAt'] # example: 2022-01-27T15:58:52.744Z 1151 | timestamp = int(round( 1152 | datetime.datetime.strptime(created_at, '%Y-%m-%dT%X.%fZ').timestamp() 1153 | )) 1154 | from_handle = escape_markdown(users[from_id].handle) if from_id in users \ 1155 | else user_id_url_template.format(from_id) 1156 | escaped_username = escape_markdown(username) 1157 | body_markdown = f'_{from_handle} added {escaped_username} to the group_' 1158 | message_markdown = f'{from_handle}: ({created_at})\n\n{body_markdown}' 1159 | messages.append((timestamp, message_markdown)) 1160 | elif "participantsJoin" in message: 1161 | participants_join = message['participantsJoin'] 1162 | if all(tag in participants_join for tag in ['initiatingUserId', 'userIds', 'createdAt']): 1163 | from_id = participants_join['initiatingUserId'] 1164 | created_at = participants_join['createdAt'] # example: 2022-01-27T15:58:52.744Z 1165 | timestamp = int(round( 1166 | datetime.datetime.strptime(created_at, '%Y-%m-%dT%X.%fZ').timestamp() 1167 | )) 1168 | from_handle = escape_markdown(users[from_id].handle) if from_id in users \ 1169 | else user_id_url_template.format(from_id) 1170 | joined_ids = participants_join['userIds'] 1171 | joined_handles = [escape_markdown(users[joined_id].handle) if joined_id in users 1172 | else user_id_url_template.format(joined_id) for joined_id in joined_ids] 1173 | name_list = ', '.join(joined_handles[:-1]) + \ 1174 | (f' and {joined_handles[-1]}' if len(joined_handles) > 1 else 1175 | joined_handles[0]) 1176 | body_markdown = f'_{from_handle} added {name_list} to the group_' 1177 | message_markdown = f'{from_handle}: ({created_at})\n\n{body_markdown}' 1178 | messages.append((timestamp, message_markdown)) 1179 | elif "participantsLeave" in message: 1180 | participants_leave = message['participantsLeave'] 1181 | if all(tag in participants_leave for tag in ['userIds', 'createdAt']): 1182 | created_at = participants_leave['createdAt'] # example: 2022-01-27T15:58:52.744Z 1183 | timestamp = int(round( 1184 | datetime.datetime.strptime(created_at, '%Y-%m-%dT%X.%fZ').timestamp() 1185 | )) 1186 | left_ids = participants_leave['userIds'] 1187 | left_handles = [escape_markdown(users[left_id].handle) if left_id in users 1188 | else user_id_url_template.format(left_id) for left_id in left_ids] 1189 | name_list = ', '.join(left_handles[:-1]) + \ 1190 | (f' and {left_handles[-1]}' if len(left_handles) > 1 else 1191 | left_handles[0]) 1192 | body_markdown = f'_{name_list} left the group_' 1193 | message_markdown = f'{name_list}: ({created_at})\n\n{body_markdown}' 1194 | messages.append((timestamp, message_markdown)) 1195 | 1196 | # collect messages per conversation in group_conversations_messages dict 1197 | group_conversations_messages[conversation_id].extend(messages) 1198 | 1199 | # output as one file per conversation (or part of long conversation) 1200 | num_written_messages = 0 1201 | num_written_files = 0 1202 | for conversation_id, messages in group_conversations_messages.items(): 1203 | # sort messages by timestamp 1204 | messages.sort(key=lambda tup: tup[0]) 1205 | # create conversation name for use in filename: 1206 | # first, try to find an official name in the parsed conversation data 1207 | 1208 | # Not-so-fun fact: 1209 | # If the name was set before the archive's owner joined the group, the name is not included 1210 | # in the archive data and can't be found anywhere (except by looking it up from twitter, 1211 | # and that would probably need a cookie). So there are many groups that do actually have a name, 1212 | # but it can't be used here because we don't know it. 1213 | 1214 | group_conversations_metadata[conversation_id]['conversation_names'].sort(key=lambda tup: tup[0], reverse=True) 1215 | official_name = group_conversations_metadata[conversation_id]['conversation_names'][0][1] 1216 | safe_group_name = make_conversation_name_safe_for_filename(official_name) 1217 | if len(safe_group_name) < 2: 1218 | # discard name if it's too short (because of collision risk) 1219 | group_name = conversation_id 1220 | else: 1221 | group_name = safe_group_name 1222 | 1223 | if group_name == conversation_id: 1224 | # try to make a nice list of participant handles for the conversation name 1225 | handles = [] 1226 | for participant_id, message_count in \ 1227 | group_conversations_metadata[conversation_id]['participant_message_count'].items(): 1228 | if participant_id in users: 1229 | participant_handle = users[participant_id].handle 1230 | if participant_handle != username: 1231 | handles.append((participant_handle, message_count)) 1232 | # sort alphabetically by handle first, for a more deterministic order 1233 | handles.sort(key=lambda tup: tup[0]) 1234 | # sort so that the most active users are at the start of the list 1235 | handles.sort(key=lambda tup: tup[1], reverse=True) 1236 | if len(handles) == 1: 1237 | group_name = \ 1238 | f'{handles[0][0]}_and_{len(group_conversations_metadata[conversation_id]["participants"]) - 1}_more' 1239 | elif len(handles) == 2 and len(group_conversations_metadata[conversation_id]["participants"]) == 3: 1240 | group_name = f'{handles[0][0]}_and_{handles[1][0]}_and_{username}' 1241 | elif len(handles) >= 2: 1242 | group_name = \ 1243 | f'{handles[0][0]}_and_{handles[1][0]}_and' \ 1244 | f'_{len(group_conversations_metadata[conversation_id]["participants"]) - 2}_more' 1245 | else: 1246 | # just use the conversation id 1247 | group_name = conversation_id 1248 | 1249 | # create a list of names of the form '@name1, @name2 and @name3' 1250 | # to use as a headline in the output file 1251 | escaped_participant_names = [ 1252 | escape_markdown(participant_name) 1253 | for participant_name in group_conversations_metadata[conversation_id]['participant_names'] 1254 | ] 1255 | name_list = ', '.join(escaped_participant_names[:-1]) + \ 1256 | (f' and {escaped_participant_names[-1]}' 1257 | if len(escaped_participant_names) > 1 1258 | else escaped_participant_names[0]) 1259 | 1260 | if len(messages) > 1000: 1261 | for chunk_index, chunk in enumerate(chunks(messages, 1000)): 1262 | markdown = '' 1263 | markdown += f'## {official_name} ##\n\n' 1264 | markdown += f'### Group conversation between {name_list}, part {chunk_index + 1}: ###\n\n----\n\n' 1265 | markdown += '\n\n----\n\n'.join(md for _, md in chunk) 1266 | conversation_output_filename = paths.create_path_for_file_output_dms( 1267 | name=group_name, format="md", kind="DMs-Group", index=chunk_index + 1 1268 | ) 1269 | 1270 | # write part to a markdown file 1271 | with open_and_mkdirs(conversation_output_filename) as f: 1272 | f.write(markdown) 1273 | print(f'Wrote {len(chunk)} messages to {conversation_output_filename}') 1274 | num_written_files += 1 1275 | else: 1276 | markdown = '' 1277 | markdown += f'## {official_name} ##\n\n' 1278 | markdown += f'### Group conversation between {name_list}: ###\n\n----\n\n' 1279 | markdown += '\n\n----\n\n'.join(md for _, md in messages) 1280 | conversation_output_filename = \ 1281 | paths.create_path_for_file_output_dms(name=group_name, format="md", kind="DMs-Group") 1282 | 1283 | with open_and_mkdirs(conversation_output_filename) as f: 1284 | f.write(markdown) 1285 | print(f'Wrote {len(messages)} messages to {conversation_output_filename}') 1286 | num_written_files += 1 1287 | 1288 | num_written_messages += len(messages) 1289 | 1290 | print(f"\nWrote {len(group_conversations_messages)} direct message group conversations " 1291 | f"({num_written_messages} total messages) to {num_written_files} markdown files") 1292 | 1293 | 1294 | def migrate_old_output(paths: PathConfig): 1295 | """If present, moves media and cache files from the archive root to the new locations in 1296 | `paths.dir_output_media` and `paths.dir_output_cache`. Then deletes old output files 1297 | (md, html, txt) from the archive root, if the user consents.""" 1298 | 1299 | # Create new folders, so we can potentially use them to move files there 1300 | os.makedirs(paths.dir_output_media, exist_ok=True) 1301 | os.makedirs(paths.dir_output_cache, exist_ok=True) 1302 | 1303 | # Move files that we can re-use: 1304 | if os.path.exists(os.path.join(paths.dir_archive, "media")): 1305 | files_to_move = glob.glob(os.path.join(paths.dir_archive, "media", "*")) 1306 | if len(files_to_move) > 0: 1307 | print(f"Moving {len(files_to_move)} files from 'media' to '{paths.dir_output_media}'") 1308 | for file_path_to_move in files_to_move: 1309 | file_name_to_move = os.path.split(file_path_to_move)[1] 1310 | os.rename(file_path_to_move, os.path.join(paths.dir_output_media, file_name_to_move)) 1311 | os.rmdir(os.path.join(paths.dir_archive, "media")) 1312 | 1313 | known_tweets_old_path = os.path.join(paths.dir_archive, "known_tweets.json") 1314 | known_tweets_new_path = os.path.join(paths.dir_output_cache, "known_tweets.json") 1315 | if os.path.exists(known_tweets_old_path): 1316 | os.rename(known_tweets_old_path, known_tweets_new_path) 1317 | 1318 | # Delete files that would be overwritten anyway (if user consents): 1319 | output_globs = [ 1320 | "TweetArchive.html", 1321 | "*Tweet-Archive*.html", 1322 | "*Tweet-Archive*.md", 1323 | "DMs-Archive-*.html", 1324 | "DMs-Archive-*.md", 1325 | "DMs-Group-Archive-*.html", 1326 | "DMs-Group-Archive-*.md", 1327 | "followers.txt", 1328 | "following.txt", 1329 | ] 1330 | files_to_delete = [] 1331 | 1332 | for output_glob in output_globs: 1333 | files_to_delete += glob.glob(os.path.join(paths.dir_archive, output_glob)) 1334 | 1335 | # TODO maybe remove those files only after the new ones have been generated? This way, the user would never 1336 | # end up with less output than before. On the other hand, they might end up with old *and* new versions 1337 | # of the output, if the script crashes before it reaches the code to delete the old version. 1338 | if len(files_to_delete) > 0: 1339 | print(f"\nThere are {len(files_to_delete)} files in the root of the archive,") 1340 | print("which were probably generated from an older version of this script.") 1341 | print("Since then, the directory layout of twitter-archive-parser has changed") 1342 | print("and these files are generated into the sub-directory 'parser-output' or") 1343 | print("various sub-sub-directories therein. These are the affected files:") 1344 | 1345 | for file_to_delete in files_to_delete: 1346 | print(file_to_delete) 1347 | 1348 | user_input = input('\nOK delete these files? (If the the directory layout would not have changed, they would be overwritten anyway) [y/N]') 1349 | if user_input.lower() in ('y', 'yes'): 1350 | for file_to_delete in files_to_delete: 1351 | os.remove(file_to_delete) 1352 | print(f"Files have been deleted. New versions of these files will be generated into 'parser-output' soon.") 1353 | 1354 | 1355 | def is_archive(path): 1356 | """Return true if there is a Twitter archive at the given path""" 1357 | return os.path.isfile(os.path.join(path, 'data', 'account.js')) 1358 | 1359 | 1360 | def find_archive(): 1361 | """ 1362 | Search for the archive 1363 | 1. First try the working directory. 1364 | 2. Then try the script directory. 1365 | 3. Finally prompt the user. 1366 | """ 1367 | if is_archive('.'): 1368 | return '.' 1369 | script_dir = os.path.dirname(__file__) 1370 | if script_dir != os.getcwd(): 1371 | if is_archive(script_dir): 1372 | return script_dir 1373 | print('Archive not found in working directory or script directory.\n' 1374 | 'Please enter the path of your Twitter archive, or just press Enter to exit.\n' 1375 | 'On most operating systems, you can also try to drag and drop your archive folder ' 1376 | 'into the terminal window, and it will paste its path automatically.\n') 1377 | # Give the user as many attempts as they need. 1378 | while True: 1379 | input_path = input('Archive path: ') 1380 | if not input_path: 1381 | exit() 1382 | if is_archive(input_path): 1383 | return input_path 1384 | print(f'Archive not found at {input_path}') 1385 | 1386 | 1387 | def main(lang): 1388 | archive_path = find_archive() 1389 | paths = PathConfig(dir_archive=archive_path) 1390 | 1391 | # Extract the archive owner's username from data/account.js 1392 | username = extract_username(paths) 1393 | 1394 | user_id_url_template = 'https://twitter.com/i/user/{}' 1395 | 1396 | html_template = """\ 1397 | 1398 | 1399 | 1400 | 1401 | 1402 | 1404 | Your Twitter archive! 1405 | 1406 | 1407 |

Your twitter archive

1408 |
1409 | {} 1410 |
1411 | 1412 | """ 1413 | 1414 | users = {} 1415 | 1416 | migrate_old_output(paths) 1417 | 1418 | # Make a folder to copy the images and videos into. 1419 | os.makedirs(paths.dir_output_media, exist_ok=True) 1420 | if not os.path.isfile(paths.file_tweet_icon): 1421 | shutil.copy('assets/images/favicon.ico', paths.file_tweet_icon) 1422 | 1423 | media_sources = parse_tweets(username, users, html_template, paths, lang) 1424 | 1425 | following_ids = collect_user_ids_from_followings(paths) 1426 | print(f'found {len(following_ids)} user IDs in followings.') 1427 | follower_ids = collect_user_ids_from_followers(paths) 1428 | print(f'found {len(follower_ids)} user IDs in followers.') 1429 | dms_user_ids = collect_user_ids_from_direct_messages(paths) 1430 | print(f'found {len(dms_user_ids)} user IDs in direct messages.') 1431 | group_dms_user_ids = collect_user_ids_from_group_direct_messages(paths) 1432 | print(f'found {len(group_dms_user_ids)} user IDs in group direct messages.') 1433 | 1434 | # bulk lookup for user handles from followers, followings, direct messages and group direct messages 1435 | collected_user_ids_without_followers = list( 1436 | set(following_ids).union(set(dms_user_ids)).union(set(group_dms_user_ids)) 1437 | ) 1438 | collected_user_ids_only_in_followers: set = set(follower_ids).difference(set(collected_user_ids_without_followers)) 1439 | collected_user_ids: list = list(set(collected_user_ids_without_followers) 1440 | .union(collected_user_ids_only_in_followers)) 1441 | 1442 | print(f'\nfound {len(collected_user_ids)} user IDs overall.') 1443 | 1444 | # give the user a choice if followers should be included in the lookup 1445 | # (but only in case they make up a large amount): 1446 | unknown_collected_user_ids: set = set(collected_user_ids).difference(users.keys()) 1447 | unknown_follower_user_ids: set = unknown_collected_user_ids.intersection(collected_user_ids_only_in_followers) 1448 | if len(unknown_follower_user_ids) > 5000: 1449 | # Account metadata observed at ~2.1KB on average. 1450 | estimated_follower_lookup_size = int(2.1 * len(unknown_follower_user_ids)) 1451 | # we can look up at least 3000 users per minute. 1452 | estimated_max_follower_lookup_time_in_minutes = len(unknown_follower_user_ids) / 3000 1453 | print( 1454 | f'For some user IDs, the @handle is not included in the archive data. ' 1455 | f'Unknown user handles can be looked up online.' 1456 | f'{len(unknown_follower_user_ids)} of {len(unknown_collected_user_ids)} total ' 1457 | f'user IDs with unknown handles are from your followers. Online lookup would be ' 1458 | f'about {estimated_follower_lookup_size:,} KB smaller and up to ' 1459 | f'{estimated_max_follower_lookup_time_in_minutes:.1f} minutes faster without them.\n' 1460 | ) 1461 | 1462 | if not get_consent(f'Do you want to include handles of your followers ' 1463 | f'in the online lookup of user handles anyway?', default_to_yes=True): 1464 | collected_user_ids = collected_user_ids_without_followers 1465 | 1466 | lookup_users(collected_user_ids, users) 1467 | 1468 | parse_followings(users, user_id_url_template, paths) 1469 | parse_followers(users, user_id_url_template, paths) 1470 | #parse_direct_messages(username, users, user_id_url_template, paths) 1471 | #parse_group_direct_messages(username, users, user_id_url_template, paths) 1472 | 1473 | # Download larger images, if the user agrees 1474 | print(f"\nThe archive doesn't contain the original-size images. We can attempt to download them from twimg.com.") 1475 | print(f'Please be aware that this script may download a lot of data, which will cost you money if you are') 1476 | print(f'paying for bandwidth. Please be aware that the servers might block these requests if they are too') 1477 | print(f'frequent. This script may not work if your account is protected. You may want to set it to public') 1478 | print(f'before starting the download.\n') 1479 | 1480 | if get_consent('OK to start downloading?'): 1481 | download_larger_media(media_sources, paths) 1482 | print('In case you set your account to public before initiating the download, ' 1483 | 'do not forget to protect it again.') 1484 | 1485 | 1486 | if __name__ == "__main__": 1487 | main(lang=sys.argv[1] if len(sys.argv) > 1 else 'zh_hans') 1488 | --------------------------------------------------------------------------------