\"\"\"\n",
206 | "\n",
207 | "with gr.Blocks() as demo:\n",
208 | " gr.HTML(title)\n",
209 | " state = gr.State()\n",
210 | " with gr.Row():\n",
211 | " with gr.Column(scale=2):\n",
212 | " temp = gr.components.Slider(minimum=0, maximum=1.1, value=0.95, label=\"Temperature\",\n",
213 | " info=\"温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。\")\n",
214 | " top_p = gr.components.Slider(minimum=0.5, maximum=1.0, value=0.99, label=\"Top-p\",\n",
215 | " info=\"top-p参数,只输出前p>top-p的文字,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。\")\n",
216 | " #code = gr.Textbox(label=\"temp_output\", info=\"解码器输出\")\n",
217 | " #top_k = gr.components.Slider(minimum=1, maximum=200, step=1, value=25, label=\"Top k\",\n",
218 | " # info=\"top-k参数,下一个输出的文字会从top-k个文字中进行选择,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。\")\n",
219 | " \n",
220 | " with gr.Column(scale=3):\n",
221 | " chatbot = gr.Chatbot(label=\"聊天框\", info=\"\")\n",
222 | " msg = gr.Textbox(label=\"输入框\", placeholder=\"最近过得怎么样?\",\n",
223 | " info=\"输入你的内容,按[Enter]发送。也可以什么都不填写生成随机数据。\")\n",
224 | " clear = gr.Button(\"清除聊天\")\n",
225 | "\n",
226 | " msg.submit(evaluate_stream, [msg, chatbot, temp, top_p], [chatbot, msg])\n",
227 | " clear.click(lambda: None, None, chatbot, queue=False)\n",
228 | " gr.HTML(footer)\n",
229 | "\n",
230 | "demo.queue()\n",
231 | "demo.launch(debug=False)\n"
232 | ]
233 | },
234 | {
235 | "cell_type": "code",
236 | "execution_count": null,
237 | "metadata": {
238 | "id": "GsS8ZYH5ficF"
239 | },
240 | "outputs": [],
241 | "source": [
242 | "import gradio as gr\n",
243 | "gr.Interface(\n",
244 | " fn=evaluate,\n",
245 | " inputs=[\n",
246 | " gr.components.Textbox(\n",
247 | " lines=2, label=\"问题\", placeholder=\"最近过得怎么样?\",\n",
248 | " info=\"可以在这里输入你的问题。也可以什么都不填写生成随机数据。\"\n",
249 | " ),\n",
250 | " #gr.components.Textbox(lines=2, label=\"Input\", placeholder=\"none\"),\n",
251 | " gr.components.Slider(minimum=0, maximum=1.1, value=1.0, label=\"Temperature\",\n",
252 | " info=\"温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。\"),\n",
253 | " gr.components.Slider(minimum=0.5, maximum=1.0, value=0.99, label=\"Top p\",\n",
254 | " info=\"top-p参数,只输出前p>top-p的文字,建议不要修改。\"),\n",
255 | " gr.components.Slider(minimum=1, maximum=200, step=1, value=25, label=\"Top k\",\n",
256 | " info=\"top-k参数,下一个输出的文字会从top-k个文字中进行选择,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。\"),\n",
257 | " ],\n",
258 | " outputs=[\n",
259 | " gr.inputs.Textbox(\n",
260 | " lines=5,\n",
261 | " label=\"Output\",\n",
262 | " )\n",
263 | " ],\n",
264 | " title=\"李萌萌(Alter Ego)\",\n",
265 | " description=\"这是一个通过ChatGLM模型训练的李萌萌的数字分身,你可以在问题栏目填入内容,或者什么都不填,来观察李萌萌到底会说些什么。\",\n",
266 | ").launch()"
267 | ]
268 | }
269 | ],
270 | "metadata": {
271 | "accelerator": "GPU",
272 | "colab": {
273 | "machine_shape": "hm",
274 | "provenance": []
275 | },
276 | "gpuClass": "standard",
277 | "kernelspec": {
278 | "display_name": "Python 3",
279 | "name": "python3"
280 | },
281 | "language_info": {
282 | "name": "python"
283 | }
284 | },
285 | "nbformat": 4,
286 | "nbformat_minor": 0
287 | }
288 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Miaomiao Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Fujisaki
2 |
3 | 致力于创造一个属于你的不二咲千寻。项目处于初期阶段。
4 |
5 | An ongoing (fast prototyping) project to create your own doppelgänger based on your Twitter archive and LoRA models.
6 |
7 | ## 这是什么项目?
8 |
9 | 在我十年前玩某个叫做《弹丸论破》的游戏的时候,有一位可爱的……后来她挂了,但是她以数字生命的形式帮助了未来的主人公冲出难关。如果我死了,还有人记得我吗?
10 |
11 | 该项目目前基于[ChatGLM+LoRa](https://github.com/mymusise/ChatGLM-Tuning/),暂时致力于生成中文内容。也希望该项目能帮助到一些使用中文语料库进行prompt engineering以及内容生成的朋友。
12 |
13 | Inspired by Fujisaki Chihiro (i.e., Alter Ego). I thought it would be a fun project, as I really cannot predict my death in the future, and I would like to see how my doppelgänger would react to the world. Thank to the super-strong LLM and LoRA to make it happen by allowing fine-tuning on small corpora.
14 |
15 | ## 项目基于什么原理
16 |
17 | 将你的Twitter存档目录解压缩放置在项目根目录下,即可自动处理。相关的程序会根据推文的内容、时间、引用、回复等信息,生成一个instruction风格的JSON数据集。然后,我们将该数据集用于某种中文语言模型的调参,生成新的推文,或启用给予上下文的对话。
18 |
19 | 目前项目的文字生成能力很优秀,问答一般,不过在推文涵盖的范围内能够识别一定的问题。对话功能也因为推文都是陈述句受到一定限制,但是因为ChatGLM本身有一定的对话能力,我们还是可以信赖它可以生成一定的人话的。
20 |
21 | 优秀训练数据的生成是本项目的最终目的,目前的实现均在`prompt.py`中,也已经引入OpenAI生成更有价值的增广上下文。
22 |
23 | ## Demo:
24 |
25 | HuggingFace Hub(慢一点的在线版本,可以直接运行):[](https://huggingface.co/spaces/ljsabc/Fujisaki)
26 |
27 | Colab演示(需要GPU):[](https://colab.research.google.com/github/ljsabc/Fujisaki/blob/main/Fujisaki_GPU_Stream.ipynb)
28 |
29 | 将Colab笔记本中的`peft_path`改为你自己训练的模型(可以是一个本地文件夹,或者是一个HuggingFace Repo),就可以复现你自己训练后的版本了。
30 |
31 |
32 | 
33 |
34 | 相对于索引、问答目的的项目,该项目会更大程度上利用Sampling,即相似的上下文也会生成非常不一样的回答。希望这样的生成方法能够提供更高的互动性。
35 |
36 | ## To-do List
37 |
38 | - [x] Modify the twitter-parser to output your twitter archive into a instruction dataset
39 | - [x] Categorized in-reply-to and quoted tweets for better conditional generation
40 | - [x] LoRA finetuning with multiple GPUs
41 | - [x] Basic prompt engineering for original posts
42 | - [x] Hyperparameter tuning (incl. LoRA rank, batch size, learning rate, etc.)
43 | - [x] Allow in-reply-to tweets to be downloaded.
44 | - [x] Advanced prompt engineering from OpenAI
45 | - [x] Colab notebook for easy deployment (I believe this code can surely run on T4 as we are expecting much shortened tokens)
46 | - [ ] Download quoted tweets.
47 | - [ ] Support other datasets (e.g. Reddit, Weibo, etc. Future plan, let's discuss in #2 .)
48 |
49 | ## 环境与安装
50 |
51 | 推荐使用 `conda` 环境。安装依赖:
52 |
53 | ```pip install -r requirements.txt```
54 |
55 | 有些时候,你可能要安装cudatoolkit:
56 |
57 | ```conda install cudatoolkit=11.3```
58 |
59 | 如果使用4090需要安装11.8的版本。s
60 |
61 | 可以帮助解决bitsandbytes检测不到CUDA runtime的问题。准备好环境之后不要忘记将`config.example.py`重命名为`config.py`,并且修改其中的参数(如有必要)。
62 |
63 | ## Data requirements
64 |
65 | 解压缩你的推文存档,放置在项目根目录下,即可自动处理。解压缩之后你应该能在项目根目录里面看到`Your archive.html`这个文件。然后,运行`twitter-parser.py`来解析你的推文存档,生成一个RLHF风格的JSON数据集。
66 |
67 | 同样的,你可以参考`tweets_sample.md`来生成你自己的数据集,或者等待项目更新。
68 |
69 | ## 训练
70 |
71 | 目前的模型基于[ChatGLM+LoRa](https://github.com/mymusise/ChatGLM-Tuning/),与[Luotuo](https://github.com/LC1332/Chinese-alpaca-lora)的处理方式较为类似。
72 |
73 | ### 生成训练数据
74 |
75 | 首先使用
76 |
77 | python twitter-parser.py
78 |
79 | 来处理推文存档,稍许等待之后,你会在项目根目录下看到一个`tweets.md`的文件。这个文件包含了你的推文存档中的所有推文,以及相关的信息。为了保护你的隐私,请不要公开该文件。
80 |
81 | 生成相应的数据之后,我们需要进一步调用ChatGLM的`tokenizer`来生成对应的tokenized数据集。这一步需要一些时间。这个原始的版本会过度cache同一个generator导致数据无法更新,我改了一个单文件的版本。
82 |
83 | python ./tokenize_dataset_rows.py --json_path ./tweets.md --save_path tweets.tokens --max_seq_length 240
84 |
85 | (可选)使用240个token是因为我的大部份推文,连同instruction一起,也不会超过240个token。如果你的推文较长,可以在生成jsonl之后调用`length.py`,根据输出的数据适当增加`max_seq_length`的数值。
86 |
87 | python3 ./cover_alpaca2jsonl.py --data_path tweets.md --save_path tweets.jsonl
88 | python length.py
89 |
90 | ### 可选项:根据回复推文生成更有互动性的数据集(Beta)
91 |
92 | 因为推文存档中的推文并不储存上下文信息,所以我们可以通过抓取回复推文来获得原始的上下文信息,来实现非常强的对话功能。这一步需要一些时间,但是只需要运行一次。
93 | 很不幸,抓取网站总是一个很痛苦的事情。你可能需要一个代理池,以及将`chromedriver`的二进制放在项目根目录中(需要你了解[selenium](https://selenium-python.readthedocs.io/))。你可能还需要一台性能足够强的电脑,以便多线程抓取回复信息。如果你有这样的条件,可以调整`config.py`,将`PARSE_REPLIES = False`改为`PARSE_REPLIES = True`,然后运行`twitter-parser.py`。该部分已经写好。
94 |
95 | 该方法无法抓取已经被删除的你的,和别人的推文。抓取引用推文的功能也尚在检讨中。
96 |
97 | ### 可选项:使用OpenAI根据推文生成问题
98 |
99 | 调用OpenAI API为你的原创推文生成一个问题,或者前情提要,以提供更顺畅的问答体验。该部分已经写好,需要你使用OpenAI API Key。价格稍显昂贵,每一万条推文需要准备2-3刀的预算。请填写config.py以配置相应内容,并参照`prompt_openai.py`查看实现方式。
100 |
101 | ### 训练
102 |
103 | 接下来便可调用`finetune.py`来进行模型训练。根据不同的GPU数量,你可以直接调用
104 |
105 | WORLD_SIZE=4 CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
106 | --master_port=1234 \
107 | finetune.py \
108 | --dataset_path tweets.tokens \
109 | --lora_rank 8 \
110 | --per_device_train_batch_size 8 \
111 | --gradient_accumulation_steps 1 \
112 | --num_train_epoch 2 \
113 | --save_steps 2000 \
114 | --save_total_limit 2 \
115 | --learning_rate 6e-4 \
116 | --remove_unused_columns false \
117 | --logging_steps 50 \
118 | --output_dir output \
119 | --ddp_find_unused_parameters false \
120 | --warmup_steps 50
121 |
122 | 进行多卡训练。如果多卡训练报错,可能再跑一遍就好了,是一个小的缓存bug。
123 |
124 | 单卡训练:
125 |
126 | python finetune.py \
127 | --dataset_path tweets.tokens \
128 | --lora_rank 8 \
129 | --per_device_train_batch_size 8 \
130 | --gradient_accumulation_steps 1 \
131 | --num_train_epoch 2 \
132 | --save_steps 2000 \
133 | --save_total_limit 2 \
134 | --learning_rate 2e-4 \
135 | --remove_unused_columns false \
136 | --logging_steps 50 \
137 | --output_dir output \
138 | --warmup_steps 100
139 |
140 | 目前的参数和[ChatGLM+LoRa](https://github.com/mymusise/ChatGLM-Tuning/)很类似,不过可以根据GPU数量调节学习率。默认的学习率是`2e-4` (每8个sample,如果loss突增可能还要降一些),请根据batch size和显卡能力自行测试调节。上游增加了gradient checkpointing之后3090甚至可以用到8的batch size了,体验非常好。LoRA的rank可以根据你希望的模型性能进行调节,默认的8是足够的,你也可以提升到12甚至更高,经过一定的测试`lora_rank`上到16结果会上升一个台阶,代价是稍微更长一点的训练和测试时间,但是不会多很多。
141 |
142 | 训练好的模型会保存在`output`文件夹下,你可以在`output/`中找到对应的模型文件。
143 |
144 | ## 预测(以及互动)
145 |
146 | 调用 `infer.py` 进行对话。你可以输入任何问题(因为是推文的原因,输入陈述句的效果比疑问句好),不过即便什么都不输入也可以生成一个很类似我的推文。
147 |
148 | ```python3 ./infer.py ```
149 |
150 | 可以在该文件中调节top-p,top-k和temerature,以便生成更多的样本。可以根据gradio demo的结果适当调节。
151 |
152 | ## Benchmark
153 |
154 | 在1张A100的配置下面,训练一个75,000条推文的数据集,在设定最大长度为240的情况下,每一个epoch需要3小时。训练大概需要2-3个epoch能够达成最佳状态。
155 |
156 | ## Credits
157 |
158 | This project is based on the following projects:
159 |
160 | 27182812/ChatGLM-chinese-insturct
161 | timhutton/twitter-archive-parser
162 | LC1332/Chinese-alpaca-lora (Donated❤️)
163 |
164 | Inspired by the following projects:
165 |
166 | tloen/alpaca-lora
167 | HuggingFace: KBlueLeaf/guanaco-7B-lora-embed
168 | (potentially) twint-fork
169 |
--------------------------------------------------------------------------------
/config_example.py:
--------------------------------------------------------------------------------
1 | # 默认参数,不太需要修改
2 | # Limit the minimal length of your response
3 | RESPONSE_THRESH = 6
4 | # the augmentation factor can be reduced if you enable the parsing of replies
5 | AUGMENTATION_FACTOR_ORIGINAL = 2
6 | # better just 1, as we are using full context
7 | AUGMENTATION_FACTOR_REPLIES = 1
8 |
9 | # reply delimiter, useful to distniguish the speaker in the context.
10 | REPLY_DELIMITER = "||"
11 |
12 | # 以下都是可选功能,可以根据需要开启
13 | # More interestingly, let's ask OpenAI to make a question for your tweets
14 | ENABLE_OPENAI = False
15 | OPENAI_KEY = ''
16 | OPENAI_MAX_SAMPLE = 2 # dont make it larger than the augmentation factor
17 | OPENAI_THREADS = 128 # Be careful when you try this! It can be expensive.
18 | OPENAI_CACHE_ONLY = False
19 |
20 | # Parsing replies is not recommended, as it will take a lot of time
21 | # But it can increase the overall quality significantly
22 | PARSE_REPLIES = False
23 | SCRAPE_TIMEOUT = 20
24 | SCRAPE_CACHE_ONLY = False
25 | PROXY = False
26 | PROXY_ADDR = ""
27 | PROXY_PORT = ""
28 | PROXY_USER = ""
29 | PROXY_PASSWD = ""
30 |
31 | # numbers of processes to used for selenium
32 | # base on experience this is better 2-3 times your CPU cores. Use it wisely.
33 | PROCESSES = 8
--------------------------------------------------------------------------------
/cover_alpaca2jsonl.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from tqdm import tqdm
4 |
5 |
6 | def format_example(example: dict) -> dict:
7 | context = f"Instruction: {example['instruction']}\n"
8 | if example.get("input"):
9 | context += f"Input: {example['input']}\n"
10 | context += "Answer: "
11 | target = example["output"]
12 | return {"context": context, "target": target}
13 |
14 |
15 | def main():
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument("--data_path", type=str, default="data/alpaca_data.json")
18 | parser.add_argument("--save_path", type=str, default="data/alpaca_data.jsonl")
19 |
20 | args = parser.parse_args()
21 | with open(args.data_path) as f:
22 | examples = json.load(f)
23 |
24 | with open(args.save_path, 'w') as f:
25 | for example in tqdm(examples, desc="formatting.."):
26 | f.write(json.dumps(format_example(example)) + '\n')
27 |
28 |
29 | if __name__ == "__main__":
30 | main()
31 |
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 | from transformers import TrainingArguments
2 | from transformers import Trainer, HfArgumentParser
3 | from transformers import AutoTokenizer, AutoModel
4 | import torch
5 | import torch.nn as nn
6 | from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_int8_training
7 | from dataclasses import dataclass, field
8 | import datasets
9 | import os
10 |
11 |
12 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, revision="658202d")
13 |
14 |
15 | @dataclass
16 | class FinetuneArguments:
17 | dataset_path: str = field(default="data/alpaca")
18 | model_path: str = field(default="output")
19 | lora_rank: int = field(default=8)
20 |
21 |
22 | class CastOutputToFloat(nn.Sequential):
23 | def forward(self, x):
24 | return super().forward(x).to(torch.float32)
25 |
26 |
27 | def data_collator(features: list) -> dict:
28 | len_ids = [len(feature["input_ids"]) for feature in features]
29 | longest = max(len_ids)
30 | input_ids = []
31 | labels_list = []
32 | for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):
33 | ids = feature["input_ids"]
34 | seq_len = feature["seq_len"]
35 | labels = (
36 | [-100] * (seq_len - 1) + ids[(seq_len - 1) :] + [-100] * (longest - ids_l)
37 | )
38 | ids = ids + [tokenizer.pad_token_id] * (longest - ids_l)
39 | _ids = torch.LongTensor(ids)
40 | labels_list.append(torch.LongTensor(labels))
41 | input_ids.append(_ids)
42 | input_ids = torch.stack(input_ids)
43 | labels = torch.stack(labels_list)
44 | return {
45 | "input_ids": input_ids,
46 | "labels": labels,
47 | }
48 |
49 |
50 | class ModifiedTrainer(Trainer):
51 | def compute_loss(self, model, inputs, return_outputs=False):
52 | return model(
53 | input_ids=inputs["input_ids"],
54 | labels=inputs["labels"],
55 | ).loss
56 |
57 | def save_model(self, output_dir=None, _internal_call=False):
58 | from transformers.trainer import TRAINING_ARGS_NAME
59 |
60 | os.makedirs(output_dir, exist_ok=True)
61 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
62 | saved_params = {
63 | k: v.to("cpu") for k, v in self.model.named_parameters() if v.requires_grad
64 | }
65 | torch.save(saved_params, os.path.join(output_dir, "adapter_model.bin"))
66 |
67 |
68 | def main():
69 | finetune_args, training_args = HfArgumentParser(
70 | (FinetuneArguments, TrainingArguments)
71 | ).parse_args_into_dataclasses()
72 |
73 |
74 | device_map = "auto"
75 | world_size = int(os.environ.get("WORLD_SIZE", 1))
76 | ddp = world_size != 1
77 | if ddp:
78 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
79 |
80 | # init model
81 | model = AutoModel.from_pretrained(
82 | "THUDM/chatglm-6b", load_in_8bit=True, trust_remote_code=True, device_map=device_map,revision="658202d"
83 | )
84 |
85 | # It's exactly the following.
86 | model = prepare_model_for_int8_training(model)
87 | #model.gradient_checkpointing_enable()
88 | #model.enable_input_require_grads()
89 | model.is_parallelizable = True
90 | model.model_parallel = True
91 | #model.lm_head = CastOutputToFloat(model.lm_head)
92 | model.config.use_cache = (
93 | False # silence the warnings. Please re-enable for inference!
94 | )
95 |
96 | #print(model)
97 |
98 | # setup peft
99 | peft_config = LoraConfig(
100 | task_type=TaskType.CAUSAL_LM,
101 | inference_mode=False,
102 | r=finetune_args.lora_rank,
103 | lora_alpha=32,
104 | lora_dropout=0.1,
105 | )
106 | model = get_peft_model(model, peft_config)
107 |
108 | # load dataset
109 | dataset = datasets.load_from_disk(finetune_args.dataset_path)
110 | #print(f"\n{len(dataset)=}\n")
111 |
112 | # start train
113 | trainer = ModifiedTrainer(
114 | model=model,
115 | train_dataset=dataset,
116 | args=training_args,
117 | data_collator=data_collator,
118 | )
119 | trainer.train()
120 | # save model
121 | #save_tunable_parameters(
122 | # model, os.path.join(training_args.output_dir, "chatglm-lora.pt")
123 | #)
124 |
125 | # don't just save model, save configs also.
126 | model.save_pretrained(training_args.output_dir)
127 |
128 |
129 | if __name__ == "__main__":
130 | main()
131 |
--------------------------------------------------------------------------------
/finetune_adv.py:
--------------------------------------------------------------------------------
1 | from transformers import TrainingArguments
2 | from transformers import Trainer, HfArgumentParser
3 | from transformers import AutoTokenizer, AutoModel
4 | import torch
5 | import torch.nn as nn
6 | from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_int8_training
7 | from dataclasses import dataclass, field
8 | import datasets
9 | import os
10 |
11 |
12 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, revision="658202d")
13 |
14 |
15 | @dataclass
16 | class FinetuneArguments:
17 | dataset_path: str = field(default="data/alpaca")
18 | model_path: str = field(default="output")
19 | lora_rank: int = field(default=8)
20 |
21 |
22 | class CastOutputToFloat(nn.Sequential):
23 | def forward(self, x):
24 | return super().forward(x).to(torch.float32)
25 |
26 |
27 | def data_collator(features: list) -> dict:
28 | len_ids = [len(feature["input_ids"]) for feature in features]
29 | longest = max(len_ids)
30 | input_ids = []
31 | labels_list = []
32 | for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):
33 | ids = feature["input_ids"]
34 | seq_len = feature["seq_len"]
35 | labels = (
36 | [-100] * (seq_len - 1) + ids[(seq_len - 1) :] + [-100] * (longest - ids_l)
37 | )
38 | ids = ids + [tokenizer.pad_token_id] * (longest - ids_l)
39 | _ids = torch.LongTensor(ids)
40 | labels_list.append(torch.LongTensor(labels))
41 | input_ids.append(_ids)
42 | input_ids = torch.stack(input_ids)
43 | labels = torch.stack(labels_list)
44 | return {
45 | "input_ids": input_ids,
46 | "labels": labels,
47 | }
48 |
49 |
50 | class ModifiedTrainer(Trainer):
51 | def compute_loss(self, model, inputs, return_outputs=False):
52 | return model(
53 | input_ids=inputs["input_ids"],
54 | labels=inputs["labels"],
55 | ).loss
56 |
57 | def save_model(self, output_dir=None, _internal_call=False):
58 | from transformers.trainer import TRAINING_ARGS_NAME
59 |
60 | os.makedirs(output_dir, exist_ok=True)
61 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
62 | saved_params = {
63 | k: v.to("cpu") for k, v in self.model.named_parameters() if v.requires_grad
64 | }
65 | torch.save(saved_params, os.path.join(output_dir, "adapter_model.bin"))
66 |
67 |
68 | def main():
69 | finetune_args, training_args = HfArgumentParser(
70 | (FinetuneArguments, TrainingArguments)
71 | ).parse_args_into_dataclasses()
72 |
73 |
74 | device_map = "auto"
75 | world_size = int(os.environ.get("WORLD_SIZE", 1))
76 | ddp = world_size != 1
77 | if ddp:
78 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
79 |
80 | # init model
81 | model = AutoModel.from_pretrained(
82 | "THUDM/chatglm-6b", load_in_8bit=True, trust_remote_code=True, device_map=device_map,revision="658202d"
83 | )
84 |
85 | # It's exactly the following.
86 | model = prepare_model_for_int8_training(model)
87 | #model.gradient_checkpointing_enable()
88 | #model.enable_input_require_grads()
89 | model.is_parallelizable = True
90 | model.model_parallel = True
91 | #model.lm_head = CastOutputToFloat(model.lm_head)
92 | model.config.use_cache = (
93 | False # silence the warnings. Please re-enable for inference!
94 | )
95 |
96 | # setup peft
97 | peft_config = LoraConfig(
98 | task_type=TaskType.CAUSAL_LM,
99 | inference_mode=False,
100 | r=finetune_args.lora_rank,
101 | lora_alpha=32,
102 | lora_dropout=0.1,
103 | target_modules=['query_key_value', 'lm_head','word_embeddings']
104 | )
105 | model = get_peft_model(model, peft_config)
106 | print(model)
107 |
108 | # load dataset
109 | dataset = datasets.load_from_disk(finetune_args.dataset_path)
110 | #print(f"\n{len(dataset)=}\n")
111 |
112 | # start train
113 | trainer = ModifiedTrainer(
114 | model=model,
115 | train_dataset=dataset,
116 | args=training_args,
117 | data_collator=data_collator,
118 | )
119 | trainer.train()
120 | # save model
121 | #save_tunable_parameters(
122 | # model, os.path.join(training_args.output_dir, "chatglm-lora.pt")
123 | #)
124 |
125 | # don't just save model, save configs also.
126 | model.save_pretrained(training_args.output_dir)
127 |
128 |
129 | if __name__ == "__main__":
130 | main()
131 |
--------------------------------------------------------------------------------
/github_assets/banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ljsabc/Fujisaki/a4a70eba2bbd95266440b3a9ac94b7a29df96004/github_assets/banner.png
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import sys
3 |
4 | from transformers import AutoTokenizer, GenerationConfig, AutoModel
5 |
6 | torch.set_default_tensor_type(torch.cuda.HalfTensor)
7 | model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, revision="658202d").cuda().half()
8 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, revision="658202d")
9 |
10 | from peft import get_peft_model, LoraConfig, TaskType, PeftModel
11 |
12 | peft_path = sys.argv[1] if len(sys.argv) > 1 else "output/"
13 | model = PeftModel.from_pretrained(
14 | model,
15 | peft_path,
16 | torch_dtype=torch.float16,
17 | )
18 | print(model)
19 |
20 | # TODO: check if full precision is necessary
21 | torch.set_default_tensor_type(torch.cuda.FloatTensor)
22 | model.eval()
23 |
24 | generation_config = GenerationConfig(
25 | temperature=0.9,
26 | top_p=0.975,
27 | #top_k=150,
28 | #repetition_penalty=1.1,
29 | num_beams=1,
30 | do_sample=True,
31 | )
32 |
33 | with torch.no_grad():
34 | while True:
35 | context = input(">")
36 | input_text = f"Context: {context}Answer: "
37 | ids = tokenizer([input_text], return_tensors="pt")
38 | inputs = ids.to("cuda")
39 | #input_ids = torch.LongTensor([ids]).cuda()
40 | out = model.generate(
41 | **inputs,
42 | max_length=224,
43 | generation_config=generation_config
44 |
45 | )
46 | out = out.tolist()[0]
47 | #print(out)
48 | decoder_output = tokenizer.decode(out)
49 | #print(decoder_output)
50 | out_text = "Chihiro:" + decoder_output.split("Answer: ")[1]
51 | print(out_text)
52 |
--------------------------------------------------------------------------------
/length.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import transformers
3 | from tqdm import tqdm
4 | import json
5 |
6 | def preprocess(tokenizer, config, example, max_seq_length):
7 | prompt = example["context"]
8 | target = example["target"]
9 | prompt_ids = tokenizer.encode(prompt, max_length=max_seq_length, truncation=True)
10 | target_ids = tokenizer.encode(
11 | target,
12 | max_length=max_seq_length,
13 | truncation=True,
14 | add_special_tokens=False)
15 | input_ids = prompt_ids + target_ids + [config.eos_token_id]
16 | #print(prompt+target, input_ids)
17 | return {"input_ids": input_ids, "seq_len": len(prompt_ids)}
18 |
19 | model_name = "THUDM/chatglm-6b"
20 | tokenizer = transformers.AutoTokenizer.from_pretrained(
21 | model_name, trust_remote_code=True)
22 | config = transformers.AutoConfig.from_pretrained(
23 | model_name, trust_remote_code=True, device_map='auto')
24 |
25 | import numpy as np
26 |
27 | length = []
28 | with open("tweets.jsonl", "r") as f:
29 | for line in tqdm(f.readlines()):
30 | example = json.loads(line)
31 | feature = preprocess(tokenizer, config, example, 4096)
32 | length.append(len(feature["input_ids"]))
33 |
34 | l = np.array(length)
35 | for q in [0.9, 0.95, 0.97, 0.99, 0.999]:
36 | print(f"{q}: {np.quantile(l, q)}")
37 |
38 |
--------------------------------------------------------------------------------
/prompt.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import re
4 |
5 | import numpy as np
6 | import collections
7 |
8 | from scrape_twitter import process_tweet_ids
9 | import config
10 |
11 | from prompt_openai import processOriginalTweet_openai
12 |
13 | from prompt_util import original_post_prompt, findTopic, cut_sent, checkResponse
14 |
15 |
16 | def processOriginalTweet(tweets):
17 |
18 | # It is about the original tweet
19 | if config.PARSE_REPLIES:
20 | # if we parse the replies, we will have more data to sample from
21 | # we do not need to do the completion, and the Q&A part can be inferred from the in-reply-to
22 | # of the original posts, 40% are unconditional (with questions), 10% are completion, 20% are Q&A, 30% are unconditional (with no prompts)
23 | sample_range = [0.4, 0.5, 0.7, 1]
24 | else:
25 | sample_range = [0.35, 0.5, 0.95, 1]
26 |
27 |
28 | final = []
29 | for item in tweets:
30 | # sample a random float from 0-1 to decide the ways of generation
31 | # sample_range is a probablity accumulative list
32 | # [random post, completion, Q&A, rest (direct original post)]
33 | tweet = item["text"]
34 | rr = random.random()
35 | if rr < sample_range[0]:
36 | # sample a random question, and concatenate
37 | instruction = f"{random.choice(original_post_prompt)}"
38 | user_input = f""
39 | if checkResponse(tweet):
40 | final.append({"instruction": instruction, "input": user_input, "output": tweet})
41 | elif rr < sample_range[1]:
42 | # given a truncated tweet, ask for completion
43 | substring = cut_sent(tweet)
44 | if len(substring) > 1:
45 | user_input = f""
46 | rr = random.randint(1, len(substring)-1)
47 | instruction = "".join(substring[0:rr])
48 | if checkResponse(instruction):
49 | final.append({"instruction": instruction, "input": user_input, "output": "".join(substring[rr:])})
50 | else:
51 | instruction = f"{random.choice(original_post_prompt)}"
52 | user_input = f""
53 | final.append({"instruction": instruction, "input": user_input, "output": tweet})
54 | elif rr < sample_range[2]:
55 | # QA like
56 | # ask for a topic, the topic is mainly based on a substring of this tweet
57 | instruction = findTopic(tweet)
58 | if instruction is not None:
59 | user_input = f""
60 | if checkResponse(tweet):
61 | final.append({"instruction": instruction, "input": user_input, "output": tweet})
62 | else:
63 | #if cannot find a topic
64 | instruction = f"{random.choice(original_post_prompt)}"
65 | user_input = f""
66 | if checkResponse(tweet):
67 | final.append({"instruction": instruction, "input": user_input, "output": tweet})
68 |
69 | else:
70 | # no instructions, unconditional generation.
71 | final.append({"instruction": "", "input": "", "output": tweet})
72 |
73 | return final
74 |
75 | def processReplyTweets(tweets):
76 |
77 | final = []
78 | context_count = []
79 |
80 | for index, t in enumerate(tweets):
81 | #print(index, t)
82 | tweet_id = t['id']
83 | tweet_text = t['text']
84 | context = t['context']
85 |
86 | # first, we need to check if the reply itself is interesting
87 | if not checkResponse(tweet_text):
88 | continue
89 |
90 | # then, we need to check if the context is blank
91 | if context is None:
92 | continue
93 |
94 | if len(context) == 0:
95 | continue
96 |
97 | # next, we do a rough check in if the context is interesting
98 | # if the context is too short, we will not use it
99 | if not checkResponse("".join(context)):
100 | continue
101 |
102 | # We believe the context is interesting
103 | # in this way, we want to sample a random context length based on the probability distribution of 1/x
104 | # the longer the context, the less likely it will be sampled
105 | # if the context's length is not long enough, we will sample again
106 | '''
107 | l = len(context)
108 | p = []
109 | for j in range(l):
110 | p.append(np.power(j+1, -config.REPLY_TEMP))
111 |
112 | # normalize the probability
113 | p = np.array([i/sum(p) for i in p], dtype=np.float32)
114 |
115 | while True:
116 | # sample a number based on p
117 | r = np.random.choice(l, 1, p=p)[0]
118 | context_text = "\n".join(context[-r:])
119 | # check if the context is interesting
120 | if checkResponse(context_text):
121 | r_count.append(r)
122 | break
123 | '''
124 |
125 | context_count.append(len(context))
126 |
127 | # now we have a context, and a reply
128 | # go give the prompt
129 | #
130 | final.append({"instruction": config.REPLY_DELIMITER.join(context), "input": "", "output": tweet_text})
131 |
132 | # but we can do more, we can also augment a Q&A like discussion within the topic, if we want
133 | # give a small random chance to do this
134 | if random.random() < 0.05:
135 | instruction = findTopic(context[-1])
136 | if instruction is not None:
137 | final.append({"instruction": instruction, "input": "", "output": tweet_text})
138 |
139 | # TODO: any other sort of prompt engineering?
140 | # use a counter to see how many long contexts are used
141 | print(collections.Counter(context_count))
142 |
143 | return final
144 |
145 | def write_json(md_path, final_md, lang):
146 |
147 | # construct a instruction dataset
148 | final = []
149 |
150 | # construct a list of tweets to be downloaded to sample the contexts
151 | context_tweets = []
152 |
153 | # original tweets
154 | original_tweets = []
155 |
156 | # firstly classify the type of each tweet
157 | for id, md, in_reply_to, quote, retweet in final_md:
158 |
159 | # content filter goes here:
160 | if md.strip() == "(media)":
161 | continue
162 |
163 | if in_reply_to and quote:
164 | # todo: process replies and quotes
165 | pass
166 | elif in_reply_to:
167 | # save them into a list; we will download them later
168 | context_tweets.append({"id": id, "text": md})
169 | elif quote:
170 | # todo: process quotes
171 | pass
172 | elif retweet:
173 | # not my tweet, simply discard them
174 | pass
175 | else:
176 | # original tweets
177 | original_tweets.append({"id": id, "text": md})
178 |
179 |
180 | # process with the original tweets
181 | for l in range(config.AUGMENTATION_FACTOR_ORIGINAL):
182 | if config.ENABLE_OPENAI:
183 | res = processOriginalTweet_openai(original_tweets, l)
184 | print(f"Batch: {l+1}, {len(res)} original tweets are added with openai.")
185 | final.extend(res)
186 | else:
187 | res = processOriginalTweet(original_tweets)
188 | print(f"Batch: {l+1}, {len(res)} original tweets are added.")
189 | final.extend(res)
190 |
191 |
192 | # process with the replies
193 | if config.PARSE_REPLIES:
194 | # Now things get even more interesting, we will scrape the tweets from the context_tweet_ids
195 | parsed_tweets = process_tweet_ids(context_tweets)
196 | print(f"Processed {len(parsed_tweets)} tweets from the context tweets.")
197 |
198 | for l in range(config.AUGMENTATION_FACTOR_REPLIES):
199 | res = processReplyTweets(parsed_tweets)
200 | print(f"Batch: {l+1}, {len(res)} replies are added.")
201 | final.extend(res)
202 |
203 |
204 |
205 | with open(md_path, "w") as f:
206 | # shuffle the dataset
207 | random.shuffle(final)
208 | f.write(json.dumps(final, indent=4, ensure_ascii=False))
209 |
--------------------------------------------------------------------------------
/prompt_openai.py:
--------------------------------------------------------------------------------
1 | # prompt_openai.py
2 | # parsing with openai-guided prompts
3 | from tkinter import E
4 | import config
5 | import random
6 | import json
7 | from prompt_util import checkResponse, original_post_prompt, findTopic
8 | import openai
9 | import os
10 | from concurrent import futures
11 | import collections
12 |
13 | def process_single_original_qa(value):
14 | response = openai.ChatCompletion.create(
15 | model="gpt-3.5-turbo",
16 | messages=[
17 | #{"role": "system", "content": "你是一个富有情感的善于分析他人发言的发帖助手。\
18 | # 请提出一个问题,使用户输入的内容可以恰当回复你提出的问题。问题中禁止包含“这”这个字。\
19 | # 如果难以提问,或者提出的问题更像是在追问用户的输入而不是让用户的输入解答问题,就请提出一个诸如“最近发生什么事?”之类的通用问题。"},
20 |
21 | {"role": "system", "content": "Based on the Chinese user input, create a question in SIMPLIFIED CHINESE that allows the user's input to serve as an appropriate response to your question.\
22 | If it's too difficult to come up with a question, or the user's input is too ambiguous, or you have to question base on the details of the user's input, \
23 | please give up and just output a general Chinese question which has the similar meaning of, '最近发生什么事了?'."},
24 |
25 | {"role": "user", "content": f"User input (in Chinese): {value}"},
26 | ],
27 | max_tokens=128,
28 | temperature=0.95
29 | )
30 | return response.choices[0]["message"]["content"].strip()
31 |
32 | def process_single_original_continue(value):
33 | response = openai.ChatCompletion.create(
34 | model="gpt-3.5-turbo",
35 | messages=[
36 | #{"role": "system", "content": "你是一个富有情感的善于分析他人发言的发帖助手。\
37 | # 请试着在用户输入的内容前面补充一个不是疑问句的简短上文,使得用户的输入可以恰当地衔接你的上文。\
38 | # 你的上文可以是用户输入之前发生的事情,也可以是一个短暂的前情提要,但不允许总结或重复用户输入。\
39 | # 注意:是让用户输入跟随你的上文,而不是你的上文去跟随用户的输入。你的回答只需要包含上文。"},
40 |
41 | {"role": "system", "content": "Please create a EXTREMELY brief context IN SIMPLIFIED CHINESE that precedes the Chinese user input, \
42 | imagining a scenario where the user's input can naturally be right AFTER your context to form a complete story. \
43 | Your context should not cover the details of the user input, but rather set the stage for it. Note: \
44 | Your context should not contain the texts that are already in the user input. Your should output the Chinese context only, without any modifier."},
45 | {"role": "user", "content": f"User input (in Chinese): {value}"},
46 | ],
47 | max_tokens=128,
48 | temperature=0.95
49 | )
50 | return response.choices[0]["message"]["content"].strip()
51 |
52 | def openai_process_original(item):
53 | tweet = item["text"]
54 | id = item["id"]
55 | res = []
56 | loadedCount = 0
57 |
58 | # check cache first
59 | loadedCount = 0
60 | if os.path.exists(f'openai_cache/{id}.txt'):
61 | try:
62 | with open(f'openai_cache/{id}.txt', 'r') as f:
63 | res = json.load(f)
64 | if len(res) >= config.OPENAI_MAX_SAMPLE:
65 | # possibility that the cache is even larger than the max sample
66 | return tweet, res
67 | else:
68 | loadedCount = len(res)
69 | except Exception as e:
70 | print(e)
71 | pass
72 |
73 | if config.OPENAI_CACHE_ONLY:
74 | return None
75 |
76 | # now we need to invoke openai to generate the rest of the samples
77 | # sample several seed questions for each tweet
78 | for j in range(loadedCount, config.OPENAI_MAX_SAMPLE):
79 | # 75% Q&A, 25% completion
80 | if random.random() < 0.75:
81 | try:
82 | res.append(process_single_original_qa(tweet))
83 | except Exception as e:
84 | print(e)
85 | # broken
86 | return None
87 |
88 | else:
89 | try:
90 | res.append(process_single_original_continue(tweet))
91 | except Exception as e:
92 | print(e)
93 | # broken
94 | return None
95 |
96 | # save to cache
97 | with open(f'openai_cache/{id}.txt', 'w') as f:
98 | json.dump(res, f, indent=4, ensure_ascii=False)
99 |
100 | return tweet, res
101 |
102 |
103 |
104 | def processOriginalTweet_openai(tweets, iteration_count):
105 | # the iteration count matters, as we will use it as the index to sample cached openai prompts
106 | openai.api_key = config.OPENAI_KEY
107 |
108 | # make the openai cache directory
109 | if not os.path.exists('openai_cache'):
110 | os.makedirs('openai_cache')
111 |
112 | # because it's sampling from openai, we do not need to do Q/A or completion
113 | # we better just do random sampling with a seed question for a small portion of samples
114 |
115 | final = []
116 | sample_threshold = 0.10 if iteration_count == 0 else 0 # sample 10% of the tweets
117 |
118 | for item in tweets:
119 | # sample a random float from 0-1 to decide the ways of generation
120 | # sample_range is a probablity accumulative list
121 | # [random post, completion, Q&A, rest (direct original post)]
122 | tweet = item["text"]
123 | rr = random.random()
124 | if rr < sample_threshold:
125 | # sample a random question, and concatenate
126 | instruction = f"{random.choice(original_post_prompt)}"
127 | user_input = f""
128 | if checkResponse(tweet):
129 | final.append({"instruction": instruction, "input": user_input, "output": tweet})
130 | elif rr < sample_threshold * 2:
131 | # no instructions, unconditional generation.
132 | final.append({"instruction": "", "input": "", "output": tweet})
133 | else:
134 | # do nothing, as we have made the sampling successfully
135 | pass
136 |
137 | # now proceed to the openai generation
138 | openai_process_list = []
139 | for item in tweets:
140 | tweet = item["text"]
141 | if checkResponse(tweet):
142 | openai_process_list.append(item)
143 |
144 | # now invoke a threadpool to process the openai generation
145 | with futures.ThreadPoolExecutor(max_workers=config.OPENAI_THREADS) as executor:
146 | results = executor.map(openai_process_original, openai_process_list)
147 |
148 | for r in results:
149 | if r:
150 | tweet, response = r
151 | if iteration_count < len(response):
152 | final.append({"instruction": response[iteration_count], "input": "", "output": tweet})
153 | else:
154 | # we do not have enough cached openai prompts
155 | pass
156 |
157 | return final
--------------------------------------------------------------------------------
/prompt_util.py:
--------------------------------------------------------------------------------
1 | import re
2 | import random
3 | import config
4 |
5 | import jieba
6 | import jieba.posseg as pseg
7 |
8 |
9 | original_post_prompt = [
10 | "最近过得怎么样?",
11 | "你这段时间都在忙什么?",
12 | "离我们上次见面以来,有什么新鲜事吗?",
13 | "最近有什么有趣的经历吗?",
14 | "你近期的生活如何?",
15 | "距离我们上次见面已经过了一段时间了,你过得好吗?",
16 | "你最近都在做些什么?",
17 | "近期有什么重要的变化吗?",
18 | "有什么新的发展吗?",
19 | "能和我分享一下你近期的生活吗?",
20 | "你最近有什么特别的事情发生吗?",
21 | "你最近的生活有什么变化吗?",
22 | "最近有没有什么大事发生?",
23 | "过去的几个月里,你都经历了些什么?",
24 | "你过去一段时间都在忙些什么?",
25 | "有什么新鲜事要和我分享吗?",
26 | "这段时间有没有什么让你难忘的事情发生?",
27 | "最近在工作或生活方面有什么挑战吗?",
28 | "从我们上次见面到现在,你都遇到了哪些有意思的事情?",
29 | "过去几个月你都忙些什么?",
30 | "最近你都有哪些新发现?",
31 | "你近期有哪些值得分享的经历?",
32 | "跟我聊聊你最近的日常吧?",
33 | "近期有什么值得一提的事吗?",
34 | "这段时间有没有什么特别的回忆?",
35 | "最近有没有什么让你感到骄傲的事?",
36 | "你最近有什么惊喜发生吗?",
37 | "过去一段时间,你的生活中有哪些亮点?",
38 | "近来你都有哪些新的尝试?",
39 | "你最近的日子过得怎样?有什么精彩的时刻?",
40 | "这段时间有什么特别的见闻吗?",
41 | "从上次见面到现在,你有什么有趣的故事吗?",
42 | "最近有没有什么特殊的事情让你感到开心?",
43 | "你最近有什么新的梦想或目标吗?",
44 | "过去的一段时间里,你有什么难忘的经历吗?",
45 | "你最近有没有什么特别的发现或灵感?",
46 | "最近有没有什么突破性的成就?",
47 | "这段时间你都学到了哪些新知识?",
48 | "最近有没有什么值得一提的事情发生在你身边?",
49 | "最近你有什么值得庆祝的事情吗?",
50 | "从我们上次见面以来,你都有哪些新的探险?",
51 | "最近有没有什么让你感到充实的事情?",
52 | "过去一段时间里,你有什么特别的体验?",
53 | "你近期有什么令人惊讶的收获?",
54 | "跟我分享一下你最近的喜怒哀乐吧?"
55 | ]
56 |
57 | related_topic_prompt = [
58 | "关于[AAA],你有什么想法?",
59 | "你觉得[AAA]在日常生活中扮演着什么角色?",
60 | "你曾经与[AAA]有过什么相关经历吗?",
61 | "你觉得[AAA]有哪些相关的话题值得探讨?",
62 | "关于[AAA],你有什么好的经验可以分享吗?",
63 | "你如何看待[AAA]与我们生活中的其他方面的关系?",
64 | "你觉得[AAA]有哪些值得关注的趋势?",
65 | "你认为[AAA]对于我们的日常生活有多大意义?",
66 | "你觉得[AAA]有哪些令人好奇的方面?",
67 | "你有没有关于[AAA]的有趣故事?",
68 | "你觉得[AAA]有哪些令人欣赏的方面?",
69 | "关于[AAA],你觉得有哪些有趣的讨论点?",
70 | "关于[AAA],你有什么新的见解吗?",
71 | "你觉得[AAA]是如何影响你生活的?",
72 | "你曾经有过关于[AAA]的有趣经历吗?",
73 | "关于[AAA],有哪些热门话题值得关注?",
74 | "你有没有关于[AAA]的特别记忆?",
75 | "你觉得[AAA]与你生活中的其他元素有何联系?",
76 | "你觉得[AAA]有哪些令人惊讶的方面?",
77 | "你有没有听过关于[AAA]的趣闻轶事?",
78 | "关于[AAA],你觉得有哪些引人入胜的话题?",
79 | "你对[AAA]有哪些深刻的印象?",
80 | "你有没有关于[AAA]的奇闻异事?",
81 | "关于[AAA],你有什么特别的见解吗?",
82 | "你有没有关于[AAA]的特别喜好?",
83 | "你对[AAA]有哪些独到的见解?",
84 | "你觉得[AAA]如何影响我们的思考方式?",
85 | "你有没有关于[AAA]的趣味事例?",
86 | "你觉得[AAA]有哪些令人叹为观止的特点?",
87 | "关于[AAA],你有什么富有启发性的想法?",
88 | "你有没有关于[AAA]的奇特经历?",
89 | "你觉得[AAA]有哪些吸引人的特质?",
90 | "关于[AAA],你有什么令人兴奋的见解?"
91 | ]
92 |
93 | def cut_sent(text):
94 | sub_sentences = re.split(r'([\。|\!|\?|\;|\,|\n])', text)
95 | sub_sentences = [s1 + s2 for s1, s2 in zip(sub_sentences[0::2], sub_sentences[1::2])] + ([sub_sentences[-1]] if len(sub_sentences) % 2 == 1 else [])
96 | return [s.replace("\n", "") for s in sub_sentences if (s.strip() != "" and s.strip() != "(media)" and s.strip() != "(link)")]
97 |
98 |
99 | def findTopic(md):
100 | substring = cut_sent(md)
101 | if len(substring) > 1:
102 | rr = random.randint(0, len(substring)-1)
103 | # choose a random substring, maximum length is 5
104 | topic = substring[rr]
105 | # cut off the last word, if it's a chinese punctuation
106 | if topic[-1] in [",", "。", "!", "?", ";"]:
107 | topic = topic[:-1]
108 |
109 | # use jieba to do the word segmentation and pos tagging
110 | tokens = [word for word,flag in pseg.cut(topic) if 'n' in flag]
111 | if len(tokens) > 0:
112 | topic = random.choice(tokens)
113 | else:
114 | # it's just it.
115 | topic = topic
116 |
117 | instruction = random.choice(related_topic_prompt).replace("[AAA]", topic)
118 | return instruction
119 | else:
120 | return None
121 |
122 |
123 | def checkResponse(response):
124 | # check if the user's response is too short. Filter it out.
125 | if len(response.replace("\n","").replace(" ", "").replace("(media)","").replace("(link)", "")) < config.RESPONSE_THRESH:
126 | return False
127 | return True
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # int8
2 | bitsandbytes==0.38.1
3 | accelerate==0.17.1
4 |
5 | # chatglm
6 | protobuf>=3.19.5,<3.20.1
7 | transformers==4.27.1
8 | icetk
9 | cpm_kernels==1.0.11
10 | torch>=1.13.1
11 |
12 | #
13 | datasets==2.10.1
14 | git+https://github.com/huggingface/peft.git # 最新版本 >=0.3.0.dev0
15 |
16 | # Chinese tokenization
17 | jieba
18 |
19 | # Advanced twitter scraping (beta)
20 | selenium
21 | beautifulsoup4
22 | webdriver-manager
23 |
24 | # parsing and sampling (should be enabled by default)
25 | numpy
26 | openai
--------------------------------------------------------------------------------
/scrape_twitter.py:
--------------------------------------------------------------------------------
1 | # Warning: research purposes only. This code is not intended for production use. Use at your own risk.
2 | # Scraping Twitter may violate the Terms of Service.
3 |
4 | import multiprocessing
5 | import re
6 | import threading
7 | import zipfile
8 | from selenium import webdriver
9 | from webdriver_manager.chrome import ChromeDriverManager
10 | from bs4 import BeautifulSoup
11 | import time
12 | import concurrent.futures
13 | import json
14 | import os
15 | import config
16 |
17 | from selenium.webdriver.common.by import By
18 | from selenium.webdriver.support.ui import WebDriverWait
19 | from selenium.webdriver.support import expected_conditions as EC
20 | from selenium.webdriver.common.proxy import Proxy, ProxyType
21 | from selenium.webdriver.common.keys import Keys
22 |
23 | class document_is_ready(object):
24 | def __call__(self, driver):
25 | ready_state = driver.execute_script("return document.readyState")
26 | return ready_state == "complete"
27 |
28 | def get_tweet_info(tweet):
29 |
30 | # check out the cache first,
31 | # if hit, return the cached result
32 | tweet_id = tweet['id']
33 |
34 | try:
35 | with open(f'tweet_cache/{tweet_id}.txt', 'r') as f:
36 | return {'id': tweet_id, 'text': tweet['text'], 'context': json.load(f)}
37 | except:
38 | pass
39 |
40 | if config.SCRAPE_CACHE_ONLY:
41 | return None
42 |
43 | driver = get_driver()
44 |
45 | try:
46 | url = f"https://mobile.twitter.com/ljsabc/status/{tweet_id}"
47 | driver.get(url)
48 | WebDriverWait(driver, config.SCRAPE_TIMEOUT).until(EC.presence_of_element_located((By.XPATH, '//*[@data-testid="bookmark"]')))
49 |
50 | # Wait for DOM to be ready
51 | wait = WebDriverWait(driver, config.SCRAPE_TIMEOUT)
52 | wait.until(document_is_ready())
53 |
54 | body_element = driver.find_element(By.TAG_NAME, "body")
55 |
56 |
57 | for j in range(2):
58 | # looks like we need to scroll up a few times to get the full context
59 | # 2 times for scroll up should be okay for most of the tweets
60 | # if you tweet longer you may try a longer range
61 | for j in range(3):
62 | body_element.send_keys(Keys.PAGE_UP)
63 | time.sleep(0.15)
64 | # introduce a small delay to let the page load
65 | # the delays can be hidden by using more threads
66 | time.sleep(1.0)
67 |
68 | soup = BeautifulSoup(driver.page_source, 'html.parser')
69 | articles = soup.find_all('article')
70 |
71 | if articles is None:
72 | # weird, probably a deleted tweet
73 | # or, possibly web driver failed to load the page
74 | # at this moment, we just return None and let the caller handle it
75 | return None
76 |
77 | target_index = -1
78 | for i, article in enumerate(articles):
79 | if article.find(lambda tag: tag.get('data-testid') == 'bookmark'):
80 | target_index = i
81 | break
82 |
83 | # without scroll to top, the tweet context may not be complete
84 | # but as we are only targetting a small context range, it should not hurt.
85 | print(f"id: {tweet_id}, located {target_index + 1} of {len(articles)} tweets.")
86 |
87 | results = []
88 | if target_index >= 0:
89 | for i, article in enumerate(articles):
90 | if i < target_index:
91 | #f True:
92 | target = article.find('div', {'data-testid': 'tweetText'})
93 | if target:
94 | tweet_text = target.get_text()
95 | print(i, tweet_text)
96 | results.append(tweet_text)
97 |
98 | # dump the results to a file
99 | with open(f'tweet_cache/{tweet_id}.txt', 'w') as f:
100 | json.dump(results, f, indent=4, ensure_ascii=False)
101 |
102 | print(f"Original tweet: {tweet['text']}\n\n")
103 |
104 | return {'id': tweet_id, 'text': tweet['text'], 'context': results}
105 | except Exception as e:
106 | print(e)
107 | #return {'id': tweet_id, 'text': tweet['text'], 'context': None}
108 | return None
109 |
110 |
111 |
112 | threadLocal = threading.local()
113 |
114 | def get_driver():
115 | driver = getattr(threadLocal, 'driver', None)
116 | if driver is None:
117 | mobile_user_agent = 'Mozilla/5.0 (Linux; Android 9; Pixel 3 Build/PQ3A.190705.001) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/75.0.3770.101 Mobile Safari/537.36'
118 | chrome_options = webdriver.ChromeOptions()
119 | chrome_options.add_argument(f'user-agent={mobile_user_agent}')
120 | prefs = {"profile.managed_default_content_settings.images": 2}
121 | chrome_options.add_experimental_option("prefs", prefs)
122 | chrome_options.add_argument(f"user-data-dir=chrome_profile/{threading.current_thread().ident}")
123 |
124 | # The window is place on purpose to avoid "error request".
125 |
126 | if config.PROXY:
127 | #proxy or proxy pool, it will be really useful
128 | pluginfile = 'proxy_auth_plugin.zip'
129 | manifest_json = """
130 | {
131 | "version": "1.0.0",
132 | "manifest_version": 2,
133 | "name": "Chrome Proxy",
134 | "permissions": [
135 | "proxy",
136 | "tabs",
137 | "unlimitedStorage",
138 | "storage",
139 | "",
140 | "webRequest",
141 | "webRequestBlocking"
142 | ],
143 | "background": {
144 | "scripts": ["background.js"]
145 | },
146 | "minimum_chrome_version":"22.0.0"
147 | }
148 | """
149 |
150 | background_js = """
151 | var config = {
152 | mode: "fixed_servers",
153 | rules: {
154 | singleProxy: {
155 | scheme: "http",
156 | host: "%s",
157 | port: parseInt(%s)
158 | },
159 | bypassList: ["localhost"]
160 | }
161 | };
162 |
163 | chrome.proxy.settings.set({value: config, scope: "regular"}, function() {});
164 |
165 | function callbackFn(details) {
166 | return {
167 | authCredentials: {
168 | username: "%s",
169 | password: "%s"
170 | }
171 | };
172 | }
173 |
174 | chrome.webRequest.onAuthRequired.addListener(
175 | callbackFn,
176 | {urls: [""]},
177 | ['blocking']
178 | );""" % (config.PROXY_ADDR, config.PROXY_PORT, config.PROXY_USER, config.PROXY_PASSWD)
179 |
180 |
181 | with zipfile.ZipFile(pluginfile, 'w') as zp:
182 | zp.writestr("manifest.json", manifest_json)
183 | zp.writestr("background.js", background_js)
184 | chrome_options.add_extension(pluginfile)
185 | driver = webdriver.Chrome(executable_path='./chromedriver', options=chrome_options)
186 | else:
187 | driver = webdriver.Chrome(executable_path='./chromedriver', options=chrome_options)
188 | driver.set_window_size(340, 695)
189 | setattr(threadLocal, 'driver', driver)
190 |
191 | return driver
192 |
193 | def process_tweet_ids(tweets):
194 | # create a cache folder to store the scraped tweets
195 | if not os.path.exists('tweet_cache'):
196 | os.mkdir('tweet_cache')
197 |
198 | if not os.path.exists('chrome_profile'):
199 | os.mkdir('chrome_profile')
200 | else:
201 | # remove the dir, then create a new one
202 | # this is to avoid the cache issue
203 | import shutil
204 | shutil.rmtree('chrome_profile')
205 | os.mkdir('chrome_profile')
206 |
207 | tweet_data = []
208 | pool = concurrent.futures.ThreadPoolExecutor(max_workers=config.PROCESSES)
209 | for result in pool.map(get_tweet_info, tweets):
210 | if result is not None:
211 | tweet_data.append(result)
212 |
213 | return tweet_data
214 |
215 | if __name__ == '__main__':
216 | # Replace with a list of tweet IDs you want to fetch information for
217 | tweet_ids = ['1639865895374958592', '1639689372059725825', '1639488084705439745', '1639639351897522176']
218 | tweet_data = process_tweet_ids(tweet_ids)
219 |
220 | for data in tweet_data:
221 | print(data)
--------------------------------------------------------------------------------
/tokenize_dataset_rows.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from tqdm import tqdm
4 |
5 | import datasets
6 | import transformers
7 |
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--json_path", type=str, default="data/alpaca_data.jsonl")
11 | parser.add_argument("--save_path", type=str, default="data/alpaca")
12 | parser.add_argument("--max_seq_length", type=int, default=384)
13 | parser.add_argument("--skip_overlength", type=bool, default=False)
14 | args = parser.parse_args()
15 |
16 | model_name = "THUDM/chatglm-6b"
17 | tokenizer = transformers.AutoTokenizer.from_pretrained(
18 | model_name, trust_remote_code=True)
19 | config = transformers.AutoConfig.from_pretrained(
20 | model_name, trust_remote_code=True, device_map='auto')
21 |
22 | def format_example(example: dict) -> dict:
23 | context = f"Instruction: {example['instruction']}\n"
24 | if example.get("input"):
25 | context += f"Input: {example['input']}\n"
26 | context += "Answer: "
27 | target = example["output"]
28 | return {"context": context, "target": target}
29 |
30 |
31 | def preprocess(tokenizer, config, example, max_seq_length):
32 | example = format_example(example)
33 | prompt = example["context"]
34 | target = example["target"]
35 | prompt_ids = tokenizer.encode(prompt, max_length=max_seq_length, truncation=True)
36 | target_ids = tokenizer.encode(
37 | target,
38 | max_length=max_seq_length,
39 | truncation=True,
40 | add_special_tokens=False)
41 | input_ids = prompt_ids + target_ids + [config.eos_token_id]
42 | return {"input_ids": input_ids, "seq_len": len(prompt_ids)}
43 |
44 |
45 | def read_jsonl(path, max_seq_length, skip_overlength=False):
46 | model_name = "THUDM/chatglm-6b"
47 | tokenizer = transformers.AutoTokenizer.from_pretrained(
48 | model_name, trust_remote_code=True)
49 | config = transformers.AutoConfig.from_pretrained(
50 | model_name, trust_remote_code=True, device_map='auto')
51 | with open(path, "r") as f:
52 | for line in tqdm(f.readlines()):
53 | example = json.loads(line)
54 | feature = preprocess(tokenizer, config, example, max_seq_length)
55 | if skip_overlength and len(feature["input_ids"]) > max_seq_length:
56 | continue
57 | feature["input_ids"] = feature["input_ids"][:max_seq_length]
58 | yield feature
59 |
60 | def parse(element):
61 | feature = preprocess(tokenizer, config, element, args.max_seq_length)
62 | feature["input_ids"] = feature["input_ids"][:args.max_seq_length]
63 | return feature
64 |
65 |
66 | def main():
67 | dataset = datasets.load_dataset("json", data_files=args.json_path)
68 | train_data = dataset["train"].shuffle().map(parse, num_proc=4)
69 | train_data.save_to_disk(args.save_path)
70 |
71 | #dataset = datasets.Dataset.from_generator(
72 | # lambda: read_jsonl(args.jsonl_path, args.max_seq_length, args.skip_overlength),
73 | #)
74 | #dataset.save_to_disk(args.save_path)
75 |
76 | # poorly written generator, should better mapped, I guess.
77 | # it ignores the updates in the same jsonl file
78 | #dataset.cleanup_cache_files()
79 |
80 |
81 | if __name__ == "__main__":
82 | main()
83 |
--------------------------------------------------------------------------------
/tweets_sample.md:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "instruction": "想让父母认同计算机是生产力工具这个观点还真是难啊……",
4 | "input": "",
5 | "output": "这个就尽量死心吧…\n我娘和我住了几个月看我每天对着电脑摔东西就知道电脑真的可以用来干活儿了…"
6 | },
7 | {
8 | "instruction": "然后发觉所谓的AI其实就是一套很简单的自动控制系统 https://twitter.com/thirtydb/status/1511244138112765957…\nAI=自动化机床专业",
9 | "input": "",
10 | "output": "其实没错的!"
11 | },
12 | {
13 | "instruction": "o365的在线多人协作是真的好用…\n而且现在还有自动补全和语法检查了,非常顺",
14 | "input": "",
15 | "output": "另外这个时候3:2屏幕的优势就出来了"
16 | },
17 | {
18 | "instruction": "最近有什么有趣的经历吗?",
19 | "input": "",
20 | "output": "我觉得我得在acknowledgement里面写上感谢waifu2x"
21 | },
22 | {
23 | "instruction": "",
24 | "input": "",
25 | "output": "zij sounds like some math notations\n\nDijkstra"
26 | },
27 | {
28 | "instruction": "跟我聊聊你最近的日常吧?",
29 | "input": "",
30 | "output": "满脑子里都是那个YEAH\n\n另外这曲子有十几个Remix,最早听的版本果然还是BO2在酒吧里面DJ放的那个版本,在最后credit roll里面也有完整版重放。 \n\n(media)"
31 | },
32 | {
33 | "instruction": "最近过得怎么样?",
34 | "input": "",
35 | "output": "想写个tensorboard的iOS客户端之类的东西…"
36 | },
37 | {
38 | "instruction": "这都是东方的莫名含蓄造成的结果啊(趴\n是,感觉很奇怪",
39 | "input": "",
40 | "output": "主要是禁止早恋但是又不跟进教育讲清楚为啥不能早恋……"
41 | },
42 | {
43 | "instruction": "过去的几个月里,你都经历了些什么?",
44 | "input": "",
45 | "output": "吹蛤爷我不反对,但是吹也要按照基本法啊\n要避免跟瞎吹的人比爬山 比拔河\n他姿势不如你 但是身体比你好啊"
46 | },
47 | {
48 | "instruction": "嗯你说的对呀 https://t.co/DmixTHi2sf",
49 | "input": "",
50 | "output": "可是要怎么改呢…(沉思"
51 | },
52 | {
53 | "instruction": "直到懂车之后还是不喜欢F1的主要原因是\n我觉得这东西开着不像车=、=",
54 | "input": "",
55 | "output": "问个类似的问题,你觉得LMP1像车么……"
56 | },
57 | {
58 | "instruction": "你近期有什么令人惊讶的收获?",
59 | "input": "",
60 | "output": "天下果然没有这样的好事儿 \n\n(media)"
61 | },
62 | {
63 | "instruction": "",
64 | "input": "",
65 | "output": "复兴-建国 \n>RT"
66 | },
67 | {
68 | "instruction": "120Hz是文明的啊…",
69 | "input": "",
70 | "output": "完了我已经回不去了"
71 | },
72 | {
73 | "instruction": "映画「君の名は。」中国語吹き替え版予告編 https://youtu.be/0Qt5OecQk-I via @YouTube\n这是中国上映的那个吹替版?怎么感觉CV好业余。",
74 | "input": "",
75 | "output": "就是好业余啊……"
76 | },
77 | {
78 | "instruction": "关于video,你觉得有哪些引人入胜的话题?",
79 | "input": "",
80 | "output": "Lets save your fucking planet earth \n\n(video)"
81 | },
82 | {
83 | "instruction": "跟我分享一下你最近的喜怒哀乐吧?",
84 | "input": "",
85 | "output": "Everything is just so shitty that I cannot get rid of these smells"
86 | },
87 | {
88 | "instruction": "离我们上次见面以来,有什么新鲜事吗?",
89 | "input": "",
90 | "output": "我foing掉了10个("
91 | },
92 | {
93 | "instruction": "你看到三个拳头的直接B了就好千万别留情面\n其实什么叫三个全头",
94 | "input": "",
95 | "output": "你看他粉丝的昵称"
96 | },
97 | {
98 | "instruction": "下面的4599咋回事",
99 | "input": "",
100 | "output": "4600U的价格啊"
101 | },
102 | {
103 | "instruction": "关于头文字,你觉得有哪些有趣的讨论点?",
104 | "input": "",
105 | "output": "(link)\n头文字D\n\n// 可以看秋名山开车 开心"
106 | },
107 | {
108 | "instruction": "有iPad 用户留言\n\n“为什么我原来没发现”\n\n因为文件管理器加入的比较晚\n\n顺便光记住IP地址没用啊\n\n要确保你局域网内IP地址不是随意分配的啊",
109 | "input": "",
110 | "output": "一般第一次之后就不会变了…\n不过直接搞个nas不香嘛"
111 | },
112 | {
113 | "instruction": "有什么新鲜事要和我分享吗?",
114 | "input": "",
115 | "output": "还是孕妇装出门QAQ \n\n(media)"
116 | },
117 | {
118 | "instruction": "你最近的日子过得怎样?有什么精彩的时刻?",
119 | "input": "",
120 | "output": "我要是有辆车就好了"
121 | },
122 | {
123 | "instruction": "请不要在戏院里食用 小小酥\n为什么你可以发语音!\n现在还有吗这个 感觉都好久远了\n有 真的有\n还有那个 兵卒一口脆 不就是圆块儿的方便面嘛",
124 | "input": "",
125 | "output": "那个也还有…\n我有一次看见了都快哭了\n\n不过买了一袋尝尝 觉得没有以前好吃了"
126 | },
127 | {
128 | "instruction": "碰气!\n你不说我都快忘记他叫什么了\n就记得是碰气碰气",
129 | "input": "",
130 | "output": "我的台语都是你们教的!\n\n业务员问我:你是北京人么!\n我:是一个比北京还要靠北的地方\n韩先生:哈哈哈哈哈\n\n("
131 | },
132 | {
133 | "instruction": "我真的不会拼misc全称……\n可耻",
134 | "input": "",
135 | "output": "我打misc它也给我拼写检查……变成了misc.\n("
136 | },
137 | {
138 | "instruction": "这DPI太温柔了……",
139 | "input": "",
140 | "output": "不对,是一样的……"
141 | },
142 | {
143 | "instruction": "",
144 | "input": "",
145 | "output": "spin大师"
146 | },
147 | {
148 | "instruction": "离我们上次见面以来,有什么新鲜事吗?",
149 | "input": "",
150 | "output": "我已经欠了十数位“回来了一定去你家”的约定了"
151 | },
152 | {
153 | "instruction": "索尼这个a系列单反样片app似乎是支持HDR的\n看起来真是太带感了\n\n好了有了一个新的买的理由了",
154 | "input": "",
155 | "output": "这个可以有!(赞许 \n\n(media)"
156 | },
157 | {
158 | "instruction": "",
159 | "input": "",
160 | "output": "所以还是作者心绪的投射吧。也是我自己心灵的投射。\n\n14岁的时候读不懂《18岁出门远行》\n等到24岁的时候,就只有一声叹息了。"
161 | },
162 | {
163 | "instruction": "草还真有啊\n藍店\n反正我没去过\n淡定淡定,重點還是價格合理營商老實",
164 | "input": "",
165 | "output": "我去旁边阿布泰它不香么(w \n\n(video)"
166 | },
167 | {
168 | "instruction": "Apollo Pro Plus(100应该是)\n\nCPU是铜矿,比我家的370赛扬1.0(图拉丁)早一代\n\n我家那台Apollo Pro 133+图拉丁三年前还能boot,估计现在应该也能……\n370最经典的还是300A可以超到450那一批,直接把旗舰奔腾2给收割了",
169 | "input": "",
170 | "output": "那时候的赛扬体质都超级好……"
171 | },
172 | {
173 | "instruction": "关于照片,你有什么特别的见解吗?",
174 | "input": "",
175 | "output": "欸喵 好期待w"
176 | },
177 | {
178 | "instruction": "看着比PI都牛逼,不知道以为是ECCV的AC呢…\n真当了ac还有空水zhihu吗",
179 | "input": "",
180 | "output": "倒是的确没见过哪个ac吹自己年级第一的"
181 | },
182 | {
183 | "instruction": "我去这个防水眼线笔好厉害\n现在都没有花\n\n补一下粉底又是条好汉(",
184 | "input": "",
185 | "output": "是呢 而且还有点流汗什么的w"
186 | },
187 | {
188 | "instruction": "你对机箱有哪些深刻的印象?",
189 | "input": "",
190 | "output": "算了,与其折腾没有保修的4090 Blower,不如直接上4块4090 AIO算了,反正两个固定在机箱上两个支出来一样用"
191 | },
192 | {
193 | "instruction": "设计不能么?",
194 | "input": "",
195 | "output": "懂了\n…也就是说我过关还要过条街啊"
196 | },
197 | {
198 | "instruction": "你们一起玩了么!w",
199 | "input": "",
200 | "output": "哇啊啊啊羡慕死了"
201 | },
202 | {
203 | "instruction": "我妈每天都在想:你天天发自曝就不怕别人笑话你?",
204 | "input": "",
205 | "output": "反正我从小到大被歧视最多的就tm家里了 必须服"
206 | },
207 | {
208 | "instruction": "跟你讲,ARM64笔记本,除了做工没有MacBook 12“好之外(你要是等也有三爽的Galaxy Book S),其他几乎爽翻天\n\nCPU虽然捞但是8个蛋加在一起不比Core M差,最关键是永远都不会throttle\n满载续航起码三个小时,上网看报10+小时起步\n用来吃饭干活儿的工具都有Native版本(至少对我来说全了)\n还有快充\n有什么推荐的型号,我找类似的替代品找很久了。",
209 | "input": "",
210 | "output": "便宜的话Huawei的Matebook E就很不错\n如果对价格不在意那就等Galaxy Book S吧\n\n不过要注意的是写代码上网看视频都行,但是打游戏或者开PS可能就算了吧("
211 | },
212 | {
213 | "instruction": "你觉得gitp有哪些令人欣赏的方面?",
214 | "input": "",
215 | "output": "$ gitp ull\n$ fuck"
216 | },
217 | {
218 | "instruction": "爽!好穿!好看!\n\n*我在奥特莱斯买的",
219 | "input": "",
220 | "output": "等我打扮好看了我给你返图"
221 | },
222 | {
223 | "instruction": "“不行 就两天”\n佛说",
224 | "input": "",
225 | "output": "三天之内撒了你"
226 | },
227 | {
228 | "instruction": "大部分人的话都是左边会大一点的 没毛病\n诶诶是这样么……\n但是(",
229 | "input": "",
230 | "output": "但是我就是大很多啊!(绝望\n\n(link)\n// 众人的反应"
231 | },
232 | {
233 | "instruction": "关于卡,你有什么富有启发性的想法?",
234 | "input": "",
235 | "output": "之前记得有一个科学研究是说喜欢的音乐类型可能和听众的学力水平有相关性\n\n我在想会不会有一种可能是只有家境、教育到了一种水平才会接触到某种音乐类型?\n我父母就永远理解不了电子乐的verse,他们觉得是卡碟了……所以他们也不会推荐给我听"
236 | },
237 | {
238 | "instruction": "啥情況啊……?\n哎……我老闆總是教育我,“你得熟悉自己不喜歡做的事情”,然而這麽下去我就也不喜歡我自己了大概",
239 | "input": "",
240 | "output": "他跟我性格其實有點像的,都是技術流派的人,結果做了professor每天開會帶學生做的都是不喜歡的事情,我真不知道他是畫餅還是自己真的接受這種生活了……\n\n反正我不行"
241 | },
242 | {
243 | "instruction": "",
244 | "input": "",
245 | "output": "榨干.jpg居然还有续集的么!"
246 | },
247 | {
248 | "instruction": "这段时间有没有什么特别的回忆?",
249 | "input": "",
250 | "output": "白天俩猫活蹦乱跳\n夜里就是俩死猪"
251 | },
252 | {
253 | "instruction": "你们女装癖是不是都爱养猫啊\n\n// 这心态真的太能理解了",
254 | "input": "",
255 | "output": "二手玫瑰的吉他手\n(link)"
256 | },
257 | {
258 | "instruction": "靠这烂尾比liar game还…\nliar game剧场版还知道玩家和庄家对赌剧情,这烂尾的什么玩意儿\n幕后黑手是谁,神崎直他爸和中国黑客啊",
259 | "input": "",
260 | "output": "动物世界我能给到7.5\nliar game一二剧场版能给到8.5\n这个最多7\n\n抢椅子游戏好玩啊 然而换女主了不是我最爱的Erika了呜呜呜"
261 | },
262 | {
263 | "instruction": "这东西超级棒啊!w\n他家咸的小饼咸的很到位…辣的也很到位\n以前补课还有考研的时候经常吃…",
264 | "input": "",
265 | "output": "嗯!我上中学就分为农心派和格力高派wwww"
266 | },
267 | {
268 | "instruction": "",
269 | "input": "",
270 | "output": "下一个问题\n我手台在哪儿"
271 | },
272 | {
273 | "instruction": "垃圾学校一辈子都没有地铁\n南门在北门的北面(",
274 | "input": "",
275 | "output": "好像是每栋楼的门都标出来了"
276 | },
277 | {
278 | "instruction": "卧槽又有喜闻乐见的男娘题材",
279 | "input": "",
280 | "output": "重案六组4神了!\n还有白学剧情的! \n\n(media)\n\n(media)\n\n(media)\n\n(media)"
281 | },
282 | {
283 | "instruction": "40% 人选择了 O,意味着他们不愿意在 PS5 上玩 PS4 游戏。\n30% 人选择了 X,意味着他们不愿意 PS4 游戏运行在 PS5 上。\n\n综上,索尼罪恶滔天!",
284 | "input": "",
285 | "output": "你这个投票赢麻了"
286 | },
287 | {
288 | "instruction": "饼姐姐生快!",
289 | "input": "",
290 | "output": "我也爱你和狗子!"
291 | },
292 | {
293 | "instruction": "",
294 | "input": "",
295 | "output": "小时候就说你跟别人不一样你只要努力长大以后就一定能成为栋梁之材只手遮天\n后来我发现只手遮天的人有千千万,不差我一个"
296 | },
297 | {
298 | "instruction": "https://zh.wikipedia.org/wiki/%E6%B7%B1%E5%9C%B3%E5%B7%B4%E5%A3%AB%E8%B7%AF%E7%BA%BF%E5%88%97%E8%A1%A8…\n自从东莞公汽成功收编了以后现行开设了不少跨市公交,比如长安、虎门公汽进松岗,光明去黄江,雁田去凤岗等等,还有两元的龙岗到惠州公交\n大湾区(不含香港)也许有一天真的会来\n据了解,近年来,我市积极融入粤港澳大湾区交通一体化布局,自2019年1月,东莞巴士公司开通首条跨市公交线路——876路(后更名“莞75路”,起讫站为国贸中心公交站和新塘白江地铁站)起,截至目前,共开通14条跨市线路,加快了我市融入粤港澳大湾区及深莞惠一体化的步伐。",
299 | "input": "",
300 | "output": "其中,莞深跨市线路7条均接驳深圳地铁,莞穗跨市线路2条,莞惠跨市线路5条,覆盖虎门高铁站、长安北站、东莞火车站、国贸中心首末站、凤岗车站等市内大型交通枢纽。"
301 | },
302 | {
303 | "instruction": "键盘摔在地上,然后PC就休眠了,是撞到了休眠的快捷键?这个还有快捷键?",
304 | "input": "",
305 | "output": "有的呀 多媒体键w"
306 | },
307 | {
308 | "instruction": "斑点猫!尾巴也是斑点,腿也是斑点!\n刚刚把腿看成腰了 还找了半天!",
309 | "input": "",
310 | "output": "反例:细长的尾巴条纹的腿 \n\n(media)"
311 | },
312 | {
313 | "instruction": "我爱你啊我的母校\n这啥(",
314 | "input": "",
315 | "output": "texlive,可以理解为每年的tex工具合集"
316 | },
317 | {
318 | "instruction": "好作文都是些大人眼中的好作文\n大人是不会欣赏这种乱七八糟的兴趣的,孩子也不会么\n话说回来这篇可真是四不像\n叙述毫无亮点,读起来颇像蒋公日记\n议论连小学生水平都没有,可笑啊",
319 | "input": "",
320 | "output": "比较无趣的是政治挂帅,观点不重要,重要的是是不是迎合当时的主流思想"
321 | },
322 | {
323 | "instruction": "有什么新的发展吗?",
324 | "input": "",
325 | "output": "哎对呀,steam deck+Xbox Remote\n几乎除了索尼独占直接通吃了啊\n\n(老任还得自己买就是了)"
326 | },
327 | {
328 | "instruction": "当时Anandtech评测了15W和25W的1065G7,看到25W的成绩我都快哭了,大喊出了一定买\n\n结果美帝真香想直接给你一个35W的1065G7尝尝……\n小新pro是可以35w的10710,1065g7有35w tdp的版本吗?",
329 | "input": "",
330 | "output": "PL1解了想多少多少啊"
331 | },
332 | {
333 | "instruction": "现在这个社会太可怕啦\n我不能接受.gif https://twitter.com/MikaAkagi/status/1063223977282158592…",
334 | "input": "",
335 | "output": "另外你真的有鸡鸡诶!(失望"
336 | },
337 | {
338 | "instruction": "你最近有什么特别的事情发生吗?",
339 | "input": "",
340 | "output": "逐渐找回自己 \n\n(media)\n\n(media)"
341 | },
342 | {
343 | "instruction": "千 元 大 钞\n恕 不 找 续\n你要换港币么",
344 | "input": "",
345 | "output": "换啊 我各种需要rmb啊"
346 | },
347 | {
348 | "instruction": "你这段时间都在忙什么?",
349 | "input": "",
350 | "output": "【RV770背后的故事 - AMD小核心之路】The RV770 Story: Documenting ATI's Road to Success\nvia:(link), 地址:(link)\n\n// 恍如隔世,RV770都有快10年了\n// 还记得用4850挖矿的日子么?"
351 | },
352 | {
353 | "instruction": "你近期的生活如何?",
354 | "input": "",
355 | "output": "卓悦汇那个卖贡茶的男的xnmbyy,越说那个机器扫二维码扫不出来我可以直接扫,他就要按回去,他还一副理直气壮的样子,搞得我差点没跟他打起来"
356 | },
357 | {
358 | "instruction": "大家好我也是女装子\n\n#不要脸",
359 | "input": "",
360 | "output": "嘛 无所谓呀。"
361 | },
362 | {
363 | "instruction": "卧槽我想用的正则居然一遍就写对了 nice",
364 | "input": "",
365 | "output": "常用的网上都有现成的啦\n有些奇形怪状的就要自己写了"
366 | },
367 | {
368 | "instruction": "发现有些美国人也开始把“龙”念成“多拉弓”了,不是“拽根”吗?\n这个界线好像已经很模糊了:\n对,“竜”是没翅膀的。\n这俩是一个字…\n简化字\n\nhttps://en.m.wiktionary.org/wiki/竜\n不过繁体龙似乎还是人名汉字,所以和花泽香菜一样还是旧字形",
369 | "input": "",
370 | "output": "不过年轻人也开始慢慢用新字形了,叫tatsuya写作竜也的越来越多了("
371 | },
372 | {
373 | "instruction": "受不了了带着队友们往前冲,",
374 | "input": "",
375 | "output": "死了之后发现家里还有一大块没涂就先涂一下我队友似乎觉得没有我一样很牛逼就继续死冲对面那当然是要g了啊"
376 | }
377 | ]
--------------------------------------------------------------------------------
/twitter-parser.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | twitter-archive-parser - Python code to parse a Twitter archive and output in various ways
4 | Copyright (C) 2022 Tim Hutton - https://github.com/timhutton/twitter-archive-parser
5 |
6 | This program is free software: you can redistribute it and/or modify
7 | it under the terms of the GNU General Public License as published by
8 | the Free Software Foundation, either version 3 of the License, or
9 | (at your option) any later version.
10 |
11 | This program is distributed in the hope that it will be useful,
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | GNU General Public License for more details.
15 |
16 | You should have received a copy of the GNU General Public License
17 | along with this program. If not, see .
18 | """
19 |
20 | from collections import defaultdict
21 | from dis import Instruction
22 | from typing import Optional
23 | from urllib.parse import urlparse
24 | import datetime
25 | import glob
26 | import importlib
27 | import json
28 | import logging
29 | import os
30 | import re
31 | import shutil
32 | import subprocess
33 | import sys
34 | import time
35 |
36 | from prompt import write_json
37 | # hot-loaded if needed, see import_module():
38 | # imagesize
39 | # requests
40 |
41 |
42 | # Print a compile-time error in Python < 3.6. This line does nothing in Python 3.6+ but is reported to the user
43 | # as an error (because it is the first line that fails to compile) in older versions.
44 | f' Error: This script requires Python 3.6 or later. Use `python --version` to check your version.'
45 |
46 |
47 | class UserData:
48 | def __init__(self, user_id: str, handle: str):
49 | if user_id is None:
50 | raise ValueError('ID "None" is not allowed in UserData.')
51 | self.user_id = user_id
52 | if handle is None:
53 | raise ValueError('handle "None" is not allowed in UserData.')
54 | self.handle = handle
55 |
56 |
57 | class PathConfig:
58 | """
59 | Helper class containing constants for various directories and files.
60 |
61 | The script will only add / change / delete content in its own directories, which start with `parser-`.
62 | Files within `parser-output` are the end result that the user is probably interested in.
63 | Files within `parser-cache` are temporary working files, which improve the efficiency if you run
64 | this script multiple times. They can safely be removed without harming the consistency of the
65 | files within `parser-output`.
66 | """
67 | def __init__(self, dir_archive):
68 | self.dir_archive = dir_archive
69 | self.dir_input_data = os.path.join(dir_archive, 'data')
70 | self.file_account_js = os.path.join(self.dir_input_data, 'account.js')
71 |
72 | # check if user is in correct folder
73 | if not os.path.isfile(self.file_account_js):
74 | print(f'Error: Failed to load {self.file_account_js}. ')
75 | exit()
76 |
77 | self.dir_input_media = find_dir_input_media(self.dir_input_data)
78 | self.dir_output = os.path.join(self.dir_archive, 'parser-output')
79 | self.dir_output_media = os.path.join(self.dir_output, 'media')
80 | self.dir_output_cache = os.path.join(self.dir_archive, 'parser-cache')
81 | self.file_output_following = os.path.join(self.dir_output, 'following.txt')
82 | self.file_output_followers = os.path.join(self.dir_output, 'followers.txt')
83 | self.file_download_log = os.path.join(self.dir_output_media, 'download_log.txt')
84 | self.file_tweet_icon = os.path.join(self.dir_output_media, 'tweet.ico')
85 | self.files_input_tweets = find_files_input_tweets(self.dir_input_data)
86 |
87 | # structured like an actual tweet output file, can be used to compute relative urls to a media file
88 | self.example_file_output_tweets = self.create_path_for_file_output_tweets(year=2020, month=12)
89 |
90 | def create_path_for_file_output_tweets(self, year, month, format="html", kind="tweets") -> str:
91 | """Builds the path for a tweet-archive file based on some properties."""
92 | # Previously the filename was f'{dt.year}-{dt.month:02}-01-Tweet-Archive-{dt.year}-{dt.month:02}'
93 | return os.path.join(self.dir_output, f"{kind}-{format}", f"{year:04}", f"{year:04}-{month:02}-01-{kind}.{format}")
94 |
95 | def create_path_for_file_output_dms(self, name: str, index: Optional[int]=None, format: str="html", kind: str="DMs") -> str:
96 | """Builds the path for a dm-archive file based on some properties."""
97 | index_suffix = ""
98 | if (index):
99 | index_suffix = f"-part{index:03}"
100 | return os.path.join(self.dir_output, kind, f"{kind}-{name}{index_suffix}.{format}")
101 |
102 | def create_path_for_file_output_single(self, format: str, kind: str)->str:
103 | """Builds the path for a single output file which, i.e. one that is not part of a larger group or sequence."""
104 | return os.path.join(self.dir_output, f"{kind}.{format}")
105 |
106 |
107 | def get_consent(prompt: str, default_to_yes: bool = False):
108 | """Asks the user for consent, using the given prompt. Accepts various versions of yes/no, or
109 | an empty answer to accept the default. The default is 'no' unless default_to_yes is passed as
110 | True. The default will be indicated automatically. For unacceptable answers, the user will
111 | be asked again."""
112 | if default_to_yes:
113 | suffix = " [Y/n]"
114 | default_answer = "yes"
115 | else:
116 | suffix = " [y/N]"
117 | default_answer = "no"
118 |
119 | # a temporary hack to make the script work in a non-interactive environment
120 | return False
121 |
122 | while True:
123 | user_input = input(prompt + suffix)
124 | if user_input == "":
125 | print (f"Your empty response was assumed to mean '{default_answer}' (the default for this question).")
126 | return default_to_yes
127 | if user_input.lower() in ('y', 'yes'):
128 | return True
129 | if user_input.lower() in ('n', 'no'):
130 | return False
131 | print (f"Sorry, did not understand. Please answer with y, n, yes, no, or press enter to accept "
132 | f"the default (which is '{default_answer}' in this case, as indicated by the uppercase "
133 | f"'{default_answer.upper()[0]}'.)")
134 |
135 |
136 | def import_module(module):
137 | """Imports a module specified by a string. Example: requests = import_module('requests')"""
138 | try:
139 | return importlib.import_module(module)
140 | except ImportError:
141 | print(f'\nError: This script uses the "{module}" module which is not installed.\n')
142 | if not get_consent('OK to install using pip?'):
143 | exit()
144 | subprocess.run([sys.executable, '-m', 'pip', 'install', module], check=True)
145 | return importlib.import_module(module)
146 |
147 |
148 | def open_and_mkdirs(path_file):
149 | """Opens a file for writing. If the parent directory does not exist yet, it is created first."""
150 | mkdirs_for_file(path_file)
151 | return open(path_file, 'w', encoding='utf-8')
152 |
153 |
154 | def mkdirs_for_file(path_file):
155 | """Creates the parent directory of the given file, if it does not exist yet."""
156 | path_dir = os.path.split(path_file)[0]
157 | os.makedirs(path_dir, exist_ok=True)
158 |
159 |
160 | def rel_url(media_path, document_path):
161 | """Computes the relative URL needed to link from `document_path` to `media_path`.
162 | Assumes that `document_path` points to a file (e.g. `.md` or `.html`), not a directory."""
163 | return os.path.relpath(media_path, os.path.split(document_path)[0]).replace("\\", "/")
164 |
165 |
166 | def get_twitter_api_guest_token(session, bearer_token):
167 | """Returns a Twitter API guest token for the current session."""
168 | guest_token_response = session.post("https://api.twitter.com/1.1/guest/activate.json",
169 | headers={'authorization': f'Bearer {bearer_token}'},
170 | timeout=2,
171 | )
172 | guest_token = json.loads(guest_token_response.content)['guest_token']
173 | if not guest_token:
174 | raise Exception(f"Failed to retrieve guest token")
175 | return guest_token
176 |
177 |
178 | def get_twitter_users(session, bearer_token, guest_token, user_ids):
179 | """Asks Twitter for all metadata associated with user_ids."""
180 | users = {}
181 | while user_ids:
182 | max_batch = 100
183 | user_id_batch = user_ids[:max_batch]
184 | user_ids = user_ids[max_batch:]
185 | user_id_list = ",".join(user_id_batch)
186 | query_url = f"https://api.twitter.com/1.1/users/lookup.json?user_id={user_id_list}"
187 | response = session.get(query_url,
188 | headers={'authorization': f'Bearer {bearer_token}', 'x-guest-token': guest_token},
189 | timeout=2,
190 | )
191 | if not response.status_code == 200:
192 | raise Exception(f'Failed to get user handle: {response}')
193 | response_json = json.loads(response.content)
194 | for user in response_json:
195 | users[user["id_str"]] = user
196 | return users
197 |
198 |
199 | def lookup_users(user_ids, users):
200 | """Fill the users dictionary with data from Twitter"""
201 | # Filter out any users already known
202 | filtered_user_ids = [id for id in user_ids if id not in users]
203 | if not filtered_user_ids:
204 | # Don't bother opening a session if there's nothing to get
205 | return
206 | # Account metadata observed at ~2.1KB on average.
207 | estimated_size = int(2.1 * len(filtered_user_ids))
208 | print(f'{len(filtered_user_ids)} users are unknown.')
209 | if not get_consent(f'Download user data from Twitter (approx {estimated_size:,} KB)?'):
210 | return
211 |
212 | requests = import_module('requests')
213 | try:
214 | with requests.Session() as session:
215 | bearer_token = 'AAAAAAAAAAAAAAAAAAAAANRILgAAAAAAnNwIzUejRCOuH5E6I8xnZz4puTs%3D1Zv7ttfk8LF81IUq16cHjhLTvJu4FA33AGWWjCpTnA'
216 | guest_token = get_twitter_api_guest_token(session, bearer_token)
217 | retrieved_users = get_twitter_users(session, bearer_token, guest_token, filtered_user_ids)
218 | for user_id, user in retrieved_users.items():
219 | if user["screen_name"] is not None:
220 | users[user_id] = UserData(user_id=user_id, handle=user["screen_name"])
221 | print() # empty line for better readability of output
222 | except Exception as err:
223 | print(f'Failed to download user data: {err}')
224 |
225 |
226 | def read_json_from_js_file(filename):
227 | """Reads the contents of a Twitter-produced .js file into a dictionary."""
228 | print(f'Parsing {filename}...')
229 | with open(filename, 'r', encoding='utf8') as f:
230 | data = f.readlines()
231 | # if the JSON has no real content, it can happen that the file is only one line long.
232 | # in this case, return an empty dict to avoid errors while trying to read non-existing lines.
233 | if len(data) <= 1:
234 | return {}
235 | # convert js file to JSON: replace first line with just '[', squash lines into a single string
236 | prefix = '['
237 | if '{' in data[0]:
238 | prefix += ' {'
239 | data = prefix + ''.join(data[1:])
240 | # parse the resulting JSON and return as a dict
241 | return json.loads(data)
242 |
243 |
244 | def extract_username(paths: PathConfig):
245 | """Returns the user's Twitter username from account.js."""
246 | account = read_json_from_js_file(paths.file_account_js)
247 | return account[0]['account']['username']
248 |
249 |
250 | def escape_markdown(input_text: str) -> str:
251 | """
252 | Escape markdown control characters from input text so that the text will not break in rendered markdown.
253 | (Only use on unformatted text parts that do not yet have any markdown control characters added on purpose!)
254 | """
255 | # disable escapes permanently
256 | return input_text
257 |
258 | characters_to_escape: str = r"\_*[]()~`>#+-=|{}.!"
259 | output_text: str = ''
260 | for char in input_text:
261 | if char in characters_to_escape:
262 | # add backslash before control char
263 | output_text = output_text + "\\" + char
264 | elif char == '\n':
265 | # add double space before line break
266 | output_text = output_text + " " + char
267 | else:
268 | output_text = output_text + char
269 | return output_text
270 |
271 | def collect_tweet_references(tweet):
272 | if 'tweet' in tweet.keys():
273 | tweet = tweet['tweet']
274 | tweet_ids = set()
275 | in_reply_to = False
276 | quote = False
277 | retweet = False
278 | # Collect quoted tweets
279 | if 'entities' in tweet and 'urls' in tweet['entities']:
280 | for url in tweet['entities']['urls']:
281 | if 'url' in url and 'expanded_url' in url:
282 | expanded_url = url['expanded_url']
283 | matches = re.match(r'^https://twitter.com/([0-9A-Za-z_]*)/status/(\d+)$', expanded_url)
284 | if (matches):
285 | quote = True
286 |
287 | # Collect previous tweets in conversation
288 | if 'in_reply_to_status_id_str' in tweet:
289 | in_reply_to = True
290 |
291 | # Collect RT retweets
292 | if 'full_text' in tweet and tweet['full_text'].startswith('RT @'):
293 | retweet = True
294 |
295 | # TODO: really parse the tweet ids into texts
296 | # needs to create 3 sets for each type.
297 | return in_reply_to, quote, retweet
298 |
299 |
300 | def convert_tweet(tweet, username, media_sources, users: dict, paths: PathConfig):
301 | """Converts a JSON-format tweet. Returns tuple of timestamp, markdown and HTML."""
302 | if 'tweet' in tweet.keys():
303 | tweet = tweet['tweet']
304 |
305 | if not 'created_at' in tweet:
306 | # seems to be a bug in archive:
307 | return None, None, None, None, None, None, None
308 |
309 | timestamp_str = tweet['created_at']
310 | timestamp = int(round(datetime.datetime.strptime(timestamp_str, '%a %b %d %X %z %Y').timestamp()))
311 | # Example: Tue Mar 19 14:05:17 +0000 2019
312 | body_markdown = tweet['full_text']
313 | body_html = tweet['full_text']
314 | tweet_id_str = tweet['id_str']
315 |
316 | # for old tweets before embedded t.co redirects were added, ensure the links are
317 | # added to the urls entities list so that we can build correct links later on.
318 | if 'entities' in tweet and 'media' not in tweet['entities'] and len(tweet['entities'].get("urls", [])) == 0:
319 | for word in tweet['full_text'].split():
320 | try:
321 | url = urlparse(word)
322 | except ValueError:
323 | pass # don't crash when trying to parse something that looks like a URL but actually isn't
324 | else:
325 | if url.scheme != '' and url.netloc != '' and not word.endswith('\u2026'):
326 | # Shorten links similar to twitter
327 | netloc_short = url.netloc[4:] if url.netloc.startswith("www.") else url.netloc
328 | path_short = url.path if len(url.path + '?' + url.query) < 15 \
329 | else (url.path + '?' + url.query)[:15] + '\u2026'
330 | tweet['entities']['urls'].append({
331 | 'url': word,
332 | 'expanded_url': word,
333 | 'display_url': netloc_short + path_short,
334 | 'indices': [tweet['full_text'].index(word), tweet['full_text'].index(word) + len(word)],
335 | })
336 | # replace t.co URLs with their original versions
337 | if 'entities' in tweet and 'urls' in tweet['entities']:
338 | for url in tweet['entities']['urls']:
339 | if 'url' in url and 'expanded_url' in url:
340 | expanded_url = url['expanded_url']
341 | #body_markdown = body_markdown.replace(url['url'], expanded_url)
342 | body_markdown = body_markdown.replace(url['url'], '(link)')
343 | expanded_url_html = f'{expanded_url}'
344 | body_html = body_html.replace(url['url'], expanded_url_html)
345 | # if the tweet is a reply, construct a header that links the names
346 | # of the accounts being replied to the tweet being replied to
347 | header_markdown = ''
348 | header_html = ''
349 | if 'in_reply_to_status_id' in tweet:
350 | # match and remove all occurrences of '@username ' at the start of the body
351 | replying_to = re.match(r'^(@[0-9A-Za-z_]* )*', body_markdown)[0]
352 | if replying_to:
353 | body_markdown = body_markdown[len(replying_to):]
354 | body_html = body_html[len(replying_to):]
355 | else:
356 | # no '@username ' in the body: we're replying to self
357 | replying_to = f'@{username}'
358 | names = replying_to.split()
359 | # some old tweets lack 'in_reply_to_screen_name': use it if present, otherwise fall back to names[0]
360 | in_reply_to_screen_name = tweet['in_reply_to_screen_name'] if 'in_reply_to_screen_name' in tweet else names[0]
361 | # create a list of names of the form '@name1, @name2 and @name3' - or just '@name1' if there is only one name
362 | name_list = ', '.join(names[:-1]) + (f' and {names[-1]}' if len(names) > 1 else names[0])
363 | in_reply_to_status_id = tweet['in_reply_to_status_id']
364 | replying_to_url = f'https://twitter.com/{in_reply_to_screen_name}/status/{in_reply_to_status_id}'
365 | header_markdown += f'Replying to [{escape_markdown(name_list)}]({replying_to_url})\n\n'
366 | header_html += f'Replying to {name_list} '
367 | # escape tweet body for markdown rendering:
368 | # no need to escape here as it will be tokenized again.
369 | body_markdown = escape_markdown(body_markdown)
370 | #body_markdown = body_markdown.strip().replace("\n", r"\n")
371 | # replace image URLs with image links to local files
372 | if 'entities' in tweet and 'media' in tweet['entities'] and 'extended_entities' in tweet \
373 | and 'media' in tweet['extended_entities']:
374 | original_url = tweet['entities']['media'][0]['url']
375 | markdown = ''
376 | html = ''
377 | for media in tweet['extended_entities']['media']:
378 | if 'url' in media and 'media_url' in media:
379 | original_expanded_url = media['media_url']
380 | original_filename = os.path.split(original_expanded_url)[1]
381 | archive_media_filename = tweet_id_str + '-' + original_filename
382 | archive_media_path = os.path.join(paths.dir_input_media, archive_media_filename)
383 | file_output_media = os.path.join(paths.dir_output_media, archive_media_filename)
384 | media_url = rel_url(file_output_media, paths.example_file_output_tweets)
385 | markdown += '' if not markdown and body_markdown == escape_markdown(original_url) else '\n\n'
386 | #markdown += '' if not markdown and body_markdown == original_url else '\n\n'
387 | html += '' if not html and body_html == original_url else ' '
388 | if os.path.isfile(archive_media_path):
389 | # Found a matching image, use this one
390 | if not os.path.isfile(file_output_media):
391 | shutil.copy(archive_media_path, file_output_media)
392 | #markdown += f''
393 | # ignore media at this moment
394 | markdown += '(media)'
395 | html += f''
396 | # Save the online location of the best-quality version of this file, for later upgrading if wanted
397 | best_quality_url = f'https://pbs.twimg.com/media/{original_filename}:orig'
398 | media_sources.append(
399 | (os.path.join(paths.dir_output_media, archive_media_filename), best_quality_url)
400 | )
401 | else:
402 | # Is there any other file that includes the tweet_id in its filename?
403 | archive_media_paths = glob.glob(os.path.join(paths.dir_input_media, tweet_id_str + '*'))
404 | if len(archive_media_paths) > 0:
405 | for archive_media_path in archive_media_paths:
406 | archive_media_filename = os.path.split(archive_media_path)[-1]
407 | file_output_media = os.path.join(paths.dir_output_media, archive_media_filename)
408 | media_url = rel_url(file_output_media, paths.example_file_output_tweets)
409 | if not os.path.isfile(file_output_media):
410 | shutil.copy(archive_media_path, file_output_media)
411 | #markdown += f'\n'
413 | markdown += '(video)'
414 | html += f'\n'
416 | # Save the online location of the best-quality version of this file,
417 | # for later upgrading if wanted
418 | if 'video_info' in media and 'variants' in media['video_info']:
419 | best_quality_url = ''
420 | best_bitrate = -1 # some valid videos are marked with bitrate=0 in the JSON
421 | for variant in media['video_info']['variants']:
422 | if 'bitrate' in variant:
423 | bitrate = int(variant['bitrate'])
424 | if bitrate > best_bitrate:
425 | best_quality_url = variant['url']
426 | best_bitrate = bitrate
427 | if best_bitrate == -1:
428 | print(f"Warning No URL found for {original_url} {original_expanded_url} "
429 | f"{archive_media_path} {media_url}")
430 | print(f"JSON: {tweet}")
431 | else:
432 | media_sources.append(
433 | (os.path.join(paths.dir_output_media, archive_media_filename),
434 | best_quality_url)
435 | )
436 | else:
437 | print(f'Warning: missing local file: {archive_media_path}. Using original link instead: '
438 | f'{original_url} (expands to {original_expanded_url})')
439 | #markdown += f''
440 | markdown += "(media)"
441 | html += f'{original_url}'
442 | body_markdown = body_markdown.replace(escape_markdown(original_url), markdown)
443 | #body_markdown = body_markdown.replace(original_url, markdown)
444 | body_html = body_html.replace(original_url, html)
445 | # make the body a quote
446 | body_markdown = body_markdown
447 | body_html = '
' + ' \n'.join(body_html.splitlines()) + '
'
448 | # append the original Twitter URL as a link
449 | original_tweet_url = f'https://twitter.com/{username}/status/{tweet_id_str}'
450 | icon_url = rel_url(paths.file_tweet_icon, paths.example_file_output_tweets)
451 |
452 | # we do not like header because it's slightly confusing
453 | # TODO: consider if we do need images in this markdown
454 | # timestamp is not needed as it's for corpus
455 | body_markdown = body_markdown #+ f'\n\n ' \
456 | #f'[{timestamp_str}]({original_tweet_url})'
457 | body_html = header_html + body_html + f' {timestamp_str}'
459 | # extract user_id:handle connections
460 | if 'in_reply_to_user_id' in tweet and 'in_reply_to_screen_name' in tweet and \
461 | tweet['in_reply_to_screen_name'] is not None:
462 | reply_to_id = tweet['in_reply_to_user_id']
463 | if int(reply_to_id) >= 0: # some ids are -1, not sure why
464 | handle = tweet['in_reply_to_screen_name']
465 | users[reply_to_id] = UserData(user_id=reply_to_id, handle=handle)
466 | if 'entities' in tweet and 'user_mentions' in tweet['entities'] and tweet['entities']['user_mentions'] is not None:
467 | for mention in tweet['entities']['user_mentions']:
468 | if mention is not None and 'id' in mention and 'screen_name' in mention:
469 | mentioned_id = mention['id']
470 | if int(mentioned_id) >= 0: # some ids are -1, not sure why
471 | handle = mention['screen_name']
472 | if handle is not None:
473 | users[mentioned_id] = UserData(user_id=mentioned_id, handle=handle)
474 |
475 | # extract the type of tweet
476 | in_reply_to, quote, retweet = collect_tweet_references(tweet)
477 |
478 | return timestamp, tweet_id_str, body_markdown, body_html, in_reply_to, quote, retweet
479 |
480 |
481 | def find_files_input_tweets(dir_path_input_data):
482 | """Identify the tweet archive's file and folder names -
483 | they change slightly depending on the archive size it seems."""
484 | input_tweets_file_templates = ['tweet.js', 'tweets.js', 'tweets-part*.js']
485 | files_paths_input_tweets = []
486 | for input_tweets_file_template in input_tweets_file_templates:
487 | files_paths_input_tweets += glob.glob(os.path.join(dir_path_input_data, input_tweets_file_template))
488 | if len(files_paths_input_tweets)==0:
489 | print(f'Error: no files matching {input_tweets_file_templates} in {dir_path_input_data}')
490 | exit()
491 | return files_paths_input_tweets
492 |
493 |
494 | def find_dir_input_media(dir_path_input_data):
495 | input_media_dir_templates = ['tweet_media', 'tweets_media']
496 | input_media_dirs = []
497 | for input_media_dir_template in input_media_dir_templates:
498 | input_media_dirs += glob.glob(os.path.join(dir_path_input_data, input_media_dir_template))
499 | if len(input_media_dirs) == 0:
500 | print(f'Error: no folders matching {input_media_dir_templates} in {dir_path_input_data}')
501 | exit()
502 | if len(input_media_dirs) > 1:
503 | print(f'Error: multiple folders matching {input_media_dir_templates} in {dir_path_input_data}')
504 | exit()
505 | return input_media_dirs[0]
506 |
507 |
508 | def download_file_if_larger(url, filename, index, count, sleep_time):
509 | """Attempts to download from the specified URL. Overwrites file if larger.
510 | Returns whether the file is now known to be the largest available, and the number of bytes downloaded.
511 | """
512 | requests = import_module('requests')
513 | imagesize = import_module('imagesize')
514 |
515 | pref = f'{index:3d}/{count:3d} {filename}: '
516 | # Sleep briefly, in an attempt to minimize the possibility of trigging some auto-cutoff mechanism
517 | if index > 1:
518 | print(f'{pref}Sleeping...', end='\r')
519 | time.sleep(sleep_time)
520 | # Request the URL (in stream mode so that we can conditionally abort depending on the headers)
521 | print(f'{pref}Requesting headers for {url}...', end='\r')
522 | byte_size_before = os.path.getsize(filename)
523 | try:
524 | with requests.get(url, stream=True, timeout=2) as res:
525 | if not res.status_code == 200:
526 | # Try to get content of response as `res.text`.
527 | # For twitter.com, this will be empty in most (all?) cases.
528 | # It is successfully tested with error responses from other domains.
529 | raise Exception(f'Download failed with status "{res.status_code} {res.reason}". '
530 | f'Response content: "{res.text}"')
531 | byte_size_after = int(res.headers['content-length'])
532 | if byte_size_after != byte_size_before:
533 | # Proceed with the full download
534 | tmp_filename = filename+'.tmp'
535 | print(f'{pref}Downloading {url}... ', end='\r')
536 | with open(tmp_filename,'wb') as f:
537 | shutil.copyfileobj(res.raw, f)
538 | post = f'{byte_size_after/2**20:.1f}MB downloaded'
539 | width_before, height_before = imagesize.get(filename)
540 | width_after, height_after = imagesize.get(tmp_filename)
541 | pixels_before, pixels_after = width_before * height_before, width_after * height_after
542 | pixels_percentage_increase = 100.0 * (pixels_after - pixels_before) / pixels_before
543 |
544 | if width_before == -1 and height_before == -1 and width_after == -1 and height_after == -1:
545 | # could not check size of both versions, probably a video or unsupported image format
546 | os.replace(tmp_filename, filename)
547 | bytes_percentage_increase = 100.0 * (byte_size_after - byte_size_before) / byte_size_before
548 | logging.info(f'{pref}SUCCESS. New version is {bytes_percentage_increase:3.0f}% '
549 | f'larger in bytes (pixel comparison not possible). {post}')
550 | return True, byte_size_after
551 | elif width_before == -1 or height_before == -1 or width_after == -1 or height_after == -1:
552 | # could not check size of one version, this should not happen (corrupted download?)
553 | logging.info(f'{pref}SKIPPED. Pixel size comparison inconclusive: '
554 | f'{width_before}*{height_before}px vs. {width_after}*{height_after}px. {post}')
555 | return False, byte_size_after
556 | elif pixels_after >= pixels_before:
557 | os.replace(tmp_filename, filename)
558 | bytes_percentage_increase = 100.0 * (byte_size_after - byte_size_before) / byte_size_before
559 | if bytes_percentage_increase >= 0:
560 | logging.info(f'{pref}SUCCESS. New version is {bytes_percentage_increase:3.0f}% larger in bytes '
561 | f'and {pixels_percentage_increase:3.0f}% larger in pixels. {post}')
562 | else:
563 | logging.info(f'{pref}SUCCESS. New version is actually {-bytes_percentage_increase:3.0f}% '
564 | f'smaller in bytes but {pixels_percentage_increase:3.0f}% '
565 | f'larger in pixels. {post}')
566 | return True, byte_size_after
567 | else:
568 | logging.info(f'{pref}SKIPPED. Online version has {-pixels_percentage_increase:3.0f}% '
569 | f'smaller pixel size. {post}')
570 | return True, byte_size_after
571 | else:
572 | logging.info(f'{pref}SKIPPED. Online version is same byte size, assuming same content. Not downloaded.')
573 | return True, 0
574 | except Exception as err:
575 | logging.error(f"{pref}FAIL. Media couldn't be retrieved from {url} because of exception: {err}")
576 | return False, 0
577 |
578 |
579 | def download_larger_media(media_sources, paths: PathConfig):
580 | """Uses (filename, URL) tuples in media_sources to download files from remote storage.
581 | Aborts downloads if the remote file is the same size or smaller than the existing local version.
582 | Retries the failed downloads several times, with increasing pauses between each to avoid being blocked.
583 | """
584 | # Log to file as well as the console
585 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(message)s')
586 | mkdirs_for_file(paths.file_download_log)
587 | logfile_handler = logging.FileHandler(filename=paths.file_download_log, mode='w')
588 | logfile_handler.setLevel(logging.INFO)
589 | logging.getLogger().addHandler(logfile_handler)
590 | # Download new versions
591 | start_time = time.time()
592 | total_bytes_downloaded = 0
593 | sleep_time = 0.25
594 | remaining_tries = 5
595 | while remaining_tries > 0:
596 | number_of_files = len(media_sources)
597 | success_count = 0
598 | retries = []
599 | for index, (local_media_path, media_url) in enumerate(media_sources):
600 | success, bytes_downloaded = download_file_if_larger(
601 | media_url, local_media_path, index + 1, number_of_files, sleep_time
602 | )
603 | if success:
604 | success_count += 1
605 | else:
606 | retries.append((local_media_path, media_url))
607 | total_bytes_downloaded += bytes_downloaded
608 |
609 | # show % done and estimated remaining time:
610 | time_elapsed: float = time.time() - start_time
611 | estimated_time_per_file: float = time_elapsed / (index + 1)
612 | estimated_time_remaining: datetime.datetime = \
613 | datetime.datetime.fromtimestamp(
614 | (number_of_files - (index + 1)) * estimated_time_per_file,
615 | tz=datetime.timezone.utc
616 | )
617 | if estimated_time_remaining.hour >= 1:
618 | time_remaining_string: str = \
619 | f"{estimated_time_remaining.hour} hour{'' if estimated_time_remaining.hour == 1 else 's'} " \
620 | f"{estimated_time_remaining.minute} minute{'' if estimated_time_remaining.minute == 1 else 's'}"
621 | elif estimated_time_remaining.minute >= 1:
622 | time_remaining_string: str = \
623 | f"{estimated_time_remaining.minute} minute{'' if estimated_time_remaining.minute == 1 else 's'} " \
624 | f"{estimated_time_remaining.second} second{'' if estimated_time_remaining.second == 1 else 's'}"
625 | else:
626 | time_remaining_string: str = \
627 | f"{estimated_time_remaining.second} second{'' if estimated_time_remaining.second == 1 else 's'}"
628 |
629 | if index + 1 == number_of_files:
630 | print(' 100 % done.')
631 | else:
632 | print(f' {(100*(index+1)/number_of_files):.1f} % done, about {time_remaining_string} remaining...')
633 |
634 | media_sources = retries
635 | remaining_tries -= 1
636 | sleep_time += 2
637 | logging.info(f'\n{success_count} of {number_of_files} tested media files '
638 | f'are known to be the best-quality available.\n')
639 | if len(retries) == 0:
640 | break
641 | if remaining_tries > 0:
642 | print(f'----------------------\n\nRetrying the ones that failed, with a longer sleep. '
643 | f'{remaining_tries} tries remaining.\n')
644 | end_time = time.time()
645 |
646 | logging.info(f'Total downloaded: {total_bytes_downloaded/2**20:.1f}MB = {total_bytes_downloaded/2**30:.2f}GB')
647 | logging.info(f'Time taken: {end_time-start_time:.0f}s')
648 | print(f'Wrote log to {paths.file_download_log}')
649 |
650 |
651 | def parse_tweets(username, users, html_template, paths: PathConfig, lang):
652 | """Read tweets from paths.files_input_tweets, write to *.md and *.html.
653 | Copy the media used to paths.dir_output_media.
654 | Collect user_id:user_handle mappings for later use, in 'users'.
655 | Returns the mapping from media filename to best-quality URL.
656 | """
657 | tweets = []
658 | media_sources = []
659 | for tweets_js_filename in paths.files_input_tweets:
660 | jsons = read_json_from_js_file(tweets_js_filename)
661 | for tweet in jsons:
662 | result = convert_tweet(tweet, username, media_sources, users, paths)
663 | if result[0]:
664 | # check valid tweet by timestamp
665 | tweets.append(result)
666 | tweets.sort(key=lambda tup: tup[0]) # oldest first
667 |
668 | # Group tweets by month
669 | grouped_tweets = defaultdict(list)
670 | for timestamp, ids, md, html, in_reply_to, quote, retweet in tweets:
671 | # Use a (markdown) filename that can be imported into Jekyll: YYYY-MM-DD-your-title-here.md
672 | dt = datetime.datetime.fromtimestamp(timestamp)
673 | grouped_tweets[(dt.year, dt.month)].append((ids, md, in_reply_to, quote, retweet))
674 |
675 | final_md = []
676 | for (year, month), content in grouped_tweets.items():
677 | # Write into *.md files
678 | for id, md, in_reply_to, quote, retweet in content:
679 | final_md.append((id, md, in_reply_to, quote, retweet))
680 | #md_path = paths.create_path_for_file_output_tweets(year, month, format="md")
681 | md_path = "tweets.md"
682 |
683 | write_json(md_path, final_md, lang)
684 |
685 |
686 | # Write into *.html files
687 | # no need, as we gonna output markdowns only
688 |
689 | print(f'Wrote {len(tweets)} tweets to *.md and *.html, '
690 | f'with images and video embedded from {paths.dir_output_media}')
691 |
692 | return media_sources
693 |
694 |
695 | def collect_user_ids_from_followings(paths) -> list:
696 | """
697 | Collect all user ids that appear in the followings archive data.
698 | (For use in bulk online lookup from Twitter.)
699 | """
700 | # read JSON file from archive
701 | following_json = read_json_from_js_file(os.path.join(paths.dir_input_data, 'following.js'))
702 | # collect all user ids in a list
703 | following_ids = []
704 | for follow in following_json:
705 | if 'following' in follow and 'accountId' in follow['following']:
706 | following_ids.append(follow['following']['accountId'])
707 | return following_ids
708 |
709 |
710 | def parse_followings(users, user_id_url_template, paths: PathConfig):
711 | """Parse paths.dir_input_data/following.js, write to paths.file_output_following.
712 | """
713 | following = []
714 | following_json = read_json_from_js_file(os.path.join(paths.dir_input_data, 'following.js'))
715 | following_ids = []
716 | for follow in following_json:
717 | if 'following' in follow and 'accountId' in follow['following']:
718 | following_ids.append(follow['following']['accountId'])
719 | for following_id in following_ids:
720 | handle = users[following_id].handle if following_id in users else '~unknown~handle~'
721 | following.append(handle + ' ' + user_id_url_template.format(following_id))
722 | following.sort()
723 | following_output_path = paths.create_path_for_file_output_single(format="txt", kind="following")
724 | with open_and_mkdirs(following_output_path) as f:
725 | f.write('\n'.join(following))
726 | print(f"Wrote {len(following)} accounts to {following_output_path}")
727 |
728 |
729 | def collect_user_ids_from_followers(paths) -> list:
730 | """
731 | Collect all user ids that appear in the followers archive data.
732 | (For use in bulk online lookup from Twitter.)
733 | """
734 | # read JSON file from archive
735 | follower_json = read_json_from_js_file(os.path.join(paths.dir_input_data, 'follower.js'))
736 | # collect all user ids in a list
737 | follower_ids = []
738 | for follower in follower_json:
739 | if 'follower' in follower and 'accountId' in follower['follower']:
740 | follower_ids.append(follower['follower']['accountId'])
741 | return follower_ids
742 |
743 |
744 | def parse_followers(users, user_id_url_template, paths: PathConfig):
745 | """Parse paths.dir_input_data/followers.js, write to paths.file_output_followers.
746 | """
747 | followers = []
748 | follower_json = read_json_from_js_file(os.path.join(paths.dir_input_data, 'follower.js'))
749 | follower_ids = []
750 | for follower in follower_json:
751 | if 'follower' in follower and 'accountId' in follower['follower']:
752 | follower_ids.append(follower['follower']['accountId'])
753 | for follower_id in follower_ids:
754 | handle = users[follower_id].handle if follower_id in users else '~unknown~handle~'
755 | followers.append(handle + ' ' + user_id_url_template.format(follower_id))
756 | followers.sort()
757 | followers_output_path = paths.create_path_for_file_output_single(format="txt", kind="followers")
758 | with open_and_mkdirs(followers_output_path) as f:
759 | f.write('\n'.join(followers))
760 | print(f"Wrote {len(followers)} accounts to {followers_output_path}")
761 |
762 |
763 | def chunks(lst: list, n: int):
764 | """Yield successive n-sized chunks from lst."""
765 | for i in range(0, len(lst), n):
766 | yield lst[i:i + n]
767 |
768 |
769 | def collect_user_ids_from_direct_messages(paths) -> list:
770 | """
771 | Collect all user ids that appear in the direct messages archive data.
772 | (For use in bulk online lookup from Twitter.)
773 | """
774 | # read JSON file from archive
775 | dms_json = read_json_from_js_file(os.path.join(paths.dir_input_data, 'direct-messages.js'))
776 | # collect all user ids in a set
777 | dms_user_ids = set()
778 | for conversation in dms_json:
779 | if 'dmConversation' in conversation and 'conversationId' in conversation['dmConversation']:
780 | dm_conversation = conversation['dmConversation']
781 | conversation_id = dm_conversation['conversationId']
782 | user1_id, user2_id = conversation_id.split('-')
783 | dms_user_ids.add(user1_id)
784 | dms_user_ids.add(user2_id)
785 | return list(dms_user_ids)
786 |
787 |
788 | def parse_direct_messages(username, users, user_id_url_template, paths: PathConfig):
789 | """Parse paths.dir_input_data/direct-messages.js, write to one markdown file per conversation.
790 | """
791 | # read JSON file
792 | dms_json = read_json_from_js_file(os.path.join(paths.dir_input_data, 'direct-messages.js'))
793 |
794 | # Parse the DMs and store the messages in a dict
795 | conversations_messages = defaultdict(list)
796 | for conversation in dms_json:
797 | if 'dmConversation' in conversation and 'conversationId' in conversation['dmConversation']:
798 | dm_conversation = conversation['dmConversation']
799 | conversation_id = dm_conversation['conversationId']
800 | user1_id, user2_id = conversation_id.split('-')
801 | messages = []
802 | if 'messages' in dm_conversation:
803 | for message in dm_conversation['messages']:
804 | if 'messageCreate' in message:
805 | message_create = message['messageCreate']
806 | if all(tag in message_create for tag in ['senderId', 'recipientId', 'text', 'createdAt']):
807 | from_id = message_create['senderId']
808 | to_id = message_create['recipientId']
809 | body = message_create['text']
810 | # replace t.co URLs with their original versions
811 | if 'urls' in message_create and len(message_create['urls']) > 0:
812 | for url in message_create['urls']:
813 | if 'url' in url and 'expanded' in url:
814 | expanded_url = url['expanded']
815 | body = body.replace(url['url'], expanded_url)
816 | # escape message body for markdown rendering:
817 | body_markdown = escape_markdown(body)
818 | # replace image URLs with image links to local files
819 | if 'mediaUrls' in message_create \
820 | and len(message_create['mediaUrls']) == 1 \
821 | and 'urls' in message_create:
822 | original_expanded_url = message_create['urls'][0]['expanded']
823 | message_id = message_create['id']
824 | media_hash_and_type = message_create['mediaUrls'][0].split('/')[-1]
825 | media_id = message_create['mediaUrls'][0].split('/')[-2]
826 | archive_media_filename = f'{message_id}-{media_hash_and_type}'
827 | new_url = os.path.join(paths.dir_output_media, archive_media_filename)
828 | archive_media_path = \
829 | os.path.join(paths.dir_input_data, 'direct_messages_media', archive_media_filename)
830 | if os.path.isfile(archive_media_path):
831 | # found a matching image, use this one
832 | if not os.path.isfile(new_url):
833 | shutil.copy(archive_media_path, new_url)
834 | image_markdown = f'\n\n'
835 | body_markdown = body_markdown.replace(
836 | escape_markdown(original_expanded_url), image_markdown
837 | )
838 |
839 | # Save the online location of the best-quality version of this file,
840 | # for later upgrading if wanted
841 | best_quality_url = \
842 | f'https://ton.twitter.com/i//ton/data/dm/' \
843 | f'{message_id}/{media_id}/{media_hash_and_type}'
844 | # there is no ':orig' here, the url without any suffix has the original size
845 |
846 | # TODO: a cookie (and a 'Referer: https://twitter.com' header)
847 | # is needed to retrieve it, so the url might be useless anyway...
848 |
849 | # WARNING: Do not uncomment the statement below until the cookie problem is solved!
850 | # media_sources.append(
851 | # (
852 | # os.path.join(output_media_folder_name, archive_media_filename),
853 | # best_quality_url
854 | # )
855 | # )
856 |
857 | else:
858 | archive_media_paths = glob.glob(
859 | os.path.join(paths.dir_input_data, 'direct_messages_media', message_id + '*'))
860 | if len(archive_media_paths) > 0:
861 | for archive_media_path in archive_media_paths:
862 | archive_media_filename = os.path.split(archive_media_path)[-1]
863 | media_url = os.path.join(paths.dir_output_media, archive_media_filename)
864 | if not os.path.isfile(media_url):
865 | shutil.copy(archive_media_path, media_url)
866 | video_markdown = f'\n\n'
868 | body_markdown = body_markdown.replace(
869 | escape_markdown(original_expanded_url), video_markdown
870 | )
871 |
872 | # TODO: maybe also save the online location of the best-quality version for videos?
873 | # (see above)
874 |
875 | else:
876 | print(f'Warning: missing local file: {archive_media_path}. '
877 | f'Using original link instead: {original_expanded_url})')
878 |
879 | created_at = message_create['createdAt'] # example: 2022-01-27T15:58:52.744Z
880 | timestamp = \
881 | int(round(datetime.datetime.strptime(created_at, '%Y-%m-%dT%X.%fZ').timestamp()))
882 |
883 | from_handle = escape_markdown(users[from_id].handle) if from_id in users \
884 | else user_id_url_template.format(from_id)
885 | to_handle = escape_markdown(users[to_id].handle) if to_id in users \
886 | else user_id_url_template.format(to_id)
887 |
888 | # make the body a quote
889 | body_markdown = '> ' + '\n> '.join(body_markdown.splitlines())
890 | message_markdown = f'{from_handle} -> {to_handle}: ({created_at}) \n\n' \
891 | f'{body_markdown}'
892 | messages.append((timestamp, message_markdown))
893 |
894 | # find identifier for the conversation
895 | other_user_id = user2_id if (user1_id in users and users[user1_id].handle == username) else user1_id
896 |
897 | # collect messages per identifying user in conversations_messages dict
898 | conversations_messages[other_user_id].extend(messages)
899 |
900 | # output as one file per conversation (or part of long conversation)
901 | num_written_messages = 0
902 | num_written_files = 0
903 | for other_user_id, messages in conversations_messages.items():
904 | # sort messages by timestamp
905 | messages.sort(key=lambda tup: tup[0])
906 |
907 | other_user_name = escape_markdown(users[other_user_id].handle) if other_user_id in users \
908 | else user_id_url_template.format(other_user_id)
909 |
910 | other_user_short_name: str = users[other_user_id].handle if other_user_id in users else other_user_id
911 |
912 | escaped_username = escape_markdown(username)
913 |
914 | # if there are more than 1000 messages, the conversation was split up in the twitter archive.
915 | # following this standard, also split up longer conversations in the output files:
916 |
917 | if len(messages) > 1000:
918 | for chunk_index, chunk in enumerate(chunks(messages, 1000)):
919 | markdown = ''
920 | markdown += f'### Conversation between {escaped_username} and {other_user_name}, ' \
921 | f'part {chunk_index+1}: ###\n\n----\n\n'
922 | markdown += '\n\n----\n\n'.join(md for _, md in chunk)
923 | conversation_output_path = paths.create_path_for_file_output_dms(name=other_user_short_name, index=(chunk_index + 1), format="md")
924 |
925 | # write part to a markdown file
926 | with open_and_mkdirs(conversation_output_path) as f:
927 | f.write(markdown)
928 | print(f'Wrote {len(chunk)} messages to {conversation_output_path}')
929 | num_written_files += 1
930 |
931 | else:
932 | markdown = ''
933 | markdown += f'### Conversation between {escaped_username} and {other_user_name}: ###\n\n----\n\n'
934 | markdown += '\n\n----\n\n'.join(md for _, md in messages)
935 | conversation_output_path = paths.create_path_for_file_output_dms(name=other_user_short_name, format="md")
936 |
937 | with open_and_mkdirs(conversation_output_path) as f:
938 | f.write(markdown)
939 | print(f'Wrote {len(messages)} messages to {conversation_output_path}')
940 | num_written_files += 1
941 |
942 | num_written_messages += len(messages)
943 |
944 | print(f"\nWrote {len(conversations_messages)} direct message conversations "
945 | f"({num_written_messages} total messages) to {num_written_files} markdown files\n")
946 |
947 |
948 | def make_conversation_name_safe_for_filename(conversation_name: str) -> str:
949 | """
950 | Remove/replace characters that could be unsafe in filenames
951 | """
952 | forbidden_chars = \
953 | ['"', "'", '*', '/', '\\', ':', '<', '>', '?', '|', '!', '@', ';', ',', '=', '.', '\n', '\r', '\t']
954 | new_conversation_name = ''
955 | for char in conversation_name:
956 | if char in forbidden_chars:
957 | new_conversation_name = new_conversation_name + '_'
958 | elif char.isspace():
959 | # replace spaces with underscores
960 | new_conversation_name = new_conversation_name + '_'
961 | elif char == 0x7F or (0x1F >= ord(char) >= 0x00):
962 | # 0x00 - 0x1F and 0x7F are also forbidden, just discard them
963 | continue
964 | else:
965 | new_conversation_name = new_conversation_name + char
966 |
967 | return new_conversation_name
968 |
969 |
970 | def find_group_dm_conversation_participant_ids(conversation: dict) -> set:
971 | """
972 | Find IDs of all participating Users in a group direct message conversation
973 | """
974 | group_user_ids = set()
975 | if 'dmConversation' in conversation and 'conversationId' in conversation['dmConversation']:
976 | dm_conversation = conversation['dmConversation']
977 | if 'messages' in dm_conversation:
978 | for message in dm_conversation['messages']:
979 | if 'messageCreate' in message:
980 | group_user_ids.add(message['messageCreate']['senderId'])
981 | elif 'joinConversation' in message:
982 | group_user_ids.add(message['joinConversation']['initiatingUserId'])
983 | for participant_id in message['joinConversation']['participantsSnapshot']:
984 | group_user_ids.add(participant_id)
985 | elif "participantsJoin" in message:
986 | group_user_ids.add(message['participantsJoin']['initiatingUserId'])
987 | for participant_id in message['participantsJoin']['userIds']:
988 | group_user_ids.add(participant_id)
989 | return group_user_ids
990 |
991 |
992 | def collect_user_ids_from_group_direct_messages(paths) -> list:
993 | """
994 | Collect all user ids that appear in the group direct messages archive data.
995 | (For use in bulk online lookup from Twitter.)
996 | """
997 | # read JSON file from archive
998 | group_dms_json = read_json_from_js_file(os.path.join(paths.dir_input_data, 'direct-messages-group.js'))
999 | # collect all user ids in a set
1000 | group_dms_user_ids = set()
1001 | for conversation in group_dms_json:
1002 | participants = find_group_dm_conversation_participant_ids(conversation)
1003 | for participant_id in participants:
1004 | group_dms_user_ids.add(participant_id)
1005 | return list(group_dms_user_ids)
1006 |
1007 |
1008 | def parse_group_direct_messages(username, users, user_id_url_template, paths):
1009 | """Parse data_folder/direct-messages-group.js, write to one markdown file per conversation.
1010 | """
1011 | # read JSON file from archive
1012 | group_dms_json = read_json_from_js_file(os.path.join(paths.dir_input_data, 'direct-messages-group.js'))
1013 |
1014 | # Parse the group DMs, store messages and metadata in a dict
1015 | group_conversations_messages = defaultdict(list)
1016 | group_conversations_metadata = defaultdict(dict)
1017 | for conversation in group_dms_json:
1018 | if 'dmConversation' in conversation and 'conversationId' in conversation['dmConversation']:
1019 | dm_conversation = conversation['dmConversation']
1020 | conversation_id = dm_conversation['conversationId']
1021 | participants = find_group_dm_conversation_participant_ids(conversation)
1022 | participant_names = []
1023 | for participant_id in participants:
1024 | if participant_id in users:
1025 | participant_names.append(users[participant_id].handle)
1026 | else:
1027 | participant_names.append(user_id_url_template.format(participant_id))
1028 |
1029 | # save names in metadata
1030 | group_conversations_metadata[conversation_id]['participants'] = participants
1031 | group_conversations_metadata[conversation_id]['participant_names'] = participant_names
1032 | group_conversations_metadata[conversation_id]['conversation_names'] = [(0, conversation_id)]
1033 | group_conversations_metadata[conversation_id]['participant_message_count'] = defaultdict(int)
1034 | for participant_id in participants:
1035 | # init every participant's message count with 0, so that users with no activity are not ignored
1036 | group_conversations_metadata[conversation_id]['participant_message_count'][participant_id] = 0
1037 | messages = []
1038 | if 'messages' in dm_conversation:
1039 | for message in dm_conversation['messages']:
1040 | if 'messageCreate' in message:
1041 | message_create = message['messageCreate']
1042 | if all(tag in message_create for tag in ['senderId', 'text', 'createdAt']):
1043 | from_id = message_create['senderId']
1044 | # count how many messages this user has sent to the group
1045 | group_conversations_metadata[conversation_id]['participant_message_count'][from_id] += 1
1046 | body = message_create['text']
1047 | # replace t.co URLs with their original versions
1048 | if 'urls' in message_create:
1049 | for url in message_create['urls']:
1050 | if 'url' in url and 'expanded' in url:
1051 | expanded_url = url['expanded']
1052 | body = body.replace(url['url'], expanded_url)
1053 | # escape message body for markdown rendering:
1054 | body_markdown = escape_markdown(body)
1055 | # replace image URLs with image links to local files
1056 | if 'mediaUrls' in message_create \
1057 | and len(message_create['mediaUrls']) == 1 \
1058 | and 'urls' in message_create:
1059 | original_expanded_url = message_create['urls'][0]['expanded']
1060 | message_id = message_create['id']
1061 | media_hash_and_type = message_create['mediaUrls'][0].split('/')[-1]
1062 | media_id = message_create['mediaUrls'][0].split('/')[-2]
1063 | archive_media_filename = f'{message_id}-{media_hash_and_type}'
1064 | new_url = os.path.join(paths.dir_output_media, archive_media_filename)
1065 | archive_media_path = \
1066 | os.path.join(paths.dir_input_data, 'direct_messages_group_media',
1067 | archive_media_filename)
1068 | if os.path.isfile(archive_media_path):
1069 | # found a matching image, use this one
1070 | if not os.path.isfile(new_url):
1071 | shutil.copy(archive_media_path, new_url)
1072 | image_markdown = f'\n\n'
1073 | body_markdown = body_markdown.replace(
1074 | escape_markdown(original_expanded_url), image_markdown
1075 | )
1076 |
1077 | # Save the online location of the best-quality version of this file,
1078 | # for later upgrading if wanted
1079 | best_quality_url = \
1080 | f'https://ton.twitter.com/i//ton/data/dm/' \
1081 | f'{message_id}/{media_id}/{media_hash_and_type}'
1082 | # there is no ':orig' here, the url without any suffix has the original size
1083 |
1084 | # TODO: a cookie (and a 'Referer: https://twitter.com' header)
1085 | # is needed to retrieve it, so the url might be useless anyway...
1086 |
1087 | # WARNING: Do not uncomment the statement below until the cookie problem is solved!
1088 | # media_sources.append(
1089 | # (
1090 | # os.path.join(output_media_folder_name, archive_media_filename),
1091 | # best_quality_url
1092 | # )
1093 | # )
1094 |
1095 | else:
1096 | archive_media_paths = glob.glob(
1097 | os.path.join(paths.dir_input_data, 'direct_messages_group_media',
1098 | message_id + '*'))
1099 | if len(archive_media_paths) > 0:
1100 | for archive_media_path in archive_media_paths:
1101 | archive_media_filename = os.path.split(archive_media_path)[-1]
1102 | media_url = os.path.join(paths.dir_output_media,
1103 | archive_media_filename)
1104 | if not os.path.isfile(media_url):
1105 | shutil.copy(archive_media_path, media_url)
1106 | video_markdown = f'\n\n'
1108 | body_markdown = body_markdown.replace(
1109 | escape_markdown(original_expanded_url), video_markdown
1110 | )
1111 |
1112 | # TODO: maybe also save the online location of the best-quality version for videos?
1113 | # (see above)
1114 |
1115 | else:
1116 | print(f'Warning: missing local file: {archive_media_path}. '
1117 | f'Using original link instead: {original_expanded_url})')
1118 | created_at = message_create['createdAt'] # example: 2022-01-27T15:58:52.744Z
1119 | timestamp = int(round(
1120 | datetime.datetime.strptime(created_at, '%Y-%m-%dT%X.%fZ').timestamp()
1121 | ))
1122 | from_handle = escape_markdown(users[from_id].handle) if from_id in users \
1123 | else user_id_url_template.format(from_id)
1124 | # make the body a quote
1125 | body_markdown = '> ' + '\n> '.join(body_markdown.splitlines())
1126 | message_markdown = f'{from_handle}: ({created_at})\n\n' \
1127 | f'{body_markdown}'
1128 | messages.append((timestamp, message_markdown))
1129 | elif "conversationNameUpdate" in message:
1130 | conversation_name_update = message['conversationNameUpdate']
1131 | if all(tag in conversation_name_update for tag in ['initiatingUserId', 'name', 'createdAt']):
1132 | from_id = conversation_name_update['initiatingUserId']
1133 | body_markdown = f"_changed group name to: {escape_markdown(conversation_name_update['name'])}_"
1134 | created_at = conversation_name_update['createdAt'] # example: 2022-01-27T15:58:52.744Z
1135 | timestamp = int(round(
1136 | datetime.datetime.strptime(created_at, '%Y-%m-%dT%X.%fZ').timestamp()
1137 | ))
1138 | from_handle = escape_markdown(users[from_id].handle) if from_id in users \
1139 | else user_id_url_template.format(from_id)
1140 | message_markdown = f'{from_handle}: ({created_at})\n\n{body_markdown}'
1141 | messages.append((timestamp, message_markdown))
1142 | # save metadata about name change:
1143 | group_conversations_metadata[conversation_id]['conversation_names'].append(
1144 | (timestamp, conversation_name_update['name'])
1145 | )
1146 | elif "joinConversation" in message:
1147 | join_conversation = message['joinConversation']
1148 | if all(tag in join_conversation for tag in ['initiatingUserId', 'createdAt']):
1149 | from_id = join_conversation['initiatingUserId']
1150 | created_at = join_conversation['createdAt'] # example: 2022-01-27T15:58:52.744Z
1151 | timestamp = int(round(
1152 | datetime.datetime.strptime(created_at, '%Y-%m-%dT%X.%fZ').timestamp()
1153 | ))
1154 | from_handle = escape_markdown(users[from_id].handle) if from_id in users \
1155 | else user_id_url_template.format(from_id)
1156 | escaped_username = escape_markdown(username)
1157 | body_markdown = f'_{from_handle} added {escaped_username} to the group_'
1158 | message_markdown = f'{from_handle}: ({created_at})\n\n{body_markdown}'
1159 | messages.append((timestamp, message_markdown))
1160 | elif "participantsJoin" in message:
1161 | participants_join = message['participantsJoin']
1162 | if all(tag in participants_join for tag in ['initiatingUserId', 'userIds', 'createdAt']):
1163 | from_id = participants_join['initiatingUserId']
1164 | created_at = participants_join['createdAt'] # example: 2022-01-27T15:58:52.744Z
1165 | timestamp = int(round(
1166 | datetime.datetime.strptime(created_at, '%Y-%m-%dT%X.%fZ').timestamp()
1167 | ))
1168 | from_handle = escape_markdown(users[from_id].handle) if from_id in users \
1169 | else user_id_url_template.format(from_id)
1170 | joined_ids = participants_join['userIds']
1171 | joined_handles = [escape_markdown(users[joined_id].handle) if joined_id in users
1172 | else user_id_url_template.format(joined_id) for joined_id in joined_ids]
1173 | name_list = ', '.join(joined_handles[:-1]) + \
1174 | (f' and {joined_handles[-1]}' if len(joined_handles) > 1 else
1175 | joined_handles[0])
1176 | body_markdown = f'_{from_handle} added {name_list} to the group_'
1177 | message_markdown = f'{from_handle}: ({created_at})\n\n{body_markdown}'
1178 | messages.append((timestamp, message_markdown))
1179 | elif "participantsLeave" in message:
1180 | participants_leave = message['participantsLeave']
1181 | if all(tag in participants_leave for tag in ['userIds', 'createdAt']):
1182 | created_at = participants_leave['createdAt'] # example: 2022-01-27T15:58:52.744Z
1183 | timestamp = int(round(
1184 | datetime.datetime.strptime(created_at, '%Y-%m-%dT%X.%fZ').timestamp()
1185 | ))
1186 | left_ids = participants_leave['userIds']
1187 | left_handles = [escape_markdown(users[left_id].handle) if left_id in users
1188 | else user_id_url_template.format(left_id) for left_id in left_ids]
1189 | name_list = ', '.join(left_handles[:-1]) + \
1190 | (f' and {left_handles[-1]}' if len(left_handles) > 1 else
1191 | left_handles[0])
1192 | body_markdown = f'_{name_list} left the group_'
1193 | message_markdown = f'{name_list}: ({created_at})\n\n{body_markdown}'
1194 | messages.append((timestamp, message_markdown))
1195 |
1196 | # collect messages per conversation in group_conversations_messages dict
1197 | group_conversations_messages[conversation_id].extend(messages)
1198 |
1199 | # output as one file per conversation (or part of long conversation)
1200 | num_written_messages = 0
1201 | num_written_files = 0
1202 | for conversation_id, messages in group_conversations_messages.items():
1203 | # sort messages by timestamp
1204 | messages.sort(key=lambda tup: tup[0])
1205 | # create conversation name for use in filename:
1206 | # first, try to find an official name in the parsed conversation data
1207 |
1208 | # Not-so-fun fact:
1209 | # If the name was set before the archive's owner joined the group, the name is not included
1210 | # in the archive data and can't be found anywhere (except by looking it up from twitter,
1211 | # and that would probably need a cookie). So there are many groups that do actually have a name,
1212 | # but it can't be used here because we don't know it.
1213 |
1214 | group_conversations_metadata[conversation_id]['conversation_names'].sort(key=lambda tup: tup[0], reverse=True)
1215 | official_name = group_conversations_metadata[conversation_id]['conversation_names'][0][1]
1216 | safe_group_name = make_conversation_name_safe_for_filename(official_name)
1217 | if len(safe_group_name) < 2:
1218 | # discard name if it's too short (because of collision risk)
1219 | group_name = conversation_id
1220 | else:
1221 | group_name = safe_group_name
1222 |
1223 | if group_name == conversation_id:
1224 | # try to make a nice list of participant handles for the conversation name
1225 | handles = []
1226 | for participant_id, message_count in \
1227 | group_conversations_metadata[conversation_id]['participant_message_count'].items():
1228 | if participant_id in users:
1229 | participant_handle = users[participant_id].handle
1230 | if participant_handle != username:
1231 | handles.append((participant_handle, message_count))
1232 | # sort alphabetically by handle first, for a more deterministic order
1233 | handles.sort(key=lambda tup: tup[0])
1234 | # sort so that the most active users are at the start of the list
1235 | handles.sort(key=lambda tup: tup[1], reverse=True)
1236 | if len(handles) == 1:
1237 | group_name = \
1238 | f'{handles[0][0]}_and_{len(group_conversations_metadata[conversation_id]["participants"]) - 1}_more'
1239 | elif len(handles) == 2 and len(group_conversations_metadata[conversation_id]["participants"]) == 3:
1240 | group_name = f'{handles[0][0]}_and_{handles[1][0]}_and_{username}'
1241 | elif len(handles) >= 2:
1242 | group_name = \
1243 | f'{handles[0][0]}_and_{handles[1][0]}_and' \
1244 | f'_{len(group_conversations_metadata[conversation_id]["participants"]) - 2}_more'
1245 | else:
1246 | # just use the conversation id
1247 | group_name = conversation_id
1248 |
1249 | # create a list of names of the form '@name1, @name2 and @name3'
1250 | # to use as a headline in the output file
1251 | escaped_participant_names = [
1252 | escape_markdown(participant_name)
1253 | for participant_name in group_conversations_metadata[conversation_id]['participant_names']
1254 | ]
1255 | name_list = ', '.join(escaped_participant_names[:-1]) + \
1256 | (f' and {escaped_participant_names[-1]}'
1257 | if len(escaped_participant_names) > 1
1258 | else escaped_participant_names[0])
1259 |
1260 | if len(messages) > 1000:
1261 | for chunk_index, chunk in enumerate(chunks(messages, 1000)):
1262 | markdown = ''
1263 | markdown += f'## {official_name} ##\n\n'
1264 | markdown += f'### Group conversation between {name_list}, part {chunk_index + 1}: ###\n\n----\n\n'
1265 | markdown += '\n\n----\n\n'.join(md for _, md in chunk)
1266 | conversation_output_filename = paths.create_path_for_file_output_dms(
1267 | name=group_name, format="md", kind="DMs-Group", index=chunk_index + 1
1268 | )
1269 |
1270 | # write part to a markdown file
1271 | with open_and_mkdirs(conversation_output_filename) as f:
1272 | f.write(markdown)
1273 | print(f'Wrote {len(chunk)} messages to {conversation_output_filename}')
1274 | num_written_files += 1
1275 | else:
1276 | markdown = ''
1277 | markdown += f'## {official_name} ##\n\n'
1278 | markdown += f'### Group conversation between {name_list}: ###\n\n----\n\n'
1279 | markdown += '\n\n----\n\n'.join(md for _, md in messages)
1280 | conversation_output_filename = \
1281 | paths.create_path_for_file_output_dms(name=group_name, format="md", kind="DMs-Group")
1282 |
1283 | with open_and_mkdirs(conversation_output_filename) as f:
1284 | f.write(markdown)
1285 | print(f'Wrote {len(messages)} messages to {conversation_output_filename}')
1286 | num_written_files += 1
1287 |
1288 | num_written_messages += len(messages)
1289 |
1290 | print(f"\nWrote {len(group_conversations_messages)} direct message group conversations "
1291 | f"({num_written_messages} total messages) to {num_written_files} markdown files")
1292 |
1293 |
1294 | def migrate_old_output(paths: PathConfig):
1295 | """If present, moves media and cache files from the archive root to the new locations in
1296 | `paths.dir_output_media` and `paths.dir_output_cache`. Then deletes old output files
1297 | (md, html, txt) from the archive root, if the user consents."""
1298 |
1299 | # Create new folders, so we can potentially use them to move files there
1300 | os.makedirs(paths.dir_output_media, exist_ok=True)
1301 | os.makedirs(paths.dir_output_cache, exist_ok=True)
1302 |
1303 | # Move files that we can re-use:
1304 | if os.path.exists(os.path.join(paths.dir_archive, "media")):
1305 | files_to_move = glob.glob(os.path.join(paths.dir_archive, "media", "*"))
1306 | if len(files_to_move) > 0:
1307 | print(f"Moving {len(files_to_move)} files from 'media' to '{paths.dir_output_media}'")
1308 | for file_path_to_move in files_to_move:
1309 | file_name_to_move = os.path.split(file_path_to_move)[1]
1310 | os.rename(file_path_to_move, os.path.join(paths.dir_output_media, file_name_to_move))
1311 | os.rmdir(os.path.join(paths.dir_archive, "media"))
1312 |
1313 | known_tweets_old_path = os.path.join(paths.dir_archive, "known_tweets.json")
1314 | known_tweets_new_path = os.path.join(paths.dir_output_cache, "known_tweets.json")
1315 | if os.path.exists(known_tweets_old_path):
1316 | os.rename(known_tweets_old_path, known_tweets_new_path)
1317 |
1318 | # Delete files that would be overwritten anyway (if user consents):
1319 | output_globs = [
1320 | "TweetArchive.html",
1321 | "*Tweet-Archive*.html",
1322 | "*Tweet-Archive*.md",
1323 | "DMs-Archive-*.html",
1324 | "DMs-Archive-*.md",
1325 | "DMs-Group-Archive-*.html",
1326 | "DMs-Group-Archive-*.md",
1327 | "followers.txt",
1328 | "following.txt",
1329 | ]
1330 | files_to_delete = []
1331 |
1332 | for output_glob in output_globs:
1333 | files_to_delete += glob.glob(os.path.join(paths.dir_archive, output_glob))
1334 |
1335 | # TODO maybe remove those files only after the new ones have been generated? This way, the user would never
1336 | # end up with less output than before. On the other hand, they might end up with old *and* new versions
1337 | # of the output, if the script crashes before it reaches the code to delete the old version.
1338 | if len(files_to_delete) > 0:
1339 | print(f"\nThere are {len(files_to_delete)} files in the root of the archive,")
1340 | print("which were probably generated from an older version of this script.")
1341 | print("Since then, the directory layout of twitter-archive-parser has changed")
1342 | print("and these files are generated into the sub-directory 'parser-output' or")
1343 | print("various sub-sub-directories therein. These are the affected files:")
1344 |
1345 | for file_to_delete in files_to_delete:
1346 | print(file_to_delete)
1347 |
1348 | user_input = input('\nOK delete these files? (If the the directory layout would not have changed, they would be overwritten anyway) [y/N]')
1349 | if user_input.lower() in ('y', 'yes'):
1350 | for file_to_delete in files_to_delete:
1351 | os.remove(file_to_delete)
1352 | print(f"Files have been deleted. New versions of these files will be generated into 'parser-output' soon.")
1353 |
1354 |
1355 | def is_archive(path):
1356 | """Return true if there is a Twitter archive at the given path"""
1357 | return os.path.isfile(os.path.join(path, 'data', 'account.js'))
1358 |
1359 |
1360 | def find_archive():
1361 | """
1362 | Search for the archive
1363 | 1. First try the working directory.
1364 | 2. Then try the script directory.
1365 | 3. Finally prompt the user.
1366 | """
1367 | if is_archive('.'):
1368 | return '.'
1369 | script_dir = os.path.dirname(__file__)
1370 | if script_dir != os.getcwd():
1371 | if is_archive(script_dir):
1372 | return script_dir
1373 | print('Archive not found in working directory or script directory.\n'
1374 | 'Please enter the path of your Twitter archive, or just press Enter to exit.\n'
1375 | 'On most operating systems, you can also try to drag and drop your archive folder '
1376 | 'into the terminal window, and it will paste its path automatically.\n')
1377 | # Give the user as many attempts as they need.
1378 | while True:
1379 | input_path = input('Archive path: ')
1380 | if not input_path:
1381 | exit()
1382 | if is_archive(input_path):
1383 | return input_path
1384 | print(f'Archive not found at {input_path}')
1385 |
1386 |
1387 | def main(lang):
1388 | archive_path = find_archive()
1389 | paths = PathConfig(dir_archive=archive_path)
1390 |
1391 | # Extract the archive owner's username from data/account.js
1392 | username = extract_username(paths)
1393 |
1394 | user_id_url_template = 'https://twitter.com/i/user/{}'
1395 |
1396 | html_template = """\
1397 |
1398 |
1399 |
1400 |
1401 |
1402 |
1404 | Your Twitter archive!
1405 |
1406 |
1407 |
Your twitter archive
1408 |
1409 | {}
1410 |
1411 |
1412 | """
1413 |
1414 | users = {}
1415 |
1416 | migrate_old_output(paths)
1417 |
1418 | # Make a folder to copy the images and videos into.
1419 | os.makedirs(paths.dir_output_media, exist_ok=True)
1420 | if not os.path.isfile(paths.file_tweet_icon):
1421 | shutil.copy('assets/images/favicon.ico', paths.file_tweet_icon)
1422 |
1423 | media_sources = parse_tweets(username, users, html_template, paths, lang)
1424 |
1425 | following_ids = collect_user_ids_from_followings(paths)
1426 | print(f'found {len(following_ids)} user IDs in followings.')
1427 | follower_ids = collect_user_ids_from_followers(paths)
1428 | print(f'found {len(follower_ids)} user IDs in followers.')
1429 | dms_user_ids = collect_user_ids_from_direct_messages(paths)
1430 | print(f'found {len(dms_user_ids)} user IDs in direct messages.')
1431 | group_dms_user_ids = collect_user_ids_from_group_direct_messages(paths)
1432 | print(f'found {len(group_dms_user_ids)} user IDs in group direct messages.')
1433 |
1434 | # bulk lookup for user handles from followers, followings, direct messages and group direct messages
1435 | collected_user_ids_without_followers = list(
1436 | set(following_ids).union(set(dms_user_ids)).union(set(group_dms_user_ids))
1437 | )
1438 | collected_user_ids_only_in_followers: set = set(follower_ids).difference(set(collected_user_ids_without_followers))
1439 | collected_user_ids: list = list(set(collected_user_ids_without_followers)
1440 | .union(collected_user_ids_only_in_followers))
1441 |
1442 | print(f'\nfound {len(collected_user_ids)} user IDs overall.')
1443 |
1444 | # give the user a choice if followers should be included in the lookup
1445 | # (but only in case they make up a large amount):
1446 | unknown_collected_user_ids: set = set(collected_user_ids).difference(users.keys())
1447 | unknown_follower_user_ids: set = unknown_collected_user_ids.intersection(collected_user_ids_only_in_followers)
1448 | if len(unknown_follower_user_ids) > 5000:
1449 | # Account metadata observed at ~2.1KB on average.
1450 | estimated_follower_lookup_size = int(2.1 * len(unknown_follower_user_ids))
1451 | # we can look up at least 3000 users per minute.
1452 | estimated_max_follower_lookup_time_in_minutes = len(unknown_follower_user_ids) / 3000
1453 | print(
1454 | f'For some user IDs, the @handle is not included in the archive data. '
1455 | f'Unknown user handles can be looked up online.'
1456 | f'{len(unknown_follower_user_ids)} of {len(unknown_collected_user_ids)} total '
1457 | f'user IDs with unknown handles are from your followers. Online lookup would be '
1458 | f'about {estimated_follower_lookup_size:,} KB smaller and up to '
1459 | f'{estimated_max_follower_lookup_time_in_minutes:.1f} minutes faster without them.\n'
1460 | )
1461 |
1462 | if not get_consent(f'Do you want to include handles of your followers '
1463 | f'in the online lookup of user handles anyway?', default_to_yes=True):
1464 | collected_user_ids = collected_user_ids_without_followers
1465 |
1466 | lookup_users(collected_user_ids, users)
1467 |
1468 | parse_followings(users, user_id_url_template, paths)
1469 | parse_followers(users, user_id_url_template, paths)
1470 | #parse_direct_messages(username, users, user_id_url_template, paths)
1471 | #parse_group_direct_messages(username, users, user_id_url_template, paths)
1472 |
1473 | # Download larger images, if the user agrees
1474 | print(f"\nThe archive doesn't contain the original-size images. We can attempt to download them from twimg.com.")
1475 | print(f'Please be aware that this script may download a lot of data, which will cost you money if you are')
1476 | print(f'paying for bandwidth. Please be aware that the servers might block these requests if they are too')
1477 | print(f'frequent. This script may not work if your account is protected. You may want to set it to public')
1478 | print(f'before starting the download.\n')
1479 |
1480 | if get_consent('OK to start downloading?'):
1481 | download_larger_media(media_sources, paths)
1482 | print('In case you set your account to public before initiating the download, '
1483 | 'do not forget to protect it again.')
1484 |
1485 |
1486 | if __name__ == "__main__":
1487 | main(lang=sys.argv[1] if len(sys.argv) > 1 else 'zh_hans')
1488 |
--------------------------------------------------------------------------------