├── dataset
└── test
│ ├── pickle
│ └── .keep
│ ├── spiece.model
│ └── vocab.txt
├── requirements.txt
├── configs
├── __init__.py
├── finetune.py
├── train.large.py
├── train.medium.py
├── train.small.py
├── train.py
└── test.py
├── .gitignore
├── predict.py
├── README.md
├── finetune.py
├── cut_words.py
├── examples
├── gpt2_quickly.ipynb
├── ai_noval_demo.ipynb
├── gpt2_medium_chinese.ipynb
└── mixed_precision_test.ipynb
├── predata.py
├── performer.py
├── train.py
├── util.py
└── fast_attention.py
/dataset/test/pickle/.keep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/dataset/test/spiece.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mymusise/gpt2-quickly/HEAD/dataset/test/spiece.model
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.18.0
2 | tokenizers==0.12.1
3 | tensorflow==2.9.1
4 | keras==2.9.0
5 |
6 | tqdm
7 | click
8 | jieba
9 | sentencepiece==0.1.91
10 |
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | if os.environ.get('ENV', 'DEV') == 'PRO':
4 | from .train import *
5 | else:
6 | from .test import *
7 |
8 | __all__ = ['path', 'model_path', 'configs', 'data']
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.log
2 | *.txt
3 | *.pickle
4 | logs/
5 |
6 | ./dataset/*
7 | !*/test/raw.txt
8 | !*/test/vocab.txt
9 | !requirements.txt
10 |
11 | venv/
12 | envs/
13 | env/
14 | .vscode/
15 |
16 | models*
17 | checkpoint
18 | configs/train.py
19 |
20 | *.pyc
21 |
--------------------------------------------------------------------------------
/configs/finetune.py:
--------------------------------------------------------------------------------
1 | path = '/data/novels'
2 | model_path = path + '/models/'
3 |
4 | data = {
5 | 'path': path + '/train/',
6 | }
7 | data = {
8 | **data,
9 | 'raw': data['path'] + 'raw.txt',
10 | 'vocab': data['path'] + 'vocab.txt',
11 | 'pickle': data['path'] + 'data.pickle',
12 | }
13 |
14 | model = {
15 | 'max_length': 1024,
16 | 'batch_size': 2,
17 | }
18 |
19 |
20 | data = type('data', (), data)
21 | model = type('model', (), model)
22 |
--------------------------------------------------------------------------------
/configs/train.large.py:
--------------------------------------------------------------------------------
1 | path = '/data/wiki_zh'
2 | model_path = path + '/models/'
3 |
4 | data = {
5 | 'path': path + '/train/',
6 | }
7 | data = {
8 | **data,
9 | 'raw': data['path'] + 'raw.txt',
10 | 'vocab': data['path'] + 'vocab.txt',
11 | 'pickle': data['path'] + 'data.pickle',
12 | }
13 |
14 | model = {
15 | 'max_length': 1024,
16 | 'n_positions': 1024,
17 | 'n_ctx': 1024,
18 | 'n_embd': 1280,
19 | 'n_layer': 36,
20 | 'n_head': 20,
21 | 'batch_size': 2
22 | }
23 |
24 |
25 | data = type('data', (), data)
26 | model = type('model', (), model)
27 |
--------------------------------------------------------------------------------
/configs/train.medium.py:
--------------------------------------------------------------------------------
1 | path = '/data/wiki_zh'
2 | model_path = path + '/models/'
3 |
4 | data = {
5 | 'path': path + '/train/',
6 | }
7 | data = {
8 | **data,
9 | 'raw': data['path'] + 'raw.txt',
10 | 'vocab': data['path'] + 'vocab.txt',
11 | 'pickle': data['path'] + 'data.pickle',
12 | }
13 |
14 | model = {
15 | 'max_length': 1024,
16 | 'n_positions': 1024,
17 | 'n_ctx': 1024,
18 | 'n_embd': 1024,
19 | 'n_layer': 24,
20 | 'n_head': 16,
21 | 'batch_size': 6
22 | }
23 |
24 |
25 | data = type('data', (), data)
26 | model = type('model', (), model)
27 |
--------------------------------------------------------------------------------
/configs/train.small.py:
--------------------------------------------------------------------------------
1 | path = '/data/wiki_zh'
2 | model_path = path + '/models/'
3 |
4 | data = {
5 | 'path': path + '/train/',
6 | }
7 | data = {
8 | **data,
9 | 'raw': data['path'] + 'raw.txt',
10 | 'vocab': data['path'] + 'vocab.txt',
11 | 'pickle': data['path'] + 'data.pickle',
12 | }
13 |
14 | model = {
15 | 'max_length': 1024,
16 | 'n_positions': 1024,
17 | 'n_ctx': 1024,
18 | 'n_embd': 768,
19 | 'n_layer': 12,
20 | 'n_head': 12,
21 | 'batch_size': 6
22 | }
23 |
24 |
25 | data = type('data', (), data)
26 | model = type('model', (), model)
27 |
--------------------------------------------------------------------------------
/configs/train.py:
--------------------------------------------------------------------------------
1 | path = '/data/wiki_zh'
2 | model_path = path + '/models/'
3 |
4 | data = {
5 | 'path': path + '/train/',
6 | }
7 | data = {
8 | **data,
9 | 'raw': data['path'] + 'raw.txt',
10 | 'raw_cut': data['path'] + 'raw.cut.txt',
11 | 'vocab': data['path'] + 'vocab.txt',
12 | 'pickle': data['path'] + 'data.pickle',
13 | }
14 |
15 | model = {
16 | 'max_length': 1024,
17 | 'n_positions': 1024,
18 | 'n_ctx': 1024,
19 | 'n_embd': 1024,
20 | 'n_layer': 24,
21 | 'n_head': 16,
22 | 'batch_size': 6
23 | }
24 |
25 |
26 | data = type('data', (), data)
27 | model = type('model', (), model)
28 |
--------------------------------------------------------------------------------
/configs/test.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | path = './dataset'
4 | model_path = path + '/models/'
5 |
6 | data = {
7 | 'path': path + '/test/',
8 | }
9 | data = {
10 | **data,
11 | 'raw': data['path'] + 'raw.txt',
12 | 'raw_cut': data['path'] + 'raw.cut.txt',
13 | 'vocab': data['path'] + 'vocab.txt',
14 | 'pickle': data['path'] + 'data.pickle',
15 | 'pickle_path': os.path.join(data['path'], 'pickle')
16 | }
17 |
18 | model = {
19 | 'max_length': 512,
20 | 'n_positions': 512,
21 | 'n_ctx': 512,
22 | 'n_embd': 768,
23 | 'n_layer': 4,
24 | 'n_head': 4,
25 | 'batch_size': 8
26 | }
27 |
28 |
29 | data = type('data', (), data)
30 | model = type('model', (), model)
31 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | from transformers import TextGenerationPipeline
2 | from transformers import GPT2Tokenizer
3 | from train import init_model, load_tokenizer
4 |
5 | tokenizer = load_tokenizer()
6 | model = init_model(tokenizer)
7 |
8 | text_generator = TextGenerationPipeline(model, tokenizer)
9 | print(text_generator("唐诗:", max_length=64, do_sample=True, top_k=10, eos_token_id=tokenizer.get_vocab().get("】", 0)))
10 | print(text_generator("此地是我开", max_length=64, do_sample=True, top_k=10, eos_token_id=tokenizer.get_vocab().get("】", 0)))
11 | print(text_generator("一只乌鸦", max_length=64, do_sample=True, top_k=10, eos_token_id=tokenizer.get_vocab().get("】", 0)))
12 | print(text_generator("走向森林 ", max_length=64, do_sample=False))
13 | print(text_generator("拿出一本秘籍", max_length=64, do_sample=False))
14 | print(text_generator("今日", max_length=64, do_sample=False))
15 | print(text_generator("大江东去", max_length=64, do_sample=False))
16 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | GPT2 Quickly
3 |
4 |
5 |
6 |
Build your own GPT2 quickly, without doing many useless work.
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | This project is base on 🤗 transformer. This tutorial show you how to train your own language(such as Chinese or Japanese) GPT2 model in a few code with Tensorflow 2.
16 |
17 | You can try this project in [colab](https://colab.research.google.com/github/mymusise/gpt2-quickly/blob/main/examples/gpt2_quickly.ipynb) right now.
18 |
19 | ## Main file
20 |
21 | ```
22 |
23 | ├── configs
24 | │ ├── test.py
25 | │ └── train.py
26 | ├── build_tokenizer.py
27 | ├── predata.py
28 | ├── predict.py
29 | └── train.py
30 | ```
31 |
32 | ## Preparation
33 |
34 |
35 | ### virtualenv
36 | ``` bash
37 | git clone git@github.com:mymusise/gpt2-quickly.git
38 | cd gpt2-quickly
39 | python3 -m venv venv
40 | source venv/bin/activate
41 |
42 | pip install -r requirements.txt
43 | ```
44 |
45 | ### Install google/sentencepiece
46 |
47 | - see [https://github.com/google/sentencepiece#installation](https://github.com/google/sentencepiece#installation)
48 |
49 |
50 | ## 0x00. prepare your raw dataset
51 |
52 | this is a example of raw dataset: [raw.txt](dataset/test/raw.txt)
53 |
54 |
55 | ## 0x01. Build vocab
56 |
57 | ```bash
58 | python cut_words.py
59 | python build_tokenizer.py
60 | ```
61 |
62 |
63 | ## 0x02. Tokenize
64 |
65 | ```bash
66 | python predata.py --n_processes=2
67 | ```
68 |
69 |
70 | ## 0x03 Train
71 |
72 | ```bash
73 | python train.py
74 | ```
75 |
76 |
77 | ## 0x04 Predict
78 |
79 | ```bash
80 | python predict.py
81 | ```
82 |
83 | ## 0x05 Fine-Tune
84 |
85 | ```bash
86 | ENV=FINETUNE python finetune.py
87 | ```
88 |
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 |
2 | from transformers import BertTokenizer, TFGPT2LMHeadModel
3 | from transformers import GPT2Config, TFGPT2LMHeadModel
4 | from transformers import TextGenerationPipeline
5 | from official import nlp
6 | import official.nlp.optimization
7 | from train import load_tokenizer, train, get_dataset
8 | import tensorflow as tf
9 | from configs import finetune as configs
10 | import click
11 |
12 |
13 | def load_model(train_steps, num_warmup_steps):
14 | try: # try to load finetuned model at local.
15 | tokenizer = load_tokenizer()
16 | config = GPT2Config.from_pretrained(configs.model_path, return_dict=False)
17 | model = TFGPT2LMHeadModel.from_pretrained(configs.model_path, return_dict=False)
18 | print("model loaded from local!")
19 | except Exception as e:
20 | tokenizer = BertTokenizer.from_pretrained(
21 | "mymusise/gpt2-medium-chinese")
22 | model = TFGPT2LMHeadModel.from_pretrained(
23 | "mymusise/gpt2-medium-chinese", return_dict=False)
24 | print("model loaded from remote!")
25 |
26 | loss = model.compute_loss
27 | optimizer = nlp.optimization.create_optimizer(
28 | 5e-5, num_train_steps=train_steps, num_warmup_steps=num_warmup_steps)
29 | metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
30 |
31 | model.compile(
32 | optimizer=optimizer,
33 | loss=[loss, *[None] * model.config.n_layer],
34 | # metrics=[metric]
35 | )
36 | return model
37 |
38 |
39 | @click.command()
40 | @click.option('--epochs', default=20, help='number of epochs')
41 | @click.option('--train_steps', default=2000, help='number of train_steps')
42 | def finetune(epochs, train_steps):
43 | warmup_steps = int(train_steps * epochs * 0.1)
44 |
45 | train_dataset = get_dataset()
46 | model = load_model(train_steps, warmup_steps)
47 | train(model, train_dataset, epochs, train_steps)
48 |
49 |
50 | if __name__ == '__main__':
51 | finetune()
52 |
--------------------------------------------------------------------------------
/cut_words.py:
--------------------------------------------------------------------------------
1 | import configs
2 | import jieba
3 | from tqdm import tqdm
4 | import click
5 | from typing import List
6 | import os
7 | from multiprocessing import Process, Manager, Queue
8 | from pathlib import Path
9 |
10 |
11 | def cut_words(processer_num, text, result_dict):
12 | with open(os.path.join(configs.data.path, f'raw.cut.temp.{processer_num}.txt'), 'w') as out_f:
13 | texts = text.split('\n')
14 | for line in tqdm(texts):
15 | try:
16 | cuts = " ".join(jieba.cut(line))
17 | out_f.write(cuts+'\n')
18 | except UnicodeDecodeError:
19 | pass
20 | except KeyError:
21 | pass
22 | except Exception as e:
23 | pass
24 |
25 |
26 | def multiply_cut(handler, tasks):
27 | manager = Manager()
28 | result_dict = manager.dict() # didn't work and don't know why
29 | jobs = []
30 | for processer_num, task in enumerate(tasks):
31 | p = Process(target=handler, args=(
32 | processer_num, task, result_dict))
33 | jobs.append(p)
34 |
35 | for job in jobs:
36 | job.start()
37 |
38 | for job in jobs:
39 | job.join()
40 |
41 | for job in jobs:
42 | try:
43 | job.close() # It may raise exception in python <=3.6
44 | except:
45 | pass
46 | print("[all_task done]")
47 |
48 |
49 | def split_data(
50 | text,
51 | n_processes
52 | ) -> List[str]:
53 | text_task = []
54 | num_pre_task = len(text) // n_processes
55 | for i in range(0, len(text), num_pre_task):
56 | text_task.append(text[i: i + num_pre_task])
57 | return text_task
58 |
59 |
60 | @click.command()
61 | @click.option('--n_processes', default=1, help='Number of processes.')
62 | def preprocess(n_processes):
63 | print(f'reading {configs.data.raw}')
64 | with open(configs.data.raw, 'r') as f:
65 | data = f.read().replace(' ', ' ').replace('\n\n', '\n')
66 | print(f"total words: {len(data)}")
67 |
68 | print(f"split data into {n_processes} pieces")
69 | text_task = split_data(data, n_processes)
70 |
71 | multiply_cut(cut_words, text_task)
72 |
73 | path = Path(configs.data.path)
74 | with open(configs.data.raw_cut, 'w') as all_cut_file:
75 | for filename in path.glob('raw.cut.temp.*'):
76 | with open(filename) as cut_file:
77 | all_cut_file.write(cut_file.read()+'\n')
78 | print(f'dropping {filename}')
79 | os.system(f'rm {filename}')
80 |
81 |
82 | if __name__ == '__main__':
83 | preprocess()
84 |
--------------------------------------------------------------------------------
/examples/gpt2_quickly.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "gpt2-quickly.ipynb",
7 | "provenance": []
8 | },
9 | "kernelspec": {
10 | "name": "python3",
11 | "display_name": "Python 3",
12 | "language": "python"
13 | },
14 | "accelerator": "GPU"
15 | },
16 | "cells": [
17 | {
18 | "cell_type": "code",
19 | "metadata": {
20 | "id": "omMmlBcEhAnZ",
21 | "outputId": "db84964d-fd61-4243-962b-2132de000d47",
22 | "colab": {
23 | "base_uri": "https://localhost:8080/"
24 | }
25 | },
26 | "source": [
27 | "!git clone https://github.com/mymusise/gpt2-quickly.git\n",
28 | "%cd gpt2-quickly\n",
29 | "!pip install -r requirements.txt"
30 | ],
31 | "execution_count": null,
32 | "outputs": []
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": 1,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "#@title install google/sentencepiece\n",
41 | "\n",
42 | "!git clone https://github.com/google/sentencepiece.git \n",
43 | "%cd sentencepiece\n",
44 | "!mkdir build\n",
45 | "%cd build\n",
46 | "!cmake ..\n",
47 | "!make -j $(nproc)\n",
48 | "!sudo make install\n",
49 | "!sudo ldconfig -v\n",
50 | "%cd ../../"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "metadata": {
56 | "id": "oIV0l_g-i_lK",
57 | "outputId": "17cedfe4-b491-49c3-dadb-4bba3079a9d0",
58 | "colab": {
59 | "base_uri": "https://localhost:8080/"
60 | }
61 | },
62 | "source": [
63 | "!head -n 5 dataset/test/raw.txt"
64 | ],
65 | "execution_count": null,
66 | "outputs": []
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "!python cut_words.py\n",
75 | "!python build_tokenizer.py\n",
76 | "!head -n 20 dataset/test/vocab.txt"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "metadata": {
82 | "id": "6sw50oOIi-Be",
83 | "outputId": "88a689fa-2a14-41d0-dbf3-9149ec729836",
84 | "colab": {
85 | "base_uri": "https://localhost:8080/"
86 | }
87 | },
88 | "source": [
89 | "!python predata.py\n",
90 | "!python train.py"
91 | ],
92 | "execution_count": null,
93 | "outputs": []
94 | },
95 | {
96 | "cell_type": "code",
97 | "metadata": {
98 | "id": "xSFuK-tgs9xf"
99 | },
100 | "source": [
101 | "!python predict.py"
102 | ],
103 | "execution_count": null,
104 | "outputs": []
105 | }
106 | ]
107 | }
--------------------------------------------------------------------------------
/predata.py:
--------------------------------------------------------------------------------
1 | from train import load_tokenizer
2 | from multiprocessing import Process
3 | from tqdm import tqdm
4 | from typing import List
5 | import tensorflow as tf
6 | import pickle
7 | import json
8 | import configs
9 | import numpy as np
10 | import click
11 | import os
12 | import random
13 | import time
14 | import gc
15 |
16 |
17 | def encode_processer(processer_num: int):
18 | tokenizer = load_tokenizer()
19 | contents = pickle.load(open(os.path.join(configs.data.pickle_path, f'data.{processer_num}.jsonp'), 'rb'))
20 | contents = contents.split('\n\n|-|\n\n')
21 | output_file_l = open(os.path.join(configs.data.pickle_path, f'data.{processer_num}.l.pickle'), 'wb')
22 | output_file_m = open(os.path.join(configs.data.pickle_path, f'data.{processer_num}.m.pickle'), 'wb')
23 | output_file_s = open(os.path.join(configs.data.pickle_path, f'data.{processer_num}.s.pickle'), 'wb')
24 | for content in tqdm(contents, desc=f"processer_{processer_num}"):
25 | if len(content) < 24:
26 | continue
27 | if len(content) <= 64 - 2:
28 | pre_size = 64
29 | output_file = output_file_s
30 | elif 64 - 2 < len(content) <= 128 - 2:
31 | pre_size = 128
32 | output_file = output_file_m
33 | else:
34 | pre_size = configs.model.max_length
35 | output_file = output_file_l
36 | content = tokenizer.sep_token + content + tokenizer.cls_token
37 | content_decoded = tokenizer(content, return_attention_mask=False,
38 | return_token_type_ids=False, add_special_tokens=False)['input_ids']
39 |
40 | if len(content_decoded) > pre_size:
41 | end_left_size = 64
42 | block_size = (pre_size - end_left_size)
43 | block_num = (len(content_decoded) - pre_size) // block_size
44 | block_num += 1 if (len(content_decoded) -
45 | pre_size) % block_size != 0 else 0
46 | new_content = [content_decoded[:pre_size]]
47 | for i in range(block_num):
48 | _block = content_decoded[pre_size + i *
49 | block_size-end_left_size:pre_size + (i+1)*block_size]
50 | new_content.append(_block)
51 | else:
52 | new_content = [content_decoded]
53 | if len(new_content[-1]) < pre_size and len(new_content) > 1:
54 | new_content[-1] = new_content[-2][-(pre_size - len(new_content[-1])):] + new_content[-1]
55 | else:
56 | new_content[-1] = new_content[-1] + [tokenizer.pad_token_id] * (pre_size - len(new_content[-1]))
57 | if len(new_content) > 0:
58 | input_ids = np.array(new_content, dtype=np.int32)
59 | output_file.write(pickle.dumps(input_ids)+'换行'.encode())
60 |
61 |
62 |
63 | def multiply_encode(handler, n_processes):
64 | jobs = []
65 | for processer_num in range(n_processes):
66 | p = Process(target=handler, args=(
67 | processer_num, ))
68 | jobs.append(p)
69 |
70 | for job in jobs:
71 | job.start()
72 |
73 | for job in jobs:
74 | job.join()
75 |
76 | for job in jobs:
77 | job.close() # It may raise exception in python <=3.6
78 | print("[all_task done]")
79 |
80 |
81 | def split_data(
82 | texts,
83 | n_processes,
84 | block_size,
85 | split_token_re=r"(。|?|!|\n)",
86 | ) -> List[str]:
87 | num_pre_task = len(texts) // n_processes
88 | print(num_pre_task, len(texts), n_processes)
89 | for _i, i in tqdm(enumerate(range(0, len(texts), num_pre_task)), desc="spliting data..."):
90 | current_text = texts[i: i + num_pre_task]
91 | with open(os.path.join(configs.data.pickle_path, f'data.{_i}.jsonp'), 'wb') as output_file:
92 | pickle.dump(current_text, output_file)
93 | print("pre task num: ", len(current_text))
94 | del current_text
95 |
96 |
97 | @click.command()
98 | @click.option('--n_processes', default=1, help='Number of processes.')
99 | def preprocess(n_processes):
100 | block_size = configs.model.max_length
101 |
102 | print(f'reading {configs.data.raw}')
103 |
104 | data = []
105 | print('reading raw data ...')
106 | with open(configs.data.raw, 'r') as f:
107 | for line in tqdm(list(f.readlines())):
108 | if len(line) > 0:
109 | data.append(line)
110 | del line
111 | random.shuffle(data)
112 | data = "\n\n|-|\n\n".join(data)
113 | split_data(data, n_processes, block_size)
114 | del data
115 | gc.collect()
116 | # time.sleep(1000)
117 | multiply_encode(encode_processer, n_processes)
118 |
119 |
120 | if __name__ == '__main__':
121 | preprocess()
122 |
--------------------------------------------------------------------------------
/examples/ai_noval_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "ai_noval_demo.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": []
9 | },
10 | "kernelspec": {
11 | "display_name": "Python 3",
12 | "name": "python3"
13 | }
14 | },
15 | "cells": [
16 | {
17 | "cell_type": "code",
18 | "metadata": {
19 | "id": "-irg8cbucTA6"
20 | },
21 | "source": [
22 | "!pip install transformers sentencepiece"
23 | ],
24 | "execution_count": null,
25 | "outputs": []
26 | },
27 | {
28 | "cell_type": "code",
29 | "metadata": {
30 | "id": "4ba0k0hkcNks",
31 | "outputId": "b023991e-da51-4344-acf1-905b59ce1b44",
32 | "colab": {
33 | "base_uri": "https://localhost:8080/"
34 | }
35 | },
36 | "source": [
37 | "from transformers import XLNetTokenizer, TFGPT2LMHeadModel\n",
38 | "\n",
39 | "tokenizer = XLNetTokenizer.from_pretrained('mymusise/EasternFantasyNoval-small')\n",
40 | "model = TFGPT2LMHeadModel.from_pretrained(\"mymusise/EasternFantasyNoval-small\")"
41 | ],
42 | "execution_count": null,
43 | "outputs": [
44 | {
45 | "output_type": "stream",
46 | "text": [
47 | "All model checkpoint layers were used when initializing TFGPT2LMHeadModel.\n",
48 | "\n",
49 | "All the layers of TFGPT2LMHeadModel were initialized from the model checkpoint at mymusise/EasternFantasyNoval.\n",
50 | "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFGPT2LMHeadModel for predictions without further training.\n"
51 | ],
52 | "name": "stderr"
53 | }
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "metadata": {
59 | "id": "AHfpYLsPdEgU",
60 | "outputId": "a4c82e77-6840-44a3-c298-f11c6731a741",
61 | "colab": {
62 | "background_save": true
63 | }
64 | },
65 | "source": [
66 | "from transformers import TextGenerationPipeline\n",
67 | "import jieba\n",
68 | "\n",
69 | "text_generater = TextGenerationPipeline(model, tokenizer)\n",
70 | "\n",
71 | "texts = [\n",
72 | " '少年对面站着一位中年人,这位中年人两鬓略有些斑白,穿着一套青衫。尽管衣衫有些脱色,但却洗得很干净。',\n",
73 | " '风凡答应了一声,将地上的东西都收了起来,然后端坐在地上,开始默想绿简的样子。他的脑海中出现了一片片字迹,这正是五行玄蒙经的内容。',\n",
74 | "]\n",
75 | "\n",
76 | "for text in texts:\n",
77 | " print(text_generater(text, max_length=120 + len(text), do_sample=True, top_k=0))\n",
78 | " print(text_generater(text, max_length=120 + len(text), do_sample=True, top_k=20))\n",
79 | " print(text_generater(text, max_length=120 + len(text), do_sample=True, top_k=0, no_repeat_ngram_size=2))\n",
80 | " print(text_generater(text, max_length=120 + len(text), do_sample=True, top_k=10, no_repeat_ngram_size=3))\n",
81 | " print(text_generater(text, max_length=120 + len(text), do_sample=True, top_k=10, no_repeat_ngram_size=3))"
82 | ],
83 | "execution_count": null,
84 | "outputs": [
85 | {
86 | "output_type": "stream",
87 | "text": [
88 | "[{'generated_text': '少年对面站着一位中年人,这位中年人两鬓略有些斑白,穿着一套青衫。尽管衣衫有些脱色,但却洗得很干净。他双目微微闪烁,沉声道:“真是一位好人,刚才还问我是自己的同伴,是否喜欢我?” 那中年人正是夜岚。 句芒已经带着一群手下退了回去。 是夜岚,夜天岚,夜岚,夜无痕要跟随夜天璇们同时外出历练。 夜天岚和夜风一边赶路,一边追赶夜天痕到处赶路,夕亦岚和夜风也是紧跟其'}]\n",
89 | "[{'generated_text': '少年对面站着一位中年人,这位中年人两鬓略有些斑白,穿着一套青衫。尽管衣衫有些脱色,但却洗得很干净。 “看来还是得去看看。这几年的苦苦劳动,我们可都忘记了。”那人看着自己的长辈,微笑道,声音虽然有些颤抖,却并没有像其他人一样,但却是那种感觉。 “嗯......不要太过紧张,那些人我很好。”这时,那中年人却是忽然道:“你可以走了。” 中年男子冷笑一声,道:“有你的份,我绝'}]\n",
90 | "[{'generated_text': '少年对面站着一位中年人,这位中年人两鬓略有些斑白,穿着一套青衫。尽管衣衫有些脱色,但却洗得很干净。他身下的灰尘看起来颇为华丽。 此人身边跟着个虽还是那具尸体,不过这人是福星将他关进了蓬莱仙岛的大弟子璟最后给他的一张画像。 “桑瑟瑟。”那中壮汉手挥了下手中的玉瓶。如这画上装着的不是。那原本应该是合天雷的长老。昭日神功在这儿完全探得一干二净,了却流通,退而'}]\n",
91 | "[{'generated_text': '少年对面站着一位中年人,这位中年人两鬓略有些斑白,穿着一套青衫。尽管衣衫有些脱色,但却洗得很干净。他的目光扫过那人,发觉他的年纪不大不小。但是他身上的伤痕,却很清晰:“你就是青云山弟子?” “我是云霄峰弟子,青云派弟子!”这位看上去只有二十多岁的青云弟子,脸色一阵苍白。云霄阁弟子一向是这位青云阁弟子,这一位,也是青云门弟子。 云霄阁的弟子是青云宗最出名的弟子之一,也有几个是青'}]\n",
92 | "[{'generated_text': '少年对面站着一位中年人,这位中年人两鬓略有些斑白,穿着一套青衫。尽管衣衫有些脱色,但却洗得很干净。看到那青衫中年男子,便是其余几人,也纷纷点头应声。 “哈哈......好啊!”青衫人大笑道,看到众人,纷纷赞叹起来。 “不错,正是小兄弟,不知这位小兄弟是?”中年中年修士看到他,心头顿喜,但脸上却是笑容满面的说道,“小兄弟不才,在下青衫老祖,是小友,是天地间最顶级别的修士。'}]\n",
93 | "[{'generated_text': '风凡答应了一声,将地上的东西都收了起来,然后端坐在地上,开始默想绿简的样子。他的脑海中出现了一片片字迹,这正是五行玄蒙经的内容。 “随意一卷,入心入肺,真正的内功是一重修行的客户,这里是每一个国家的地盘,其他国家需要经过系统的考验,只有读取更广博的才能才能得到其他修行的出身,不过现在开始修炼了!” “谢谢大家关注!” 虽然保持着先前寇立的觉悟,可是一个修行者,就能在毕方眼中看到什么呢? “待会儿我'}]\n",
94 | "[{'generated_text': '风凡答应了一声,将地上的东西都收了起来,然后端坐在地上,开始默想绿简的样子。他的脑海中出现了一片片字迹,这正是五行玄蒙经的内容。而此刻,此刻在他的脑海中也出现了“五行玄蒙经”! 这是一种奇怪的情况,一种奇怪的气体散发出来,仿佛是要把所有的东西都吞噬掉一样。这是一种无所不能的气体,这无形的无形之中,有着无形的阻碍,让人产生不解的感觉。但是,这种气体却是能量,而不是能量,而是能量的克星,这就是无形的无'}]\n",
95 | "[{'generated_text': '风凡答应了一声,将地上的东西都收了起来,然后端坐在地上,开始默想绿简的样子。他的脑海中出现了一片片字迹,这正是五行玄蒙经的内容。 天下第一高僧是一个个年轻年纪,长相粗狂、连自己也是老样子,而且一身黄毛颜色都是紫色,另外两个看起来就是年长已久的美少年。至于老者,因为头发上扎着紫『色』的胡须所在,所以整个人就像黑鹰一样的目光,阴森的盯着欧阳宁。 四大高手面对欧英文那金袍男子,气势汹汹,一步一步的朝着绿明离去的方向行去,'}]\n",
96 | "[{'generated_text': '风凡答应了一声,将地上的东西都收了起来,然后端坐在地上,开始默想绿简的样子。他的脑海中出现了一片片字迹,这正是五行玄蒙经的内容。只见这个字迹在地下的时候,竟然发出了阵阵的声音,仿佛这里就是地府一样,只是这样看起来像是一个地下室。 “好,我答应你,一定会回来的,你一定可以找到我,我也会回去的,我要去找一个合适的地方。” “好。”绿简答应一声,然後就转过身,开口说道。 “我会去找你的。”林阳昊点点头,然'}]\n",
97 | "[{'generated_text': '风凡答应了一声,将地上的东西都收了起来,然后端坐在地上,开始默想绿简的样子。他的脑海中出现了一片片字迹,这正是五行玄蒙经的内容。 他将绿简收拾了一下,然後开始思考该如何修炼的方法。因为在这里,他已经没有时间去想了,现在要去的只有一条路可以走。这里离绿笔山有一段距离,这里只要他修炼成功的地方,不管是什么样的东西,他都会选择走下去的路。 绿笔山,是一种奇怪的地脉的中心,这一条山脉,就是一条长约百米的通道,通'}]\n"
98 | ],
99 | "name": "stdout"
100 | }
101 | ]
102 | }
103 | ]
104 | }
105 |
--------------------------------------------------------------------------------
/performer.py:
--------------------------------------------------------------------------------
1 | from transformers import GPT2Tokenizer
2 | import tensorflow as tf
3 | from transformers import TFPerformerAttention
4 | from transformers import GPT2Config
5 | from transformers import TFGPT2MainLayer, TFGPT2LMHeadModel
6 | from transformers.models.gpt2.modeling_tf_gpt2 import TFMLP, TFAttention, TFConv1D
7 | from enum import Enum
8 | from typing import Sequence, Optional, Union
9 | from transformers.configuration_performer_attention import PerformerAttentionConfig, PerformerKernel, OrthogonalFeatureAlgorithm
10 | from fast_attention import SelfAttention as Attention
11 | import json
12 |
13 |
14 | class EnumEncoder(json.JSONEncoder):
15 | def default(self, obj):
16 | if issubclass(type(obj), Enum):
17 | return {"__enum__": str(obj)}
18 | return json.JSONEncoder.default(self, obj)
19 |
20 |
21 | class PerformerConfig(GPT2Config):
22 |
23 | def __init__(self,
24 | attention_dropout: float = 0.1,
25 | kernel_type: Union[str, PerformerKernel] = PerformerKernel.exp,
26 | causal: bool = False,
27 | use_recurrent_decoding: bool = False,
28 | kernel_epsilon: float = 1e-4,
29 | normalize_output: bool = True,
30 | normalization_stabilizer: float = 1e-6,
31 | use_linear_layers: bool = True,
32 | linear_layer_names: Sequence[str] = ('q_linear', 'k_linear', 'v_linear', 'out_linear'),
33 | num_random_features: Optional[int] = None,
34 | use_thick_features: bool = False,
35 | regularize_feature_norms: bool = True,
36 | use_orthogonal_features: bool = True,
37 | orthogonal_feature_algorithm: Union[str, OrthogonalFeatureAlgorithm] = OrthogonalFeatureAlgorithm.auto,
38 | feature_redraw_interval: Optional[int] = 100,
39 | redraw_stochastically: bool = False,
40 | redraw_verbose: bool = False,
41 | d_model: Optional[int] = None,
42 | num_heads: Optional[int] = None,
43 | **kwargs):
44 |
45 | self.attention_dropout = attention_dropout
46 | self.kernel_type = kernel_type
47 | self.causal = causal
48 | self.use_recurrent_decoding = use_recurrent_decoding
49 | self.kernel_epsilon = kernel_epsilon
50 | self.normalize_output = normalize_output
51 | self.normalization_stabilizer = normalization_stabilizer
52 | self.use_linear_layers = use_linear_layers
53 | self.linear_layer_names = linear_layer_names
54 | self.num_random_features = num_random_features
55 | self.use_thick_features = use_thick_features
56 | self.regularize_feature_norms = regularize_feature_norms
57 | self.use_orthogonal_features = use_orthogonal_features
58 | self.orthogonal_feature_algorithm = orthogonal_feature_algorithm
59 | self.feature_redraw_interval = feature_redraw_interval
60 | self.redraw_stochastically = redraw_stochastically
61 | self.redraw_verbose = redraw_verbose
62 | self.d_model = d_model
63 | self.num_heads = num_heads
64 |
65 | super().__init__(**kwargs)
66 |
67 | def to_json_string(self, use_diff: bool = True) -> str:
68 | """
69 | Serializes this instance to a JSON string.
70 |
71 | Args:
72 | use_diff (:obj:`bool`, `optional`, defaults to :obj:`True`):
73 | If set to ``True``, only the difference between the config instance and the default
74 | ``PretrainedConfig()`` is serialized to JSON string.
75 |
76 | Returns:
77 | :obj:`str`: String containing all the attributes that make up this configuration instance in JSON format.
78 | """
79 | if use_diff is True:
80 | config_dict = self.to_diff_dict()
81 | else:
82 | config_dict = self.to_dict()
83 | return json.dumps(config_dict, indent=2, sort_keys=True, cls=EnumEncoder) + "\n"
84 |
85 |
86 | class TFBlock(tf.keras.layers.Layer):
87 | def __init__(self, n_ctx, config, scale=False, **kwargs):
88 | super().__init__(**kwargs)
89 | nx = config.n_embd
90 | inner_dim = config.n_inner if config.n_inner is not None else 4 * nx
91 | self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1")
92 | # self.c_attn = TFConv1D(config.n_embd * 3, nx, initializer_range=config.initializer_range, name="c_attn")
93 | self.attn = Attention(hidden_size=config.n_embd, num_heads=config.n_head, attention_dropout=0.1, causal=True, nb_random_features=768, name="attn")
94 | self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
95 | self.mlp = TFMLP(inner_dim, config, name="mlp")
96 |
97 | def call(self, x, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
98 | a = self.ln_1(x)
99 | # a = self.c_attn(a)
100 | # query, key, value = tf.split(a, 3, axis=2)
101 | output_attn = self.attn(a)
102 | a = output_attn[0] # output_attn: a, present, (attentions)
103 | x = x + a
104 |
105 | m = self.ln_2(x)
106 | m = self.mlp(m, training=training)
107 | x = x + m
108 |
109 | outputs = [x] + [x]
110 | return outputs # x, present, (attentions)
111 |
112 |
113 | class TFGPT2MainLayer(TFGPT2MainLayer):
114 |
115 | def __init__(self, config, *inputs, **kwargs):
116 | super().__init__(config, *inputs, **kwargs)
117 | self.h = [TFBlock(config.n_ctx, config, scale=True, name="h_._{}".format(i)) for i in range(config.n_layer)]
118 | self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f")
119 |
120 |
121 | class TFGPT2LMHeadModel(TFGPT2LMHeadModel):
122 |
123 | def __init__(self, config, *inputs, **kwargs):
124 | super().__init__(config, *inputs, **kwargs)
125 | # assert hasattr(config, 'attention_dropout')
126 | # self.transformer = TFGPT2MainLayer(config, name="transformer")
127 |
128 |
129 | # pconfig = PerformerConfig()
130 | # configuration = GPT2Config()
131 | # model = TFGPT2LMHeadModel(pconfig)
132 |
133 | # tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
134 | # inputs = tokenizer("aa aa aa aa", return_tensors="tf")
135 | # model(**inputs)
136 | # out = model(**inputs)
137 | # print(out[0].shape)
138 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from transformers import GPT2Config
3 | from transformers import TFGPT2LMHeadModel
4 | from transformers import XLNetTokenizer
5 | from transformers import BertTokenizer
6 | from transformers.modeling_tf_utils import shape_list
7 | import configs
8 | import random
9 | import click
10 | import time
11 | import pickle
12 | from pathlib import Path
13 | import numpy as np
14 | import gc
15 | from tqdm import tqdm
16 |
17 | # from tensorflow.keras.mixed_precision import experimental as mixed_precision
18 |
19 |
20 | # policy = mixed_precision.Policy('mixed_float16')
21 | # mixed_precision.set_policy(policy)
22 |
23 |
24 | def load_tokenizer() -> BertTokenizer:
25 | tokenizer = BertTokenizer.from_pretrained(
26 | configs.data.path, max_len=configs.model.max_length, add_special_token=False)
27 | tokenizer.return_attention_mask = None
28 | return tokenizer
29 |
30 |
31 | def get_dataset() -> tf.data.Dataset:
32 | p = Path(configs.data.pickle_path)
33 | s_pickle_files = list(p.glob("*.s.pickle"))
34 | m_pickle_files = list(p.glob("*.m.pickle"))
35 | l_pickle_files = list(p.glob("*.l.pickle"))
36 | s_group_num = 8
37 | m_group_num = 4
38 | l_group_num = 2
39 | s_pickle_files = [
40 | (s_pickle_files[i: i + s_group_num], 's')
41 | for i in range(0, len(s_pickle_files), s_group_num)
42 | ]
43 | m_pickle_files = [
44 | (m_pickle_files[i: i + m_group_num], 'm')
45 | for i in range(0, len(m_pickle_files), m_group_num)
46 | ]
47 | l_pickle_files = [
48 | (l_pickle_files[i: i + l_group_num], 'l')
49 | for i in range(0, len(l_pickle_files), l_group_num)
50 | ]
51 | pickle_files = s_pickle_files + m_pickle_files + l_pickle_files
52 | random.shuffle(pickle_files)
53 |
54 | for (sub_pickle_files, size) in pickle_files:
55 | input_ids = []
56 | for pickle_file in sub_pickle_files:
57 | print(f"loading {pickle_file}")
58 | pickle_datas = list(
59 | open(pickle_file, "rb").read().split("换行".encode()))
60 | bad_count = 0
61 | for line in tqdm(pickle_datas):
62 | if line:
63 | ids = pickle.loads(line)
64 | # if ids.shape[-1] != configs.model.max_length:
65 | # bad_count += 1
66 | # continue
67 | for row in ids:
68 | input_ids.append(row)
69 | print("bad ids count: ", bad_count)
70 | input_ids = np.array(input_ids)
71 |
72 | ids = input_ids[:, :-1]
73 | labels = input_ids[:, 1:]
74 | # ids = ids.astype('int32')
75 | # labels = ids.astype('int32')
76 | # print(ids, labels)
77 | print(ids.shape, labels.shape, ids.dtype, labels.dtype)
78 | if size == 'l':
79 | batch_size = 32
80 | if size == 'm':
81 | batch_size = 48
82 | if size == 's':
83 | batch_size = 64
84 | dataset = (
85 | tf.data.Dataset.from_tensor_slices((ids, labels))
86 | .shuffle(ids.shape[0], reshuffle_each_iteration=False)
87 | .repeat()
88 | .batch(batch_size)
89 | )
90 | yield len(input_ids), dataset
91 |
92 |
93 | def build_loss(tokenizer):
94 | def custom_loss(labels, logits):
95 | loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
96 | from_logits=True, reduction=tf.keras.losses.Reduction.NONE
97 | )
98 | # make sure only labels that are not equal to -100 affect the loss
99 | active_loss = tf.not_equal(tf.reshape(labels, (-1,)), tokenizer.pad_token_id)
100 | reduced_logits = tf.boolean_mask(
101 | tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss
102 | )
103 | labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
104 | return loss_fn(labels, reduced_logits)
105 |
106 | return custom_loss
107 |
108 |
109 | class CustomAccuracy(tf.keras.metrics.SparseCategoricalAccuracy):
110 |
111 | def __init__(self, *args, tokenizer=None, **kwargs):
112 | self.tokenizer = tokenizer
113 | super(CustomAccuracy, self).__init__(*args, **kwargs)
114 |
115 | def update_state(self, labels, logits, sample_weight=None):
116 | active_loss = tf.not_equal(tf.reshape(labels, (-1,)), self.tokenizer.pad_token_id)
117 | reduced_logits = tf.boolean_mask(
118 | tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss
119 | )
120 | labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
121 | return super().update_state(labels, reduced_logits, sample_weight)
122 |
123 |
124 | def init_model(
125 | tokenizer: BertTokenizer,
126 | model_path: str = configs.model_path,
127 | ) -> TFGPT2LMHeadModel:
128 |
129 | try:
130 | model = TFGPT2LMHeadModel.from_pretrained(
131 | model_path, return_dict=False)
132 | except EnvironmentError:
133 | config = GPT2Config(
134 | architectures=["TFGPT2LMHeadModel"],
135 | model_type="TFGPT2LMHeadModel",
136 | tokenizer_class="XLNetTokenizer",
137 | vocab_size=tokenizer.vocab_size,
138 | n_positions=configs.model.n_positions,
139 | n_ctx=configs.model.n_ctx,
140 | n_embd=configs.model.n_embd,
141 | n_layer=configs.model.n_layer,
142 | n_head=configs.model.n_head,
143 | pad_token_id=tokenizer.pad_token_id,
144 | task_specific_params={
145 | "text-generation": {"do_sample": True, "max_length": 120}
146 | },
147 | return_dict=False,
148 | output_attentions=False,
149 | output_hidden_states=False,
150 | use_cache=False,
151 | )
152 | model = TFGPT2LMHeadModel(config)
153 | loss = build_loss(tokenizer)
154 | # loss = model.compute_loss
155 | lr = 5e-5
156 | optimizer = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
157 | metric = CustomAccuracy("accuracy", tokenizer=tokenizer)
158 |
159 | model.compile(
160 | optimizer=optimizer,
161 | loss=[loss, *[None] * model.config.n_layer],
162 | metrics=[metric],
163 | )
164 | return model
165 |
166 |
167 | def train(model: TFGPT2LMHeadModel, train_dataset, epochs: int, train_steps: int):
168 | class AutoSaveCallback(tf.keras.callbacks.Callback):
169 |
170 | def on_epoch_end(self, epoch, logs=None):
171 | self.model.save_pretrained(f"{configs.model_path}")
172 |
173 | callbacks = [
174 | tf.keras.callbacks.TensorBoard(
175 | log_dir=f"{configs.model_path}/logs", update_freq=50
176 | ),
177 | AutoSaveCallback(),
178 | ]
179 |
180 | t1 = time.time()
181 |
182 | model.fit(
183 | train_dataset,
184 | epochs=epochs,
185 | steps_per_epoch=train_steps,
186 | callbacks=callbacks,
187 | batch_size=None,
188 | )
189 | print(f"total train time {time.time() - t1}")
190 |
191 |
192 | @click.command()
193 | @click.option('--epochs', default=4, help='number of epochs')
194 | @click.option('--train_steps', default=500, help='number of train_steps')
195 | def main(epochs, train_steps):
196 | tokenizer = load_tokenizer()
197 |
198 | for _total_num, train_dataset in get_dataset():
199 | model = init_model(
200 | tokenizer, configs.model_path
201 | )
202 |
203 | train(model, train_dataset, epochs, train_steps)
204 | del train_dataset
205 | del model
206 | gc.collect()
207 |
208 |
209 | if __name__ == '__main__':
210 | mirrored_strategy = tf.distribute.MirroredStrategy()
211 | with mirrored_strategy.scope():
212 | main()
213 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
17 | #
18 | # Licensed under the Apache License, Version 2.0 (the "License");
19 | # you may not use this file except in compliance with the License.
20 | # You may obtain a copy of the License at
21 | #
22 | # http://www.apache.org/licenses/LICENSE-2.0
23 | #
24 | # Unless required by applicable law or agreed to in writing, software
25 | # distributed under the License is distributed on an "AS IS" BASIS,
26 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27 | # See the License for the specific language governing permissions and
28 | # limitations under the License.
29 | # ==============================================================================
30 | """Keras-based einsum layer.
31 |
32 | Copied from
33 | https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/dense_einsum.py.
34 | """
35 | # pylint: disable=g-classes-have-attributes
36 |
37 | import tensorflow as tf
38 |
39 | _CHR_IDX = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m"]
40 |
41 |
42 | @tf.keras.utils.register_keras_serializable(package="Text")
43 | class DenseEinsum(tf.keras.layers.Layer):
44 | """A densely connected layer that uses tf.einsum as the backing computation.
45 |
46 | This layer can perform einsum calculations of arbitrary dimensionality.
47 |
48 | Arguments:
49 | output_shape: Positive integer or tuple, dimensionality of the output space.
50 | num_summed_dimensions: The number of dimensions to sum over. Standard 2D
51 | matmul should use 1, 3D matmul should use 2, and so forth.
52 | activation: Activation function to use. If you don't specify anything, no
53 | activation is applied
54 | (ie. "linear" activation: `a(x) = x`).
55 | use_bias: Boolean, whether the layer uses a bias vector.
56 | kernel_initializer: Initializer for the `kernel` weights matrix.
57 | bias_initializer: Initializer for the bias vector.
58 | kernel_regularizer: Regularizer function applied to the `kernel` weights
59 | matrix.
60 | bias_regularizer: Regularizer function applied to the bias vector.
61 | activity_regularizer: Regularizer function applied to the output of the
62 | layer (its "activation")..
63 | kernel_constraint: Constraint function applied to the `kernel` weights
64 | matrix.
65 | bias_constraint: Constraint function applied to the bias vector.
66 | Input shape:
67 | N-D tensor with shape: `(batch_size, ..., input_dim)`. The most common
68 | situation would be a 2D input with shape `(batch_size, input_dim)`.
69 | Output shape:
70 | N-D tensor with shape: `(batch_size, ..., units)`. For instance, for a 2D
71 | input with shape `(batch_size, input_dim)`, the output would have shape
72 | `(batch_size, units)`.
73 | """
74 |
75 | def __init__(self,
76 | output_shape,
77 | num_summed_dimensions=1,
78 | activation=None,
79 | use_bias=True,
80 | kernel_initializer="glorot_uniform",
81 | bias_initializer="zeros",
82 | kernel_regularizer=None,
83 | bias_regularizer=None,
84 | activity_regularizer=None,
85 | kernel_constraint=None,
86 | bias_constraint=None,
87 | **kwargs):
88 | super(DenseEinsum, self).__init__(**kwargs)
89 | self._output_shape = output_shape if isinstance(
90 | output_shape, (list, tuple)) else (output_shape,)
91 | self._activation = tf.keras.activations.get(activation)
92 | self._use_bias = use_bias
93 | self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
94 | self._bias_initializer = tf.keras.initializers.get(bias_initializer)
95 | self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
96 | self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
97 | self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
98 | self._bias_constraint = tf.keras.constraints.get(bias_constraint)
99 | self._num_summed_dimensions = num_summed_dimensions
100 | self._einsum_string = None
101 |
102 | def _build_einsum_string(self, free_input_dims, bound_dims, output_dims):
103 | input_str = ""
104 | kernel_str = ""
105 | output_str = ""
106 | letter_offset = 0
107 | for i in range(free_input_dims):
108 | char = _CHR_IDX[i + letter_offset]
109 | input_str += char
110 | output_str += char
111 |
112 | letter_offset += free_input_dims
113 | for i in range(bound_dims):
114 | char = _CHR_IDX[i + letter_offset]
115 | input_str += char
116 | kernel_str += char
117 |
118 | letter_offset += bound_dims
119 | for i in range(output_dims):
120 | char = _CHR_IDX[i + letter_offset]
121 | kernel_str += char
122 | output_str += char
123 |
124 | return input_str + "," + kernel_str + "->" + output_str
125 |
126 | def build(self, input_shape):
127 | input_shape = tf.TensorShape(input_shape)
128 | input_rank = input_shape.rank
129 | free_input_dims = input_rank - self._num_summed_dimensions
130 | output_dims = len(self._output_shape)
131 |
132 | self._einsum_string = self._build_einsum_string(free_input_dims,
133 | self._num_summed_dimensions,
134 | output_dims)
135 |
136 | # This is only saved for testing purposes.
137 | self._kernel_shape = (
138 | input_shape[free_input_dims:].concatenate(self._output_shape))
139 |
140 | self._kernel = self.add_weight(
141 | "kernel",
142 | shape=self._kernel_shape,
143 | initializer=self._kernel_initializer,
144 | regularizer=self._kernel_regularizer,
145 | constraint=self._kernel_constraint,
146 | dtype=self.dtype,
147 | trainable=True)
148 | if self._use_bias:
149 | self._bias = self.add_weight(
150 | "bias",
151 | shape=self._output_shape,
152 | initializer=self._bias_initializer,
153 | regularizer=self._bias_regularizer,
154 | constraint=self._bias_constraint,
155 | dtype=self.dtype,
156 | trainable=True)
157 | else:
158 | self._bias = None
159 | super(DenseEinsum, self).build(input_shape)
160 |
161 | def get_config(self):
162 | config = {
163 | "output_shape":
164 | self._output_shape,
165 | "num_summed_dimensions":
166 | self._num_summed_dimensions,
167 | "activation":
168 | tf.keras.activations.serialize(self._activation),
169 | "use_bias":
170 | self._use_bias,
171 | "kernel_initializer":
172 | tf.keras.initializers.serialize(self._kernel_initializer),
173 | "bias_initializer":
174 | tf.keras.initializers.serialize(self._bias_initializer),
175 | "kernel_regularizer":
176 | tf.keras.regularizers.serialize(self._kernel_regularizer),
177 | "bias_regularizer":
178 | tf.keras.regularizers.serialize(self._bias_regularizer),
179 | "activity_regularizer":
180 | tf.keras.regularizers.serialize(self._activity_regularizer),
181 | "kernel_constraint":
182 | tf.keras.constraints.serialize(self._kernel_constraint),
183 | "bias_constraint":
184 | tf.keras.constraints.serialize(self._bias_constraint)
185 | }
186 | base_config = super(DenseEinsum, self).get_config()
187 | return dict(list(base_config.items()) + list(config.items()))
188 |
189 | def call(self, inputs):
190 | ret = tf.einsum(self._einsum_string, inputs, self._kernel)
191 | if self._use_bias:
192 | ret += self._bias
193 | if self._activation is not None:
194 | ret = self._activation(ret)
195 | return ret
196 |
197 |
--------------------------------------------------------------------------------
/dataset/test/vocab.txt:
--------------------------------------------------------------------------------
1 | [PAD]
2 | [UNK]
3 | [CLS]
4 | [SEP]
5 | [MASK]
6 | ?
7 | □
8 | 。
9 | 【
10 | 】
11 | 一
12 | 七
13 | 万
14 | 三
15 | 上
16 | 下
17 | 不
18 | 与
19 | 且
20 | 世
21 | 丘
22 | 业
23 | 丛
24 | 东
25 | 丝
26 | 丞
27 | 两
28 | 严
29 | 中
30 | 临
31 | 丹
32 | 为
33 | 主
34 | 举
35 | 乃
36 | 久
37 | 义
38 | 之
39 | 乌
40 | 乐
41 | 乘
42 | 九
43 | 也
44 | 乡
45 | 书
46 | 乱
47 | 乾
48 | 了
49 | 予
50 | 争
51 | 事
52 | 二
53 | 于
54 | 云
55 | 五
56 | 井
57 | 交
58 | 亦
59 | 享
60 | 京
61 | 亭
62 | 亲
63 | 人
64 | 仁
65 | 今
66 | 仍
67 | 从
68 | 他
69 | 仙
70 | 代
71 | 令
72 | 以
73 | 仪
74 | 仰
75 | 任
76 | 休
77 | 众
78 | 会
79 | 传
80 | 伤
81 | 伴
82 | 似
83 | 但
84 | 低
85 | 住
86 | 何
87 | 佛
88 | 作
89 | 佩
90 | 佳
91 | 使
92 | 侍
93 | 依
94 | 侣
95 | 侯
96 | 侵
97 | 便
98 | 俗
99 | 信
100 | 修
101 | 俱
102 | 倒
103 | 倚
104 | 倾
105 | 偈
106 | 偏
107 | 停
108 | 偶
109 | 傍
110 | 僧
111 | 儿
112 | 元
113 | 兄
114 | 先
115 | 光
116 | 入
117 | 全
118 | 八
119 | 公
120 | 六
121 | 兮
122 | 兰
123 | 共
124 | 关
125 | 兴
126 | 兵
127 | 其
128 | 兹
129 | 养
130 | 兼
131 | 内
132 | 再
133 | 军
134 | 冠
135 | 冥
136 | 冬
137 | 冰
138 | 况
139 | 冷
140 | 净
141 | 凉
142 | 凌
143 | 凝
144 | 几
145 | 凤
146 | 凭
147 | 出
148 | 刀
149 | 分
150 | 列
151 | 刘
152 | 则
153 | 初
154 | 利
155 | 别
156 | 到
157 | 制
158 | 前
159 | 剑
160 | 力
161 | 功
162 | 动
163 | 劳
164 | 势
165 | 勤
166 | 化
167 | 北
168 | 十
169 | 千
170 | 升
171 | 半
172 | 华
173 | 南
174 | 占
175 | 卢
176 | 卧
177 | 即
178 | 却
179 | 卷
180 | 卿
181 | 历
182 | 厌
183 | 原
184 | 去
185 | 县
186 | 参
187 | 又
188 | 及
189 | 友
190 | 双
191 | 发
192 | 取
193 | 受
194 | 变
195 | 口
196 | 古
197 | 句
198 | 只
199 | 可
200 | 台
201 | 史
202 | 叶
203 | 司
204 | 叹
205 | 各
206 | 合
207 | 同
208 | 名
209 | 后
210 | 吏
211 | 向
212 | 君
213 | 吟
214 | 含
215 | 听
216 | 启
217 | 吴
218 | 吹
219 | 吾
220 | 呈
221 | 周
222 | 呼
223 | 命
224 | 和
225 | 咏
226 | 咸
227 | 咽
228 | 哀
229 | 响
230 | 哭
231 | 唐
232 | 唯
233 | 啼
234 | 善
235 | 喜
236 | 喧
237 | 嗟
238 | 嘉
239 | 四
240 | 回
241 | 因
242 | 园
243 | 国
244 | 图
245 | 圆
246 | 土
247 | 圣
248 | 在
249 | 地
250 | 坐
251 | 坛
252 | 垂
253 | 城
254 | 堂
255 | 堪
256 | 塞
257 | 境
258 | 壁
259 | 士
260 | 声
261 | 处
262 | 备
263 | 复
264 | 夏
265 | 夕
266 | 外
267 | 多
268 | 夜
269 | 大
270 | 天
271 | 太
272 | 夫
273 | 失
274 | 头
275 | 夷
276 | 奇
277 | 奈
278 | 奉
279 | 奏
280 | 女
281 | 好
282 | 如
283 | 妇
284 | 妖
285 | 妾
286 | 始
287 | 威
288 | 子
289 | 字
290 | 存
291 | 孙
292 | 孝
293 | 孤
294 | 学
295 | 宁
296 | 宅
297 | 宇
298 | 守
299 | 安
300 | 宗
301 | 官
302 | 定
303 | 宜
304 | 宝
305 | 实
306 | 客
307 | 室
308 | 宫
309 | 宴
310 | 宵
311 | 家
312 | 容
313 | 宿
314 | 寂
315 | 寄
316 | 密
317 | 寒
318 | 寥
319 | 对
320 | 寺
321 | 寻
322 | 寿
323 | 封
324 | 将
325 | 尊
326 | 小
327 | 少
328 | 尔
329 | 尘
330 | 尚
331 | 就
332 | 尺
333 | 尽
334 | 居
335 | 屋
336 | 展
337 | 山
338 | 岁
339 | 岂
340 | 岛
341 | 岩
342 | 岭
343 | 岳
344 | 岸
345 | 峰
346 | 崇
347 | 崔
348 | 川
349 | 州
350 | 巢
351 | 差
352 | 已
353 | 巴
354 | 市
355 | 布
356 | 帆
357 | 师
358 | 帘
359 | 帝
360 | 带
361 | 席
362 | 常
363 | 干
364 | 平
365 | 年
366 | 并
367 | 幸
368 | 幽
369 | 广
370 | 庆
371 | 床
372 | 序
373 | 应
374 | 底
375 | 庙
376 | 府
377 | 废
378 | 度
379 | 庭
380 | 延
381 | 开
382 | 异
383 | 弃
384 | 引
385 | 弟
386 | 张
387 | 弦
388 | 弹
389 | 归
390 | 当
391 | 形
392 | 彩
393 | 影
394 | 彼
395 | 往
396 | 征
397 | 径
398 | 待
399 | 徒
400 | 得
401 | 御
402 | 微
403 | 德
404 | 心
405 | 必
406 | 忆
407 | 志
408 | 忘
409 | 忧
410 | 念
411 | 忽
412 | 怀
413 | 怅
414 | 怜
415 | 思
416 | 急
417 | 性
418 | 怨
419 | 怪
420 | 恐
421 | 恨
422 | 恩
423 | 息
424 | 恶
425 | 悟
426 | 悠
427 | 悬
428 | 悲
429 | 情
430 | 惊
431 | 惜
432 | 惟
433 | 惭
434 | 想
435 | 愁
436 | 意
437 | 愚
438 | 感
439 | 愿
440 | 戎
441 | 戏
442 | 成
443 | 我
444 | 战
445 | 户
446 | 房
447 | 所
448 | 手
449 | 才
450 | 扬
451 | 承
452 | 投
453 | 折
454 | 报
455 | 披
456 | 抱
457 | 拂
458 | 招
459 | 拜
460 | 拟
461 | 持
462 | 指
463 | 接
464 | 推
465 | 掩
466 | 摇
467 | 攀
468 | 收
469 | 改
470 | 放
471 | 故
472 | 教
473 | 敢
474 | 散
475 | 数
476 | 文
477 | 斋
478 | 斗
479 | 斜
480 | 断
481 | 斯
482 | 新
483 | 方
484 | 施
485 | 旋
486 | 旌
487 | 无
488 | 既
489 | 日
490 | 旧
491 | 早
492 | 时
493 | 昌
494 | 明
495 | 昏
496 | 易
497 | 昔
498 | 星
499 | 映
500 | 春
501 | 昨
502 | 昭
503 | 是
504 | 昼
505 | 晓
506 | 晚
507 | 晨
508 | 景
509 | 晴
510 | 智
511 | 暂
512 | 暑
513 | 暗
514 | 暮
515 | 曙
516 | 曲
517 | 更
518 | 曹
519 | 曾
520 | 最
521 | 月
522 | 有
523 | 服
524 | 望
525 | 朝
526 | 期
527 | 木
528 | 未
529 | 本
530 | 朱
531 | 机
532 | 杀
533 | 杂
534 | 李
535 | 村
536 | 条
537 | 来
538 | 杨
539 | 杯
540 | 松
541 | 极
542 | 枕
543 | 林
544 | 果
545 | 枝
546 | 枯
547 | 柏
548 | 柳
549 | 树
550 | 栖
551 | 校
552 | 根
553 | 桂
554 | 桃
555 | 桐
556 | 桑
557 | 桥
558 | 梁
559 | 梅
560 | 梦
561 | 楚
562 | 楼
563 | 横
564 | 檐
565 | 次
566 | 欢
567 | 欲
568 | 歌
569 | 止
570 | 正
571 | 此
572 | 步
573 | 武
574 | 死
575 | 殊
576 | 残
577 | 殷
578 | 殿
579 | 每
580 | 比
581 | 毛
582 | 氏
583 | 气
584 | 水
585 | 永
586 | 求
587 | 汉
588 | 江
589 | 池
590 | 沈
591 | 沙
592 | 没
593 | 沧
594 | 河
595 | 泉
596 | 法
597 | 泛
598 | 波
599 | 泥
600 | 泪
601 | 泽
602 | 洛
603 | 洞
604 | 洲
605 | 流
606 | 浅
607 | 济
608 | 浑
609 | 浦
610 | 浪
611 | 浮
612 | 海
613 | 消
614 | 涛
615 | 涯
616 | 深
617 | 添
618 | 清
619 | 渐
620 | 渔
621 | 渚
622 | 游
623 | 湖
624 | 湘
625 | 湿
626 | 源
627 | 溪
628 | 满
629 | 滴
630 | 漫
631 | 潜
632 | 潭
633 | 潮
634 | 澄
635 | 火
636 | 灭
637 | 灯
638 | 灵
639 | 炎
640 | 点
641 | 烛
642 | 烟
643 | 烦
644 | 烧
645 | 然
646 | 煌
647 | 照
648 | 燕
649 | 爱
650 | 片
651 | 牛
652 | 物
653 | 牵
654 | 犬
655 | 犹
656 | 狂
657 | 独
658 | 献
659 | 猿
660 | 玄
661 | 玉
662 | 王
663 | 珠
664 | 理
665 | 琴
666 | 琼
667 | 瑞
668 | 甘
669 | 生
670 | 用
671 | 田
672 | 由
673 | 画
674 | 界
675 | 留
676 | 疏
677 | 疑
678 | 病
679 | 登
680 | 白
681 | 百
682 | 皆
683 | 皇
684 | 盈
685 | 盘
686 | 盛
687 | 目
688 | 直
689 | 相
690 | 眉
691 | 看
692 | 真
693 | 眠
694 | 眼
695 | 瞻
696 | 知
697 | 短
698 | 石
699 | 破
700 | 碧
701 | 磬
702 | 示
703 | 礼
704 | 祀
705 | 祖
706 | 神
707 | 禁
708 | 禅
709 | 福
710 | 离
711 | 秀
712 | 秋
713 | 种
714 | 秦
715 | 积
716 | 称
717 | 移
718 | 稀
719 | 程
720 | 穆
721 | 穴
722 | 穷
723 | 空
724 | 穿
725 | 窗
726 | 立
727 | 竞
728 | 章
729 | 童
730 | 端
731 | 竹
732 | 笑
733 | 笔
734 | 第
735 | 笼
736 | 筵
737 | 管
738 | 篇
739 | 类
740 | 粉
741 | 精
742 | 系
743 | 素
744 | 紫
745 | 繁
746 | 红
747 | 纵
748 | 纷
749 | 细
750 | 终
751 | 经
752 | 结
753 | 绕
754 | 绝
755 | 绿
756 | 缘
757 | 罗
758 | 罢
759 | 羁
760 | 美
761 | 群
762 | 羽
763 | 翁
764 | 翠
765 | 翻
766 | 老
767 | 者
768 | 而
769 | 耶
770 | 聊
771 | 肃
772 | 肠
773 | 肯
774 | 胜
775 | 胡
776 | 能
777 | 臣
778 | 自
779 | 至
780 | 舍
781 | 舒
782 | 舞
783 | 舟
784 | 船
785 | 良
786 | 色
787 | 艳
788 | 节
789 | 花
790 | 芳
791 | 苍
792 | 苑
793 | 苔
794 | 若
795 | 苦
796 | 英
797 | 茅
798 | 茫
799 | 荆
800 | 草
801 | 荐
802 | 荒
803 | 荡
804 | 荣
805 | 药
806 | 荷
807 | 莫
808 | 莲
809 | 莺
810 | 营
811 | 萧
812 | 落
813 | 著
814 | 蓬
815 | 薄
816 | 藏
817 | 虎
818 | 虚
819 | 虫
820 | 虽
821 | 蛇
822 | 蜀
823 | 蝉
824 | 行
825 | 衣
826 | 表
827 | 衰
828 | 被
829 | 裴
830 | 西
831 | 要
832 | 见
833 | 观
834 | 觉
835 | 角
836 | 解
837 | 言
838 | 计
839 | 许
840 | 论
841 | 访
842 | 识
843 | 词
844 | 诗
845 | 诚
846 | 语
847 | 说
848 | 诸
849 | 谁
850 | 调
851 | 谢
852 | 谣
853 | 谷
854 | 象
855 | 贞
856 | 贤
857 | 贫
858 | 贵
859 | 赋
860 | 赏
861 | 赠
862 | 赤
863 | 走
864 | 赴
865 | 起
866 | 越
867 | 足
868 | 路
869 | 身
870 | 车
871 | 轩
872 | 转
873 | 轮
874 | 轻
875 | 载
876 | 辞
877 | 边
878 | 达
879 | 过
880 | 迎
881 | 运
882 | 近
883 | 还
884 | 进
885 | 远
886 | 连
887 | 迟
888 | 迢
889 | 迷
890 | 迹
891 | 送
892 | 逐
893 | 途
894 | 通
895 | 逢
896 | 逸
897 | 遂
898 | 遇
899 | 遍
900 | 道
901 | 遗
902 | 遣
903 | 遥
904 | 避
905 | 那
906 | 邻
907 | 郊
908 | 郎
909 | 郡
910 | 郭
911 | 都
912 | 酒
913 | 酬
914 | 醉
915 | 采
916 | 里
917 | 重
918 | 野
919 | 金
920 | 钓
921 | 钟
922 | 钱
923 | 销
924 | 锦
925 | 镜
926 | 长
927 | 门
928 | 问
929 | 闲
930 | 间
931 | 闻
932 | 阁
933 | 阙
934 | 阳
935 | 阴
936 | 阶
937 | 陆
938 | 陇
939 | 陈
940 | 降
941 | 限
942 | 院
943 | 除
944 | 陪
945 | 陵
946 | 随
947 | 隐
948 | 隔
949 | 难
950 | 雁
951 | 雄
952 | 集
953 | 雨
954 | 雪
955 | 雷
956 | 雾
957 | 霄
958 | 霜
959 | 霞
960 | 露
961 | 青
962 | 静
963 | 非
964 | 面
965 | 鞭
966 | 音
967 | 韵
968 | 顺
969 | 须
970 | 顾
971 | 频
972 | 题
973 | 颜
974 | 风
975 | 飘
976 | 飞
977 | 食
978 | 饮
979 | 馀
980 | 馆
981 | 首
982 | 香
983 | 马
984 | 驱
985 | 驾
986 | 驿
987 | 骑
988 | 骨
989 | 高
990 | 鬓
991 | 魂
992 | 鱼
993 | 鳞
994 | 鸟
995 | 鸡
996 | 鸣
997 | 鸿
998 | 鹤
999 | 黄
1000 | 鼓
1001 | 齐
1002 | 龙
1003 | ,
1004 | :
1005 | ?
1006 | ##□
1007 | □□
1008 | □□□
1009 |
--------------------------------------------------------------------------------
/fast_attention.py:
--------------------------------------------------------------------------------
1 |
2 | """Implementation of multiheaded FAVOR-attention & FAVOR-self-attention layers.
3 | Prefix Sum Tensorflow implementation by Valerii Likhosherstov.
4 | """
5 | import math
6 | import numpy as np
7 | import tensorflow as tf
8 | import util
9 |
10 | BIG_CONSTANT = 1e8
11 |
12 |
13 | def create_projection_matrix(m, d, seed=0, scaling=0, struct_mode=False):
14 | r"""Constructs the matrix of random projections.
15 | Constructs a matrix of random orthogonal projections. Each projection vector
16 | has direction chosen uniformly at random and either deterministic length
17 | \sqrt{d} or length taken from the \chi(d) distribution (in the latter case
18 | marginal distributions of the projections are d-dimensional Gaussian vectors
19 | with associated identity covariance matrix).
20 | Args:
21 | m: number of random projections.
22 | d: dimensionality of each random projection.
23 | seed: random seed used to construct projections.
24 | scaling: 1 if all the random projections need to be renormalized to have
25 | length \sqrt{d}, 0 if the lengths of random projections should follow
26 | \chi(d) distribution.
27 | struct_mode: if True then products of Givens rotations will be used to
28 | construct random orthogonal matrix. This bypasses Gram-Schmidt
29 | orthogonalization.
30 | Returns:
31 | The matrix of random projections of the shape [m, d].
32 | """
33 | nb_full_blocks = int(m / d)
34 | block_list = []
35 | current_seed = seed
36 | for _ in range(nb_full_blocks):
37 | if struct_mode:
38 | q = create_products_of_givens_rotations(d, seed)
39 | else:
40 | unstructured_block = tf.random.normal((d, d), seed=current_seed)
41 | q, _ = tf.linalg.qr(unstructured_block)
42 | q = tf.transpose(q)
43 | block_list.append(q)
44 | current_seed += 1
45 | remaining_rows = m - nb_full_blocks * d
46 | if remaining_rows > 0:
47 | if struct_mode:
48 | q = create_products_of_givens_rotations(d, seed)
49 | else:
50 | unstructured_block = tf.random.normal((d, d), seed=current_seed)
51 | q, _ = tf.linalg.qr(unstructured_block)
52 | q = tf.transpose(q)
53 | block_list.append(q[0:remaining_rows])
54 | final_matrix = tf.experimental.numpy.vstack(block_list)
55 | current_seed += 1
56 |
57 | if scaling == 0:
58 | multiplier = tf.norm(tf.random.normal((m, d), seed=current_seed), axis=1)
59 | elif scaling == 1:
60 | multiplier = tf.math.sqrt(float(d)) * tf.ones((m))
61 | else:
62 | raise ValueError("Scaling must be one of {0, 1}. Was %s" % scaling)
63 |
64 | return tf.linalg.matmul(tf.linalg.diag(multiplier), final_matrix)
65 |
66 |
67 | def create_products_of_givens_rotations(dim, seed):
68 | r"""Constructs a 2D-tensor which is a product of Givens random rotations.
69 | Constructs a 2D-tensor of the form G_1 * ... * G_k, where G_i is a Givens
70 | random rotation. The resulting tensor mimics a matrix taken uniformly at
71 | random form the orthogonal group.
72 | Args:
73 | dim: number of rows/columns of the resulting 2D-tensor.
74 | seed: random seed.
75 | Returns:
76 | The product of Givens random rotations.
77 | """
78 | nb_givens_rotations = dim * int(math.ceil(math.log(float(dim))))
79 | q = np.eye(dim, dim)
80 | np.random.seed(seed)
81 | for _ in range(nb_givens_rotations):
82 | random_angle = math.pi * np.random.uniform()
83 | random_indices = np.random.choice(dim, 2)
84 | index_i = min(random_indices[0], random_indices[1])
85 | index_j = max(random_indices[0], random_indices[1])
86 | slice_i = q[index_i]
87 | slice_j = q[index_j]
88 | new_slice_i = math.cos(random_angle) * slice_i + math.sin(
89 | random_angle) * slice_j
90 | new_slice_j = -math.sin(random_angle) * slice_i + math.cos(
91 | random_angle) * slice_j
92 | q[index_i] = new_slice_i
93 | q[index_j] = new_slice_j
94 | return tf.cast(tf.constant(q), dtype=tf.float32)
95 |
96 |
97 | def relu_kernel_transformation(data,
98 | is_query,
99 | projection_matrix=None,
100 | numerical_stabilizer=0.001):
101 | """Computes features for the ReLU-kernel.
102 | Computes random features for the ReLU kernel from
103 | https://arxiv.org/pdf/2009.14794.pdf.
104 | Args:
105 | data: input data tensor of the shape [B, L, H, D], where: B - batch
106 | dimension, L - attention dimensions, H - heads, D - features.
107 | is_query: indicates whether input data is a query oor key tensor.
108 | projection_matrix: random Gaussian matrix of shape [M, D], where M stands
109 | for the number of random features and each D x D sub-block has pairwise
110 | orthogonal rows.
111 | numerical_stabilizer: small positive constant for numerical stability.
112 | Returns:
113 | Corresponding kernel feature map.
114 | """
115 | del is_query
116 | if projection_matrix is None:
117 | return tf.nn.relu(data) + numerical_stabilizer
118 | else:
119 | ratio = 1.0 / tf.math.sqrt(
120 | tf.dtypes.cast(projection_matrix.shape[0], tf.float32))
121 | data_dash = ratio * tf.einsum("blhd,md->blhm", data, projection_matrix)
122 | return tf.nn.relu(data_dash) + numerical_stabilizer
123 |
124 |
125 | def softmax_kernel_transformation(data,
126 | is_query,
127 | projection_matrix=None,
128 | numerical_stabilizer=0.000001):
129 | """Computes random features for the softmax kernel using FAVOR+ mechanism.
130 | Computes random features for the softmax kernel using FAVOR+ mechanism from
131 | https://arxiv.org/pdf/2009.14794.pdf.
132 | Args:
133 | data: input data tensor of the shape [B, L, H, D], where: B - batch
134 | dimension, L - attention dimensions, H - heads, D - features.
135 | is_query: indicates whether input data is a query oor key tensor.
136 | projection_matrix: random Gaussian matrix of shape [M, D], where M stands
137 | for the number of random features and each D x D sub-block has pairwise
138 | orthogonal rows.
139 | numerical_stabilizer: small positive constant for numerical stability.
140 | Returns:
141 | Corresponding kernel feature map.
142 | """
143 | data_normalizer = 1.0 / (
144 | tf.math.sqrt(tf.math.sqrt(tf.dtypes.cast(data.shape[-1], tf.float32))))
145 | data = data_normalizer * data
146 | ratio = 1.0 / tf.math.sqrt(
147 | tf.dtypes.cast(projection_matrix.shape[0], tf.float32))
148 | data_dash = tf.einsum("blhd,md->blhm", data, projection_matrix)
149 | diag_data = tf.math.square(data)
150 | diag_data = tf.math.reduce_sum(
151 | diag_data, axis=tf.keras.backend.ndim(data) - 1)
152 | diag_data = diag_data / 2.0
153 | diag_data = tf.expand_dims(diag_data, axis=tf.keras.backend.ndim(data) - 1)
154 | last_dims_t = (len(data_dash.shape) - 1,)
155 | attention_dims_t = (len(data_dash.shape) - 3,)
156 | if is_query:
157 | data_dash = ratio * (
158 | tf.math.exp(data_dash - diag_data - tf.math.reduce_max(
159 | data_dash, axis=last_dims_t, keepdims=True)) + numerical_stabilizer)
160 | else:
161 | data_dash = ratio * (
162 | tf.math.exp(data_dash - diag_data - tf.math.reduce_max(
163 | data_dash, axis=last_dims_t + attention_dims_t, keepdims=True)) +
164 | numerical_stabilizer)
165 |
166 | return data_dash
167 |
168 |
169 | def noncausal_numerator(qs, ks, vs):
170 | """Computes not-normalized FAVOR noncausal attention AV.
171 | Args:
172 | qs: query_prime tensor of the shape [L,B,H,M].
173 | ks: key_prime tensor of the shape [L,B,H,M].
174 | vs: value tensor of the shape [L,B,H,D].
175 | Returns:
176 | Not-normalized FAVOR noncausal attention AV.
177 | """
178 | kvs = tf.einsum("lbhm,lbhd->bhmd", ks, vs)
179 | return tf.einsum("lbhm,bhmd->lbhd", qs, kvs)
180 |
181 |
182 | def noncausal_denominator(qs, ks):
183 | """Computes FAVOR normalizer in noncausal attention.
184 | Args:
185 | qs: query_prime tensor of the shape [L,B,H,M].
186 | ks: key_prime tensor of the shape [L,B,H,M].
187 | Returns:
188 | FAVOR normalizer in noncausal attention.
189 | """
190 | all_ones = tf.ones([ks.shape[0]])
191 | ks_sum = tf.einsum("lbhm,l->bhm", ks, all_ones)
192 | return tf.einsum("lbhm,bhm->lbh", qs, ks_sum)
193 |
194 |
195 | def causal_attention_mask(nd, ns, dtype):
196 | """
197 | 1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]),
198 | -1, ns-nd), but doesn't produce garbage on TPUs.
199 | """
200 | i = tf.range(nd)[:, None]
201 | j = tf.range(ns)
202 | m = i >= j - ns + nd
203 | return tf.cast(m, dtype)
204 |
205 |
206 | @tf.custom_gradient
207 | def causal_numerator(qs, ks, vs):
208 | """Computes not-normalized FAVOR causal attention A_{masked}V.
209 | Args:
210 | qs: query_prime tensor of the shape [L,B,H,M].
211 | ks: key_prime tensor of the shape [L,B,H,M].
212 | vs: value tensor of the shape [L,B,H,D].
213 | Returns:
214 | Not-normalized FAVOR causal attention A_{masked}V.
215 | """
216 |
217 | result = []
218 | sums = tf.zeros_like(tf.einsum("ijk,ijl->ijkl", ks[0], vs[0]))
219 |
220 | for index in range(qs.shape[0]):
221 | sums = sums + tf.einsum("ijk,ijl->ijkl", ks[index], vs[index])
222 | result.append(tf.einsum("ijkl,ijk->ijl", sums, qs[index])[None, Ellipsis])
223 |
224 | result = tf.concat(result, axis=0)
225 |
226 | def grad(res_grad):
227 |
228 | grads = tf.zeros_like(tf.einsum("ijk,ijl->ijkl", ks[0], vs[0]))
229 |
230 | gr_sums = sums
231 |
232 | q_grads = []
233 | k_grads = []
234 | v_grads = []
235 |
236 | for index in range(qs.shape[0] - 1, -1, -1):
237 |
238 | q_grads.append(
239 | tf.einsum("ijkl,ijl->ijk", gr_sums, res_grad[index])[None, Ellipsis])
240 | grads = grads + tf.einsum("ijk,ijl->ijkl", qs[index], res_grad[index])
241 | k_grads.append(tf.einsum("ijkl,ijl->ijk", grads, vs[index])[None, Ellipsis])
242 | v_grads.append(tf.einsum("ijkl,ijk->ijl", grads, ks[index])[None, Ellipsis])
243 | gr_sums = gr_sums - tf.einsum("ijk,ijl->ijkl", ks[index], vs[index])
244 |
245 | q_grads = tf.concat(q_grads[::-1], axis=0)
246 | k_grads = tf.concat(k_grads[::-1], axis=0)
247 | v_grads = tf.concat(v_grads[::-1], axis=0)
248 |
249 | return q_grads, k_grads, v_grads
250 |
251 | return result, grad
252 |
253 |
254 | @tf.custom_gradient
255 | def causal_denominator(qs, ks):
256 | """Computes FAVOR normalizer in causal attention.
257 | Args:
258 | qs: query_prime tensor of the shape [L,B,H,M].
259 | ks: key_prime tensor of the shape [L,B,H,M].
260 | Returns:
261 | FAVOR normalizer in causal attention.
262 | """
263 |
264 | result = []
265 | sums = tf.zeros_like(ks[0])
266 |
267 | for index in range(qs.shape[0]):
268 | sums = sums + ks[index]
269 | result.append(tf.reduce_sum(qs[index] * sums, axis=2)[None, Ellipsis])
270 |
271 | result = tf.concat(result, axis=0)
272 |
273 | def grad(res_grad):
274 |
275 | k_grad = tf.zeros_like(ks[0])
276 |
277 | gr_sums = sums
278 |
279 | q_grads = []
280 | k_grads = []
281 |
282 | for index in range(qs.shape[0] - 1, -1, -1):
283 |
284 | q_grads.append(
285 | tf.einsum("ijk,ij->ijk", gr_sums, res_grad[index])[None, Ellipsis])
286 | k_grad = k_grad + tf.einsum("ijk,ij->ijk", qs[index], res_grad[index])
287 | k_grads.append(k_grad[None, Ellipsis])
288 | gr_sums = gr_sums - ks[index]
289 |
290 | q_grads = tf.concat(q_grads[::-1], axis=0)
291 | k_grads = tf.concat(k_grads[::-1], axis=0)
292 |
293 | return q_grads, k_grads
294 |
295 | return result, grad
296 |
297 |
298 | def favor_attention(query,
299 | key,
300 | value,
301 | kernel_transformation,
302 | causal,
303 | projection_matrix=None):
304 | """Computes FAVOR normalized attention.
305 | Args:
306 | query: query tensor.
307 | key: key tensor.
308 | value: value tensor.
309 | kernel_transformation: transformation used to get finite kernel features.
310 | causal: whether attention is causal or not.
311 | projection_matrix: projection matrix to be used.
312 | Returns:
313 | FAVOR normalized attention.
314 | """
315 | query_prime = kernel_transformation(query, True,
316 | projection_matrix) # [B,L,H,M]
317 | key_prime = kernel_transformation(key, False, projection_matrix) # [B,L,H,M]
318 | query_prime = tf.transpose(query_prime, [1, 0, 2, 3]) # [L,B,H,M]
319 | key_prime = tf.transpose(key_prime, [1, 0, 2, 3]) # [L,B,H,M]
320 | value = tf.transpose(value, [1, 0, 2, 3]) # [L,B,H,D]
321 |
322 | if causal:
323 | av_attention = causal_numerator(query_prime, key_prime, value)
324 | attention_normalizer = causal_denominator(query_prime, key_prime)
325 | else:
326 | av_attention = noncausal_numerator(query_prime, key_prime, value)
327 | attention_normalizer = noncausal_denominator(query_prime, key_prime)
328 | # TODO(kchoro): Add more comments.
329 | av_attention = tf.transpose(av_attention, [1, 0, 2, 3])
330 | attention_normalizer = tf.transpose(attention_normalizer, [1, 0, 2])
331 | attention_normalizer = tf.expand_dims(attention_normalizer,
332 | len(attention_normalizer.shape))
333 | return av_attention / attention_normalizer
334 |
335 |
336 | class Attention(tf.keras.layers.Layer):
337 | """Multi-headed attention layer."""
338 |
339 | def __init__(self,
340 | hidden_size,
341 | num_heads,
342 | attention_dropout,
343 | kernel_transformation=relu_kernel_transformation,
344 | numerical_stabilizer=0.001,
345 | causal=False,
346 | projection_matrix_type=None,
347 | nb_random_features=0, **kwargs):
348 | """Initialize Attention.
349 | Args:
350 | hidden_size: int, output dim of hidden layer.
351 | num_heads: int, number of heads to repeat the same attention structure.
352 | attention_dropout: float, dropout rate inside attention for training.
353 | kernel_transformation: transformation used to produce kernel features for
354 | attention.
355 | numerical_stabilizer: used to bound away from zero kernel values.
356 | causal: whether attention is causal or not.
357 | projection_matrix_type: None if Identity should be used, otherwise random
358 | projection matrix will be applied.
359 | nb_random_features: number of random features to be used (relevant only if
360 | projection_matrix is not None).
361 | """
362 | if hidden_size % num_heads:
363 | raise ValueError(
364 | "Hidden size ({}) must be divisible by the number of heads ({})."
365 | .format(hidden_size, num_heads))
366 |
367 | super().__init__(**kwargs)
368 | self.hidden_size = hidden_size
369 | self.num_heads = num_heads
370 | self.attention_dropout = attention_dropout
371 | self.kernel_transformation = kernel_transformation
372 | self.numerical_stabilizer = numerical_stabilizer
373 | self.causal = causal
374 | self.projection_matrix_type = projection_matrix_type
375 | self.nb_random_features = nb_random_features
376 |
377 | # def build(self, input_shape):
378 | # """Builds the layer."""
379 | # # Layers for linearly projecting the queries, keys, and values.
380 | size_per_head = self.hidden_size // self.num_heads
381 |
382 | def _glorot_initializer(fan_in, fan_out):
383 | limit = math.sqrt(6.0 / (fan_in + fan_out))
384 | return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)
385 |
386 | attention_initializer = _glorot_initializer(hidden_size, self.hidden_size)
387 | self.query_dense_layer = util.DenseEinsum(
388 | output_shape=(self.num_heads, size_per_head),
389 | kernel_initializer=attention_initializer,
390 | use_bias=True,
391 | name="query")
392 | self.key_dense_layer = util.DenseEinsum(
393 | output_shape=(self.num_heads, size_per_head),
394 | kernel_initializer=attention_initializer,
395 | use_bias=True,
396 | name="key")
397 | self.value_dense_layer = util.DenseEinsum(
398 | output_shape=(self.num_heads, size_per_head),
399 | kernel_initializer=attention_initializer,
400 | use_bias=True,
401 | name="value")
402 |
403 | output_initializer = _glorot_initializer(self.hidden_size, self.hidden_size)
404 | self.output_dense_layer = util.DenseEinsum(
405 | output_shape=self.hidden_size,
406 | num_summed_dimensions=2,
407 | kernel_initializer=output_initializer,
408 | use_bias=True,
409 | name="output_transform")
410 | # super(Attention, self).build(input_shape)
411 |
412 | def get_config(self):
413 | return {
414 | "hidden_size": self.hidden_size,
415 | "num_heads": self.num_heads,
416 | "attention_dropout": self.attention_dropout,
417 | }
418 |
419 | def call(self,
420 | query_input,
421 | source_input,
422 | bias=None,
423 | training=None,
424 | cache=None,
425 | decode_loop_step=None):
426 | """Apply attention mechanism to query_input and source_input.
427 | Args:
428 | query_input: A tensor with shape [batch_size, length_query, hidden_size].
429 | source_input: A tensor with shape [batch_size, length_source,
430 | hidden_size].
431 | bias: A tensor with shape [batch_size, 1, length_query, length_source],
432 | the attention bias that will be added to the result of the dot product.
433 | training: A bool, whether in training mode or not.
434 | cache: (Used during prediction) A dictionary with tensors containing
435 | results of previous attentions. The dictionary must have the items:
436 | {"k": tensor with shape [batch_size, i, heads, dim_per_head],
437 | "v": tensor with shape [batch_size, i, heads, dim_per_head]} where
438 | i is the current decoded length for non-padded decode, or max
439 | sequence length for padded decode.
440 | decode_loop_step: An integer, step number of the decoding loop. Used only
441 | for autoregressive inference on TPU.
442 | Returns:
443 | Attention layer output with shape [batch_size, length_query, hidden_size]
444 | """
445 | # Linearly project the query, key and value using different learned
446 | # projections. Splitting heads is automatically done during the linear
447 | # projections --> [batch_size, length, num_heads, dim_per_head].
448 | query = self.query_dense_layer(query_input)
449 | key = self.key_dense_layer(source_input)
450 | value = self.value_dense_layer(source_input)
451 |
452 | if self.projection_matrix_type is None:
453 | projection_matrix = None
454 | else:
455 | dim = query.shape[-1]
456 | seed = 0
457 | projection_matrix = create_projection_matrix(
458 | self.nb_random_features, dim, seed=seed)
459 |
460 | if cache is not None:
461 | # Combine cached keys and values with new keys and values.
462 | if decode_loop_step is not None:
463 | cache_k_shape = cache["k"].shape.as_list()
464 | indices = tf.reshape(
465 | tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype),
466 | [1, cache_k_shape[1], 1, 1])
467 | key = cache["k"] + key * indices
468 | cache_v_shape = cache["v"].shape.as_list()
469 | indices = tf.reshape(
470 | tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype),
471 | [1, cache_v_shape[1], 1, 1])
472 | value = cache["v"] + value * indices
473 | else:
474 | key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1)
475 | value = tf.concat([tf.cast(cache["v"], value.dtype), value], axis=1)
476 |
477 | # Update cache
478 | cache["k"] = key
479 | cache["v"] = value
480 |
481 | attention_output = favor_attention(query, key, value,
482 | self.kernel_transformation, self.causal,
483 | projection_matrix)
484 | attention_output = self.output_dense_layer(attention_output)
485 | return attention_output
486 |
487 |
488 | class SelfAttention(Attention):
489 | """Multiheaded self-attention layer."""
490 |
491 | def call(self,
492 | query_input,
493 | bias=None,
494 | training=None,
495 | cache=None,
496 | decode_loop_step=None):
497 | return super(SelfAttention, self).call(query_input, query_input, bias,
498 | training, cache, decode_loop_step)
--------------------------------------------------------------------------------
/examples/gpt2_medium_chinese.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "ai_noval_demo.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": []
9 | },
10 | "kernelspec": {
11 | "display_name": "Python 3",
12 | "name": "python3"
13 | },
14 | "widgets": {
15 | "application/vnd.jupyter.widget-state+json": {
16 | "64e7fa9092d64e759e3754ac8f87bb74": {
17 | "model_module": "@jupyter-widgets/controls",
18 | "model_name": "HBoxModel",
19 | "state": {
20 | "_view_name": "HBoxView",
21 | "_dom_classes": [],
22 | "_model_name": "HBoxModel",
23 | "_view_module": "@jupyter-widgets/controls",
24 | "_model_module_version": "1.5.0",
25 | "_view_count": null,
26 | "_view_module_version": "1.5.0",
27 | "box_style": "",
28 | "layout": "IPY_MODEL_13241f65231f4e3cb9726ec72a60a458",
29 | "_model_module": "@jupyter-widgets/controls",
30 | "children": [
31 | "IPY_MODEL_4bced51e28214e4da179bb977aba565a",
32 | "IPY_MODEL_756213dc534144d484a294aa6fa96719"
33 | ]
34 | }
35 | },
36 | "13241f65231f4e3cb9726ec72a60a458": {
37 | "model_module": "@jupyter-widgets/base",
38 | "model_name": "LayoutModel",
39 | "state": {
40 | "_view_name": "LayoutView",
41 | "grid_template_rows": null,
42 | "right": null,
43 | "justify_content": null,
44 | "_view_module": "@jupyter-widgets/base",
45 | "overflow": null,
46 | "_model_module_version": "1.2.0",
47 | "_view_count": null,
48 | "flex_flow": null,
49 | "width": null,
50 | "min_width": null,
51 | "border": null,
52 | "align_items": null,
53 | "bottom": null,
54 | "_model_module": "@jupyter-widgets/base",
55 | "top": null,
56 | "grid_column": null,
57 | "overflow_y": null,
58 | "overflow_x": null,
59 | "grid_auto_flow": null,
60 | "grid_area": null,
61 | "grid_template_columns": null,
62 | "flex": null,
63 | "_model_name": "LayoutModel",
64 | "justify_items": null,
65 | "grid_row": null,
66 | "max_height": null,
67 | "align_content": null,
68 | "visibility": null,
69 | "align_self": null,
70 | "height": null,
71 | "min_height": null,
72 | "padding": null,
73 | "grid_auto_rows": null,
74 | "grid_gap": null,
75 | "max_width": null,
76 | "order": null,
77 | "_view_module_version": "1.2.0",
78 | "grid_template_areas": null,
79 | "object_position": null,
80 | "object_fit": null,
81 | "grid_auto_columns": null,
82 | "margin": null,
83 | "display": null,
84 | "left": null
85 | }
86 | },
87 | "4bced51e28214e4da179bb977aba565a": {
88 | "model_module": "@jupyter-widgets/controls",
89 | "model_name": "FloatProgressModel",
90 | "state": {
91 | "_view_name": "ProgressView",
92 | "style": "IPY_MODEL_8c4f1cf0e91d4b43ab590c52f3b15a31",
93 | "_dom_classes": [],
94 | "description": "Downloading: 100%",
95 | "_model_name": "FloatProgressModel",
96 | "bar_style": "success",
97 | "max": 35058,
98 | "_view_module": "@jupyter-widgets/controls",
99 | "_model_module_version": "1.5.0",
100 | "value": 35058,
101 | "_view_count": null,
102 | "_view_module_version": "1.5.0",
103 | "orientation": "horizontal",
104 | "min": 0,
105 | "description_tooltip": null,
106 | "_model_module": "@jupyter-widgets/controls",
107 | "layout": "IPY_MODEL_82d0150d639b47adaf29fbd8935aa5a3"
108 | }
109 | },
110 | "756213dc534144d484a294aa6fa96719": {
111 | "model_module": "@jupyter-widgets/controls",
112 | "model_name": "HTMLModel",
113 | "state": {
114 | "_view_name": "HTMLView",
115 | "style": "IPY_MODEL_3bf33e08b650446282581f38856f7a9d",
116 | "_dom_classes": [],
117 | "description": "",
118 | "_model_name": "HTMLModel",
119 | "placeholder": "",
120 | "_view_module": "@jupyter-widgets/controls",
121 | "_model_module_version": "1.5.0",
122 | "value": " 35.1k/35.1k [00:01<00:00, 22.3kB/s]",
123 | "_view_count": null,
124 | "_view_module_version": "1.5.0",
125 | "description_tooltip": null,
126 | "_model_module": "@jupyter-widgets/controls",
127 | "layout": "IPY_MODEL_acf3afce64f1476380ad9bae2da3979b"
128 | }
129 | },
130 | "8c4f1cf0e91d4b43ab590c52f3b15a31": {
131 | "model_module": "@jupyter-widgets/controls",
132 | "model_name": "ProgressStyleModel",
133 | "state": {
134 | "_view_name": "StyleView",
135 | "_model_name": "ProgressStyleModel",
136 | "description_width": "initial",
137 | "_view_module": "@jupyter-widgets/base",
138 | "_model_module_version": "1.5.0",
139 | "_view_count": null,
140 | "_view_module_version": "1.2.0",
141 | "bar_color": null,
142 | "_model_module": "@jupyter-widgets/controls"
143 | }
144 | },
145 | "82d0150d639b47adaf29fbd8935aa5a3": {
146 | "model_module": "@jupyter-widgets/base",
147 | "model_name": "LayoutModel",
148 | "state": {
149 | "_view_name": "LayoutView",
150 | "grid_template_rows": null,
151 | "right": null,
152 | "justify_content": null,
153 | "_view_module": "@jupyter-widgets/base",
154 | "overflow": null,
155 | "_model_module_version": "1.2.0",
156 | "_view_count": null,
157 | "flex_flow": null,
158 | "width": null,
159 | "min_width": null,
160 | "border": null,
161 | "align_items": null,
162 | "bottom": null,
163 | "_model_module": "@jupyter-widgets/base",
164 | "top": null,
165 | "grid_column": null,
166 | "overflow_y": null,
167 | "overflow_x": null,
168 | "grid_auto_flow": null,
169 | "grid_area": null,
170 | "grid_template_columns": null,
171 | "flex": null,
172 | "_model_name": "LayoutModel",
173 | "justify_items": null,
174 | "grid_row": null,
175 | "max_height": null,
176 | "align_content": null,
177 | "visibility": null,
178 | "align_self": null,
179 | "height": null,
180 | "min_height": null,
181 | "padding": null,
182 | "grid_auto_rows": null,
183 | "grid_gap": null,
184 | "max_width": null,
185 | "order": null,
186 | "_view_module_version": "1.2.0",
187 | "grid_template_areas": null,
188 | "object_position": null,
189 | "object_fit": null,
190 | "grid_auto_columns": null,
191 | "margin": null,
192 | "display": null,
193 | "left": null
194 | }
195 | },
196 | "3bf33e08b650446282581f38856f7a9d": {
197 | "model_module": "@jupyter-widgets/controls",
198 | "model_name": "DescriptionStyleModel",
199 | "state": {
200 | "_view_name": "StyleView",
201 | "_model_name": "DescriptionStyleModel",
202 | "description_width": "",
203 | "_view_module": "@jupyter-widgets/base",
204 | "_model_module_version": "1.5.0",
205 | "_view_count": null,
206 | "_view_module_version": "1.2.0",
207 | "_model_module": "@jupyter-widgets/controls"
208 | }
209 | },
210 | "acf3afce64f1476380ad9bae2da3979b": {
211 | "model_module": "@jupyter-widgets/base",
212 | "model_name": "LayoutModel",
213 | "state": {
214 | "_view_name": "LayoutView",
215 | "grid_template_rows": null,
216 | "right": null,
217 | "justify_content": null,
218 | "_view_module": "@jupyter-widgets/base",
219 | "overflow": null,
220 | "_model_module_version": "1.2.0",
221 | "_view_count": null,
222 | "flex_flow": null,
223 | "width": null,
224 | "min_width": null,
225 | "border": null,
226 | "align_items": null,
227 | "bottom": null,
228 | "_model_module": "@jupyter-widgets/base",
229 | "top": null,
230 | "grid_column": null,
231 | "overflow_y": null,
232 | "overflow_x": null,
233 | "grid_auto_flow": null,
234 | "grid_area": null,
235 | "grid_template_columns": null,
236 | "flex": null,
237 | "_model_name": "LayoutModel",
238 | "justify_items": null,
239 | "grid_row": null,
240 | "max_height": null,
241 | "align_content": null,
242 | "visibility": null,
243 | "align_self": null,
244 | "height": null,
245 | "min_height": null,
246 | "padding": null,
247 | "grid_auto_rows": null,
248 | "grid_gap": null,
249 | "max_width": null,
250 | "order": null,
251 | "_view_module_version": "1.2.0",
252 | "grid_template_areas": null,
253 | "object_position": null,
254 | "object_fit": null,
255 | "grid_auto_columns": null,
256 | "margin": null,
257 | "display": null,
258 | "left": null
259 | }
260 | },
261 | "dc5e861e2018456dab1952ed4866fbb5": {
262 | "model_module": "@jupyter-widgets/controls",
263 | "model_name": "HBoxModel",
264 | "state": {
265 | "_view_name": "HBoxView",
266 | "_dom_classes": [],
267 | "_model_name": "HBoxModel",
268 | "_view_module": "@jupyter-widgets/controls",
269 | "_model_module_version": "1.5.0",
270 | "_view_count": null,
271 | "_view_module_version": "1.5.0",
272 | "box_style": "",
273 | "layout": "IPY_MODEL_f54bdf7663114aeeba19d659501f84da",
274 | "_model_module": "@jupyter-widgets/controls",
275 | "children": [
276 | "IPY_MODEL_591d323671ef40e98b86470090465ebe",
277 | "IPY_MODEL_9a71e13f98f84c84ac91e4c64f5cd949"
278 | ]
279 | }
280 | },
281 | "f54bdf7663114aeeba19d659501f84da": {
282 | "model_module": "@jupyter-widgets/base",
283 | "model_name": "LayoutModel",
284 | "state": {
285 | "_view_name": "LayoutView",
286 | "grid_template_rows": null,
287 | "right": null,
288 | "justify_content": null,
289 | "_view_module": "@jupyter-widgets/base",
290 | "overflow": null,
291 | "_model_module_version": "1.2.0",
292 | "_view_count": null,
293 | "flex_flow": null,
294 | "width": null,
295 | "min_width": null,
296 | "border": null,
297 | "align_items": null,
298 | "bottom": null,
299 | "_model_module": "@jupyter-widgets/base",
300 | "top": null,
301 | "grid_column": null,
302 | "overflow_y": null,
303 | "overflow_x": null,
304 | "grid_auto_flow": null,
305 | "grid_area": null,
306 | "grid_template_columns": null,
307 | "flex": null,
308 | "_model_name": "LayoutModel",
309 | "justify_items": null,
310 | "grid_row": null,
311 | "max_height": null,
312 | "align_content": null,
313 | "visibility": null,
314 | "align_self": null,
315 | "height": null,
316 | "min_height": null,
317 | "padding": null,
318 | "grid_auto_rows": null,
319 | "grid_gap": null,
320 | "max_width": null,
321 | "order": null,
322 | "_view_module_version": "1.2.0",
323 | "grid_template_areas": null,
324 | "object_position": null,
325 | "object_fit": null,
326 | "grid_auto_columns": null,
327 | "margin": null,
328 | "display": null,
329 | "left": null
330 | }
331 | },
332 | "591d323671ef40e98b86470090465ebe": {
333 | "model_module": "@jupyter-widgets/controls",
334 | "model_name": "FloatProgressModel",
335 | "state": {
336 | "_view_name": "ProgressView",
337 | "style": "IPY_MODEL_26d7ab47ef944127963bf23ceaa9db69",
338 | "_dom_classes": [],
339 | "description": "Downloading: 100%",
340 | "_model_name": "FloatProgressModel",
341 | "bar_style": "success",
342 | "max": 811,
343 | "_view_module": "@jupyter-widgets/controls",
344 | "_model_module_version": "1.5.0",
345 | "value": 811,
346 | "_view_count": null,
347 | "_view_module_version": "1.5.0",
348 | "orientation": "horizontal",
349 | "min": 0,
350 | "description_tooltip": null,
351 | "_model_module": "@jupyter-widgets/controls",
352 | "layout": "IPY_MODEL_30759c36349240229a60499721a458da"
353 | }
354 | },
355 | "9a71e13f98f84c84ac91e4c64f5cd949": {
356 | "model_module": "@jupyter-widgets/controls",
357 | "model_name": "HTMLModel",
358 | "state": {
359 | "_view_name": "HTMLView",
360 | "style": "IPY_MODEL_f911485dcd324e5e8064ca5198b13b49",
361 | "_dom_classes": [],
362 | "description": "",
363 | "_model_name": "HTMLModel",
364 | "placeholder": "",
365 | "_view_module": "@jupyter-widgets/controls",
366 | "_model_module_version": "1.5.0",
367 | "value": " 811/811 [00:44<00:00, 18.4B/s]",
368 | "_view_count": null,
369 | "_view_module_version": "1.5.0",
370 | "description_tooltip": null,
371 | "_model_module": "@jupyter-widgets/controls",
372 | "layout": "IPY_MODEL_f515c31b840141fdbbaf14c272cd1c8d"
373 | }
374 | },
375 | "26d7ab47ef944127963bf23ceaa9db69": {
376 | "model_module": "@jupyter-widgets/controls",
377 | "model_name": "ProgressStyleModel",
378 | "state": {
379 | "_view_name": "StyleView",
380 | "_model_name": "ProgressStyleModel",
381 | "description_width": "initial",
382 | "_view_module": "@jupyter-widgets/base",
383 | "_model_module_version": "1.5.0",
384 | "_view_count": null,
385 | "_view_module_version": "1.2.0",
386 | "bar_color": null,
387 | "_model_module": "@jupyter-widgets/controls"
388 | }
389 | },
390 | "30759c36349240229a60499721a458da": {
391 | "model_module": "@jupyter-widgets/base",
392 | "model_name": "LayoutModel",
393 | "state": {
394 | "_view_name": "LayoutView",
395 | "grid_template_rows": null,
396 | "right": null,
397 | "justify_content": null,
398 | "_view_module": "@jupyter-widgets/base",
399 | "overflow": null,
400 | "_model_module_version": "1.2.0",
401 | "_view_count": null,
402 | "flex_flow": null,
403 | "width": null,
404 | "min_width": null,
405 | "border": null,
406 | "align_items": null,
407 | "bottom": null,
408 | "_model_module": "@jupyter-widgets/base",
409 | "top": null,
410 | "grid_column": null,
411 | "overflow_y": null,
412 | "overflow_x": null,
413 | "grid_auto_flow": null,
414 | "grid_area": null,
415 | "grid_template_columns": null,
416 | "flex": null,
417 | "_model_name": "LayoutModel",
418 | "justify_items": null,
419 | "grid_row": null,
420 | "max_height": null,
421 | "align_content": null,
422 | "visibility": null,
423 | "align_self": null,
424 | "height": null,
425 | "min_height": null,
426 | "padding": null,
427 | "grid_auto_rows": null,
428 | "grid_gap": null,
429 | "max_width": null,
430 | "order": null,
431 | "_view_module_version": "1.2.0",
432 | "grid_template_areas": null,
433 | "object_position": null,
434 | "object_fit": null,
435 | "grid_auto_columns": null,
436 | "margin": null,
437 | "display": null,
438 | "left": null
439 | }
440 | },
441 | "f911485dcd324e5e8064ca5198b13b49": {
442 | "model_module": "@jupyter-widgets/controls",
443 | "model_name": "DescriptionStyleModel",
444 | "state": {
445 | "_view_name": "StyleView",
446 | "_model_name": "DescriptionStyleModel",
447 | "description_width": "",
448 | "_view_module": "@jupyter-widgets/base",
449 | "_model_module_version": "1.5.0",
450 | "_view_count": null,
451 | "_view_module_version": "1.2.0",
452 | "_model_module": "@jupyter-widgets/controls"
453 | }
454 | },
455 | "f515c31b840141fdbbaf14c272cd1c8d": {
456 | "model_module": "@jupyter-widgets/base",
457 | "model_name": "LayoutModel",
458 | "state": {
459 | "_view_name": "LayoutView",
460 | "grid_template_rows": null,
461 | "right": null,
462 | "justify_content": null,
463 | "_view_module": "@jupyter-widgets/base",
464 | "overflow": null,
465 | "_model_module_version": "1.2.0",
466 | "_view_count": null,
467 | "flex_flow": null,
468 | "width": null,
469 | "min_width": null,
470 | "border": null,
471 | "align_items": null,
472 | "bottom": null,
473 | "_model_module": "@jupyter-widgets/base",
474 | "top": null,
475 | "grid_column": null,
476 | "overflow_y": null,
477 | "overflow_x": null,
478 | "grid_auto_flow": null,
479 | "grid_area": null,
480 | "grid_template_columns": null,
481 | "flex": null,
482 | "_model_name": "LayoutModel",
483 | "justify_items": null,
484 | "grid_row": null,
485 | "max_height": null,
486 | "align_content": null,
487 | "visibility": null,
488 | "align_self": null,
489 | "height": null,
490 | "min_height": null,
491 | "padding": null,
492 | "grid_auto_rows": null,
493 | "grid_gap": null,
494 | "max_width": null,
495 | "order": null,
496 | "_view_module_version": "1.2.0",
497 | "grid_template_areas": null,
498 | "object_position": null,
499 | "object_fit": null,
500 | "grid_auto_columns": null,
501 | "margin": null,
502 | "display": null,
503 | "left": null
504 | }
505 | },
506 | "c70eb9c425b242d9afd22368720f85aa": {
507 | "model_module": "@jupyter-widgets/controls",
508 | "model_name": "HBoxModel",
509 | "state": {
510 | "_view_name": "HBoxView",
511 | "_dom_classes": [],
512 | "_model_name": "HBoxModel",
513 | "_view_module": "@jupyter-widgets/controls",
514 | "_model_module_version": "1.5.0",
515 | "_view_count": null,
516 | "_view_module_version": "1.5.0",
517 | "box_style": "",
518 | "layout": "IPY_MODEL_8bcd569c425d433daa6cc2f6b0ee5049",
519 | "_model_module": "@jupyter-widgets/controls",
520 | "children": [
521 | "IPY_MODEL_aaf7468202274d45b189bfca4be3e78d",
522 | "IPY_MODEL_2c84b9cac1634532b19a3cd3630ddc41"
523 | ]
524 | }
525 | },
526 | "8bcd569c425d433daa6cc2f6b0ee5049": {
527 | "model_module": "@jupyter-widgets/base",
528 | "model_name": "LayoutModel",
529 | "state": {
530 | "_view_name": "LayoutView",
531 | "grid_template_rows": null,
532 | "right": null,
533 | "justify_content": null,
534 | "_view_module": "@jupyter-widgets/base",
535 | "overflow": null,
536 | "_model_module_version": "1.2.0",
537 | "_view_count": null,
538 | "flex_flow": null,
539 | "width": null,
540 | "min_width": null,
541 | "border": null,
542 | "align_items": null,
543 | "bottom": null,
544 | "_model_module": "@jupyter-widgets/base",
545 | "top": null,
546 | "grid_column": null,
547 | "overflow_y": null,
548 | "overflow_x": null,
549 | "grid_auto_flow": null,
550 | "grid_area": null,
551 | "grid_template_columns": null,
552 | "flex": null,
553 | "_model_name": "LayoutModel",
554 | "justify_items": null,
555 | "grid_row": null,
556 | "max_height": null,
557 | "align_content": null,
558 | "visibility": null,
559 | "align_self": null,
560 | "height": null,
561 | "min_height": null,
562 | "padding": null,
563 | "grid_auto_rows": null,
564 | "grid_gap": null,
565 | "max_width": null,
566 | "order": null,
567 | "_view_module_version": "1.2.0",
568 | "grid_template_areas": null,
569 | "object_position": null,
570 | "object_fit": null,
571 | "grid_auto_columns": null,
572 | "margin": null,
573 | "display": null,
574 | "left": null
575 | }
576 | },
577 | "aaf7468202274d45b189bfca4be3e78d": {
578 | "model_module": "@jupyter-widgets/controls",
579 | "model_name": "FloatProgressModel",
580 | "state": {
581 | "_view_name": "ProgressView",
582 | "style": "IPY_MODEL_7b7388d66d8746ad89c170b537d75798",
583 | "_dom_classes": [],
584 | "description": "Downloading: 100%",
585 | "_model_name": "FloatProgressModel",
586 | "bar_style": "success",
587 | "max": 1246629736,
588 | "_view_module": "@jupyter-widgets/controls",
589 | "_model_module_version": "1.5.0",
590 | "value": 1246629736,
591 | "_view_count": null,
592 | "_view_module_version": "1.5.0",
593 | "orientation": "horizontal",
594 | "min": 0,
595 | "description_tooltip": null,
596 | "_model_module": "@jupyter-widgets/controls",
597 | "layout": "IPY_MODEL_81083dc746b74780b970871df7a806c5"
598 | }
599 | },
600 | "2c84b9cac1634532b19a3cd3630ddc41": {
601 | "model_module": "@jupyter-widgets/controls",
602 | "model_name": "HTMLModel",
603 | "state": {
604 | "_view_name": "HTMLView",
605 | "style": "IPY_MODEL_97fccff63dca4eab96e24397046242ef",
606 | "_dom_classes": [],
607 | "description": "",
608 | "_model_name": "HTMLModel",
609 | "placeholder": "",
610 | "_view_module": "@jupyter-widgets/controls",
611 | "_model_module_version": "1.5.0",
612 | "value": " 1.25G/1.25G [00:43<00:00, 28.9MB/s]",
613 | "_view_count": null,
614 | "_view_module_version": "1.5.0",
615 | "description_tooltip": null,
616 | "_model_module": "@jupyter-widgets/controls",
617 | "layout": "IPY_MODEL_b4136fcefb43464598940b995259d91b"
618 | }
619 | },
620 | "7b7388d66d8746ad89c170b537d75798": {
621 | "model_module": "@jupyter-widgets/controls",
622 | "model_name": "ProgressStyleModel",
623 | "state": {
624 | "_view_name": "StyleView",
625 | "_model_name": "ProgressStyleModel",
626 | "description_width": "initial",
627 | "_view_module": "@jupyter-widgets/base",
628 | "_model_module_version": "1.5.0",
629 | "_view_count": null,
630 | "_view_module_version": "1.2.0",
631 | "bar_color": null,
632 | "_model_module": "@jupyter-widgets/controls"
633 | }
634 | },
635 | "81083dc746b74780b970871df7a806c5": {
636 | "model_module": "@jupyter-widgets/base",
637 | "model_name": "LayoutModel",
638 | "state": {
639 | "_view_name": "LayoutView",
640 | "grid_template_rows": null,
641 | "right": null,
642 | "justify_content": null,
643 | "_view_module": "@jupyter-widgets/base",
644 | "overflow": null,
645 | "_model_module_version": "1.2.0",
646 | "_view_count": null,
647 | "flex_flow": null,
648 | "width": null,
649 | "min_width": null,
650 | "border": null,
651 | "align_items": null,
652 | "bottom": null,
653 | "_model_module": "@jupyter-widgets/base",
654 | "top": null,
655 | "grid_column": null,
656 | "overflow_y": null,
657 | "overflow_x": null,
658 | "grid_auto_flow": null,
659 | "grid_area": null,
660 | "grid_template_columns": null,
661 | "flex": null,
662 | "_model_name": "LayoutModel",
663 | "justify_items": null,
664 | "grid_row": null,
665 | "max_height": null,
666 | "align_content": null,
667 | "visibility": null,
668 | "align_self": null,
669 | "height": null,
670 | "min_height": null,
671 | "padding": null,
672 | "grid_auto_rows": null,
673 | "grid_gap": null,
674 | "max_width": null,
675 | "order": null,
676 | "_view_module_version": "1.2.0",
677 | "grid_template_areas": null,
678 | "object_position": null,
679 | "object_fit": null,
680 | "grid_auto_columns": null,
681 | "margin": null,
682 | "display": null,
683 | "left": null
684 | }
685 | },
686 | "97fccff63dca4eab96e24397046242ef": {
687 | "model_module": "@jupyter-widgets/controls",
688 | "model_name": "DescriptionStyleModel",
689 | "state": {
690 | "_view_name": "StyleView",
691 | "_model_name": "DescriptionStyleModel",
692 | "description_width": "",
693 | "_view_module": "@jupyter-widgets/base",
694 | "_model_module_version": "1.5.0",
695 | "_view_count": null,
696 | "_view_module_version": "1.2.0",
697 | "_model_module": "@jupyter-widgets/controls"
698 | }
699 | },
700 | "b4136fcefb43464598940b995259d91b": {
701 | "model_module": "@jupyter-widgets/base",
702 | "model_name": "LayoutModel",
703 | "state": {
704 | "_view_name": "LayoutView",
705 | "grid_template_rows": null,
706 | "right": null,
707 | "justify_content": null,
708 | "_view_module": "@jupyter-widgets/base",
709 | "overflow": null,
710 | "_model_module_version": "1.2.0",
711 | "_view_count": null,
712 | "flex_flow": null,
713 | "width": null,
714 | "min_width": null,
715 | "border": null,
716 | "align_items": null,
717 | "bottom": null,
718 | "_model_module": "@jupyter-widgets/base",
719 | "top": null,
720 | "grid_column": null,
721 | "overflow_y": null,
722 | "overflow_x": null,
723 | "grid_auto_flow": null,
724 | "grid_area": null,
725 | "grid_template_columns": null,
726 | "flex": null,
727 | "_model_name": "LayoutModel",
728 | "justify_items": null,
729 | "grid_row": null,
730 | "max_height": null,
731 | "align_content": null,
732 | "visibility": null,
733 | "align_self": null,
734 | "height": null,
735 | "min_height": null,
736 | "padding": null,
737 | "grid_auto_rows": null,
738 | "grid_gap": null,
739 | "max_width": null,
740 | "order": null,
741 | "_view_module_version": "1.2.0",
742 | "grid_template_areas": null,
743 | "object_position": null,
744 | "object_fit": null,
745 | "grid_auto_columns": null,
746 | "margin": null,
747 | "display": null,
748 | "left": null
749 | }
750 | }
751 | }
752 | }
753 | },
754 | "cells": [
755 | {
756 | "cell_type": "code",
757 | "metadata": {
758 | "id": "-irg8cbucTA6",
759 | "outputId": "3b8ec82c-0bfe-478c-a7db-eb8753b0b426",
760 | "colab": {
761 | "base_uri": "https://localhost:8080/"
762 | }
763 | },
764 | "source": [
765 | "!pip install transformers"
766 | ],
767 | "execution_count": null,
768 | "outputs": []
769 | },
770 | {
771 | "cell_type": "code",
772 | "metadata": {
773 | "id": "4ba0k0hkcNks",
774 | "outputId": "901340c2-04df-4c2b-c510-68c1ad1cb5b3",
775 | "colab": {
776 | "base_uri": "https://localhost:8080/",
777 | "height": 232,
778 | "referenced_widgets": [
779 | "64e7fa9092d64e759e3754ac8f87bb74",
780 | "13241f65231f4e3cb9726ec72a60a458",
781 | "4bced51e28214e4da179bb977aba565a",
782 | "756213dc534144d484a294aa6fa96719",
783 | "8c4f1cf0e91d4b43ab590c52f3b15a31",
784 | "82d0150d639b47adaf29fbd8935aa5a3",
785 | "3bf33e08b650446282581f38856f7a9d",
786 | "acf3afce64f1476380ad9bae2da3979b",
787 | "dc5e861e2018456dab1952ed4866fbb5",
788 | "f54bdf7663114aeeba19d659501f84da",
789 | "591d323671ef40e98b86470090465ebe",
790 | "9a71e13f98f84c84ac91e4c64f5cd949",
791 | "26d7ab47ef944127963bf23ceaa9db69",
792 | "30759c36349240229a60499721a458da",
793 | "f911485dcd324e5e8064ca5198b13b49",
794 | "f515c31b840141fdbbaf14c272cd1c8d",
795 | "c70eb9c425b242d9afd22368720f85aa",
796 | "8bcd569c425d433daa6cc2f6b0ee5049",
797 | "aaf7468202274d45b189bfca4be3e78d",
798 | "2c84b9cac1634532b19a3cd3630ddc41",
799 | "7b7388d66d8746ad89c170b537d75798",
800 | "81083dc746b74780b970871df7a806c5",
801 | "97fccff63dca4eab96e24397046242ef",
802 | "b4136fcefb43464598940b995259d91b"
803 | ]
804 | }
805 | },
806 | "source": [
807 | "from transformers import BertTokenizer, TFGPT2LMHeadModel\n",
808 | "\n",
809 | "tokenizer = BertTokenizer.from_pretrained(\"mymusise/gpt2-medium-chinese\")\n",
810 | "\n",
811 | "model = TFGPT2LMHeadModel.from_pretrained(\"mymusise/gpt2-medium-chinese\")"
812 | ],
813 | "execution_count": null,
814 | "outputs": []
815 | },
816 | {
817 | "cell_type": "code",
818 | "metadata": {
819 | "id": "AHfpYLsPdEgU"
820 | },
821 | "source": [
822 | "from transformers import TextGenerationPipeline\n",
823 | "\n",
824 | "text_generator = TextGenerationPipeline(model, tokenizer)\n",
825 | "print(text_generator(\"走向森林\", max_length=64, do_sample=True, repetition_penalty=1.3, top_k=10, eos_token_id=tokenizer.get_vocab().get(\"】\", 0)))\n",
826 | "print(text_generator(\"拿出一本秘籍\", max_length=64, do_sample=True, repetition_penalty=1.3, top_k=10, eos_token_id=tokenizer.get_vocab().get(\"】\", 0)))\n",
827 | "print(text_generator(\"跨越山丘\", max_length=64, do_sample=True, repetition_penalty=1.3, top_k=10, eos_token_id=tokenizer.get_vocab().get(\"】\", 0)))"
828 | ],
829 | "execution_count": null,
830 | "outputs": []
831 | }
832 | ]
833 | }
--------------------------------------------------------------------------------
/examples/mixed_precision_test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "orig_nbformat": 2,
6 | "colab": {
7 | "name": "demo-fp16.ipynb",
8 | "provenance": []
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "accelerator": "GPU",
15 | "widgets": {
16 | "application/vnd.jupyter.widget-state+json": {
17 | "524dc030edf24fd294390f9da5a360e0": {
18 | "model_module": "@jupyter-widgets/controls",
19 | "model_name": "HBoxModel",
20 | "state": {
21 | "_view_name": "HBoxView",
22 | "_dom_classes": [],
23 | "_model_name": "HBoxModel",
24 | "_view_module": "@jupyter-widgets/controls",
25 | "_model_module_version": "1.5.0",
26 | "_view_count": null,
27 | "_view_module_version": "1.5.0",
28 | "box_style": "",
29 | "layout": "IPY_MODEL_1574aa76516e43b9b7772615c2a679f1",
30 | "_model_module": "@jupyter-widgets/controls",
31 | "children": [
32 | "IPY_MODEL_1f6934ca621f4bf4a87272798721ff8d",
33 | "IPY_MODEL_c0b1c0b0c4d8424bbca3471347c7fdaa"
34 | ]
35 | }
36 | },
37 | "1574aa76516e43b9b7772615c2a679f1": {
38 | "model_module": "@jupyter-widgets/base",
39 | "model_name": "LayoutModel",
40 | "state": {
41 | "_view_name": "LayoutView",
42 | "grid_template_rows": null,
43 | "right": null,
44 | "justify_content": null,
45 | "_view_module": "@jupyter-widgets/base",
46 | "overflow": null,
47 | "_model_module_version": "1.2.0",
48 | "_view_count": null,
49 | "flex_flow": null,
50 | "width": null,
51 | "min_width": null,
52 | "border": null,
53 | "align_items": null,
54 | "bottom": null,
55 | "_model_module": "@jupyter-widgets/base",
56 | "top": null,
57 | "grid_column": null,
58 | "overflow_y": null,
59 | "overflow_x": null,
60 | "grid_auto_flow": null,
61 | "grid_area": null,
62 | "grid_template_columns": null,
63 | "flex": null,
64 | "_model_name": "LayoutModel",
65 | "justify_items": null,
66 | "grid_row": null,
67 | "max_height": null,
68 | "align_content": null,
69 | "visibility": null,
70 | "align_self": null,
71 | "height": null,
72 | "min_height": null,
73 | "padding": null,
74 | "grid_auto_rows": null,
75 | "grid_gap": null,
76 | "max_width": null,
77 | "order": null,
78 | "_view_module_version": "1.2.0",
79 | "grid_template_areas": null,
80 | "object_position": null,
81 | "object_fit": null,
82 | "grid_auto_columns": null,
83 | "margin": null,
84 | "display": null,
85 | "left": null
86 | }
87 | },
88 | "1f6934ca621f4bf4a87272798721ff8d": {
89 | "model_module": "@jupyter-widgets/controls",
90 | "model_name": "FloatProgressModel",
91 | "state": {
92 | "_view_name": "ProgressView",
93 | "style": "IPY_MODEL_ff848af2f9e249919192db4f81210802",
94 | "_dom_classes": [],
95 | "description": "Downloading: 100%",
96 | "_model_name": "FloatProgressModel",
97 | "bar_style": "success",
98 | "max": 1042301,
99 | "_view_module": "@jupyter-widgets/controls",
100 | "_model_module_version": "1.5.0",
101 | "value": 1042301,
102 | "_view_count": null,
103 | "_view_module_version": "1.5.0",
104 | "orientation": "horizontal",
105 | "min": 0,
106 | "description_tooltip": null,
107 | "_model_module": "@jupyter-widgets/controls",
108 | "layout": "IPY_MODEL_a49942edd1514293a65af568a3d8227c"
109 | }
110 | },
111 | "c0b1c0b0c4d8424bbca3471347c7fdaa": {
112 | "model_module": "@jupyter-widgets/controls",
113 | "model_name": "HTMLModel",
114 | "state": {
115 | "_view_name": "HTMLView",
116 | "style": "IPY_MODEL_58cd1ceac3da46e79928efd3b6af8a3d",
117 | "_dom_classes": [],
118 | "description": "",
119 | "_model_name": "HTMLModel",
120 | "placeholder": "",
121 | "_view_module": "@jupyter-widgets/controls",
122 | "_model_module_version": "1.5.0",
123 | "value": " 1.04M/1.04M [00:00<00:00, 9.13MB/s]",
124 | "_view_count": null,
125 | "_view_module_version": "1.5.0",
126 | "description_tooltip": null,
127 | "_model_module": "@jupyter-widgets/controls",
128 | "layout": "IPY_MODEL_0c1f67bac76848598762a0da1fff0032"
129 | }
130 | },
131 | "ff848af2f9e249919192db4f81210802": {
132 | "model_module": "@jupyter-widgets/controls",
133 | "model_name": "ProgressStyleModel",
134 | "state": {
135 | "_view_name": "StyleView",
136 | "_model_name": "ProgressStyleModel",
137 | "description_width": "initial",
138 | "_view_module": "@jupyter-widgets/base",
139 | "_model_module_version": "1.5.0",
140 | "_view_count": null,
141 | "_view_module_version": "1.2.0",
142 | "bar_color": null,
143 | "_model_module": "@jupyter-widgets/controls"
144 | }
145 | },
146 | "a49942edd1514293a65af568a3d8227c": {
147 | "model_module": "@jupyter-widgets/base",
148 | "model_name": "LayoutModel",
149 | "state": {
150 | "_view_name": "LayoutView",
151 | "grid_template_rows": null,
152 | "right": null,
153 | "justify_content": null,
154 | "_view_module": "@jupyter-widgets/base",
155 | "overflow": null,
156 | "_model_module_version": "1.2.0",
157 | "_view_count": null,
158 | "flex_flow": null,
159 | "width": null,
160 | "min_width": null,
161 | "border": null,
162 | "align_items": null,
163 | "bottom": null,
164 | "_model_module": "@jupyter-widgets/base",
165 | "top": null,
166 | "grid_column": null,
167 | "overflow_y": null,
168 | "overflow_x": null,
169 | "grid_auto_flow": null,
170 | "grid_area": null,
171 | "grid_template_columns": null,
172 | "flex": null,
173 | "_model_name": "LayoutModel",
174 | "justify_items": null,
175 | "grid_row": null,
176 | "max_height": null,
177 | "align_content": null,
178 | "visibility": null,
179 | "align_self": null,
180 | "height": null,
181 | "min_height": null,
182 | "padding": null,
183 | "grid_auto_rows": null,
184 | "grid_gap": null,
185 | "max_width": null,
186 | "order": null,
187 | "_view_module_version": "1.2.0",
188 | "grid_template_areas": null,
189 | "object_position": null,
190 | "object_fit": null,
191 | "grid_auto_columns": null,
192 | "margin": null,
193 | "display": null,
194 | "left": null
195 | }
196 | },
197 | "58cd1ceac3da46e79928efd3b6af8a3d": {
198 | "model_module": "@jupyter-widgets/controls",
199 | "model_name": "DescriptionStyleModel",
200 | "state": {
201 | "_view_name": "StyleView",
202 | "_model_name": "DescriptionStyleModel",
203 | "description_width": "",
204 | "_view_module": "@jupyter-widgets/base",
205 | "_model_module_version": "1.5.0",
206 | "_view_count": null,
207 | "_view_module_version": "1.2.0",
208 | "_model_module": "@jupyter-widgets/controls"
209 | }
210 | },
211 | "0c1f67bac76848598762a0da1fff0032": {
212 | "model_module": "@jupyter-widgets/base",
213 | "model_name": "LayoutModel",
214 | "state": {
215 | "_view_name": "LayoutView",
216 | "grid_template_rows": null,
217 | "right": null,
218 | "justify_content": null,
219 | "_view_module": "@jupyter-widgets/base",
220 | "overflow": null,
221 | "_model_module_version": "1.2.0",
222 | "_view_count": null,
223 | "flex_flow": null,
224 | "width": null,
225 | "min_width": null,
226 | "border": null,
227 | "align_items": null,
228 | "bottom": null,
229 | "_model_module": "@jupyter-widgets/base",
230 | "top": null,
231 | "grid_column": null,
232 | "overflow_y": null,
233 | "overflow_x": null,
234 | "grid_auto_flow": null,
235 | "grid_area": null,
236 | "grid_template_columns": null,
237 | "flex": null,
238 | "_model_name": "LayoutModel",
239 | "justify_items": null,
240 | "grid_row": null,
241 | "max_height": null,
242 | "align_content": null,
243 | "visibility": null,
244 | "align_self": null,
245 | "height": null,
246 | "min_height": null,
247 | "padding": null,
248 | "grid_auto_rows": null,
249 | "grid_gap": null,
250 | "max_width": null,
251 | "order": null,
252 | "_view_module_version": "1.2.0",
253 | "grid_template_areas": null,
254 | "object_position": null,
255 | "object_fit": null,
256 | "grid_auto_columns": null,
257 | "margin": null,
258 | "display": null,
259 | "left": null
260 | }
261 | },
262 | "7010786313364b3f80713de08b571528": {
263 | "model_module": "@jupyter-widgets/controls",
264 | "model_name": "HBoxModel",
265 | "state": {
266 | "_view_name": "HBoxView",
267 | "_dom_classes": [],
268 | "_model_name": "HBoxModel",
269 | "_view_module": "@jupyter-widgets/controls",
270 | "_model_module_version": "1.5.0",
271 | "_view_count": null,
272 | "_view_module_version": "1.5.0",
273 | "box_style": "",
274 | "layout": "IPY_MODEL_8276e2dff64f4a0aaf68a6f7ad9a573b",
275 | "_model_module": "@jupyter-widgets/controls",
276 | "children": [
277 | "IPY_MODEL_04a2a76ac10545c9801420c353dd33d8",
278 | "IPY_MODEL_b74984ac0551484c95a0c6381f664e88"
279 | ]
280 | }
281 | },
282 | "8276e2dff64f4a0aaf68a6f7ad9a573b": {
283 | "model_module": "@jupyter-widgets/base",
284 | "model_name": "LayoutModel",
285 | "state": {
286 | "_view_name": "LayoutView",
287 | "grid_template_rows": null,
288 | "right": null,
289 | "justify_content": null,
290 | "_view_module": "@jupyter-widgets/base",
291 | "overflow": null,
292 | "_model_module_version": "1.2.0",
293 | "_view_count": null,
294 | "flex_flow": null,
295 | "width": null,
296 | "min_width": null,
297 | "border": null,
298 | "align_items": null,
299 | "bottom": null,
300 | "_model_module": "@jupyter-widgets/base",
301 | "top": null,
302 | "grid_column": null,
303 | "overflow_y": null,
304 | "overflow_x": null,
305 | "grid_auto_flow": null,
306 | "grid_area": null,
307 | "grid_template_columns": null,
308 | "flex": null,
309 | "_model_name": "LayoutModel",
310 | "justify_items": null,
311 | "grid_row": null,
312 | "max_height": null,
313 | "align_content": null,
314 | "visibility": null,
315 | "align_self": null,
316 | "height": null,
317 | "min_height": null,
318 | "padding": null,
319 | "grid_auto_rows": null,
320 | "grid_gap": null,
321 | "max_width": null,
322 | "order": null,
323 | "_view_module_version": "1.2.0",
324 | "grid_template_areas": null,
325 | "object_position": null,
326 | "object_fit": null,
327 | "grid_auto_columns": null,
328 | "margin": null,
329 | "display": null,
330 | "left": null
331 | }
332 | },
333 | "04a2a76ac10545c9801420c353dd33d8": {
334 | "model_module": "@jupyter-widgets/controls",
335 | "model_name": "FloatProgressModel",
336 | "state": {
337 | "_view_name": "ProgressView",
338 | "style": "IPY_MODEL_9faaade8dbdd4115b4b5b0b354d2b657",
339 | "_dom_classes": [],
340 | "description": "Downloading: 100%",
341 | "_model_name": "FloatProgressModel",
342 | "bar_style": "success",
343 | "max": 456318,
344 | "_view_module": "@jupyter-widgets/controls",
345 | "_model_module_version": "1.5.0",
346 | "value": 456318,
347 | "_view_count": null,
348 | "_view_module_version": "1.5.0",
349 | "orientation": "horizontal",
350 | "min": 0,
351 | "description_tooltip": null,
352 | "_model_module": "@jupyter-widgets/controls",
353 | "layout": "IPY_MODEL_2c00d7cf2a9f4c389a045cffe8d3430d"
354 | }
355 | },
356 | "b74984ac0551484c95a0c6381f664e88": {
357 | "model_module": "@jupyter-widgets/controls",
358 | "model_name": "HTMLModel",
359 | "state": {
360 | "_view_name": "HTMLView",
361 | "style": "IPY_MODEL_ad07e052ea2947219dadb7a1ee738b0f",
362 | "_dom_classes": [],
363 | "description": "",
364 | "_model_name": "HTMLModel",
365 | "placeholder": "",
366 | "_view_module": "@jupyter-widgets/controls",
367 | "_model_module_version": "1.5.0",
368 | "value": " 456k/456k [00:00<00:00, 4.23MB/s]",
369 | "_view_count": null,
370 | "_view_module_version": "1.5.0",
371 | "description_tooltip": null,
372 | "_model_module": "@jupyter-widgets/controls",
373 | "layout": "IPY_MODEL_7cc3926286d54a418f42f97967f7325c"
374 | }
375 | },
376 | "9faaade8dbdd4115b4b5b0b354d2b657": {
377 | "model_module": "@jupyter-widgets/controls",
378 | "model_name": "ProgressStyleModel",
379 | "state": {
380 | "_view_name": "StyleView",
381 | "_model_name": "ProgressStyleModel",
382 | "description_width": "initial",
383 | "_view_module": "@jupyter-widgets/base",
384 | "_model_module_version": "1.5.0",
385 | "_view_count": null,
386 | "_view_module_version": "1.2.0",
387 | "bar_color": null,
388 | "_model_module": "@jupyter-widgets/controls"
389 | }
390 | },
391 | "2c00d7cf2a9f4c389a045cffe8d3430d": {
392 | "model_module": "@jupyter-widgets/base",
393 | "model_name": "LayoutModel",
394 | "state": {
395 | "_view_name": "LayoutView",
396 | "grid_template_rows": null,
397 | "right": null,
398 | "justify_content": null,
399 | "_view_module": "@jupyter-widgets/base",
400 | "overflow": null,
401 | "_model_module_version": "1.2.0",
402 | "_view_count": null,
403 | "flex_flow": null,
404 | "width": null,
405 | "min_width": null,
406 | "border": null,
407 | "align_items": null,
408 | "bottom": null,
409 | "_model_module": "@jupyter-widgets/base",
410 | "top": null,
411 | "grid_column": null,
412 | "overflow_y": null,
413 | "overflow_x": null,
414 | "grid_auto_flow": null,
415 | "grid_area": null,
416 | "grid_template_columns": null,
417 | "flex": null,
418 | "_model_name": "LayoutModel",
419 | "justify_items": null,
420 | "grid_row": null,
421 | "max_height": null,
422 | "align_content": null,
423 | "visibility": null,
424 | "align_self": null,
425 | "height": null,
426 | "min_height": null,
427 | "padding": null,
428 | "grid_auto_rows": null,
429 | "grid_gap": null,
430 | "max_width": null,
431 | "order": null,
432 | "_view_module_version": "1.2.0",
433 | "grid_template_areas": null,
434 | "object_position": null,
435 | "object_fit": null,
436 | "grid_auto_columns": null,
437 | "margin": null,
438 | "display": null,
439 | "left": null
440 | }
441 | },
442 | "ad07e052ea2947219dadb7a1ee738b0f": {
443 | "model_module": "@jupyter-widgets/controls",
444 | "model_name": "DescriptionStyleModel",
445 | "state": {
446 | "_view_name": "StyleView",
447 | "_model_name": "DescriptionStyleModel",
448 | "description_width": "",
449 | "_view_module": "@jupyter-widgets/base",
450 | "_model_module_version": "1.5.0",
451 | "_view_count": null,
452 | "_view_module_version": "1.2.0",
453 | "_model_module": "@jupyter-widgets/controls"
454 | }
455 | },
456 | "7cc3926286d54a418f42f97967f7325c": {
457 | "model_module": "@jupyter-widgets/base",
458 | "model_name": "LayoutModel",
459 | "state": {
460 | "_view_name": "LayoutView",
461 | "grid_template_rows": null,
462 | "right": null,
463 | "justify_content": null,
464 | "_view_module": "@jupyter-widgets/base",
465 | "overflow": null,
466 | "_model_module_version": "1.2.0",
467 | "_view_count": null,
468 | "flex_flow": null,
469 | "width": null,
470 | "min_width": null,
471 | "border": null,
472 | "align_items": null,
473 | "bottom": null,
474 | "_model_module": "@jupyter-widgets/base",
475 | "top": null,
476 | "grid_column": null,
477 | "overflow_y": null,
478 | "overflow_x": null,
479 | "grid_auto_flow": null,
480 | "grid_area": null,
481 | "grid_template_columns": null,
482 | "flex": null,
483 | "_model_name": "LayoutModel",
484 | "justify_items": null,
485 | "grid_row": null,
486 | "max_height": null,
487 | "align_content": null,
488 | "visibility": null,
489 | "align_self": null,
490 | "height": null,
491 | "min_height": null,
492 | "padding": null,
493 | "grid_auto_rows": null,
494 | "grid_gap": null,
495 | "max_width": null,
496 | "order": null,
497 | "_view_module_version": "1.2.0",
498 | "grid_template_areas": null,
499 | "object_position": null,
500 | "object_fit": null,
501 | "grid_auto_columns": null,
502 | "margin": null,
503 | "display": null,
504 | "left": null
505 | }
506 | }
507 | }
508 | }
509 | },
510 | "cells": [
511 | {
512 | "cell_type": "code",
513 | "metadata": {
514 | "id": "i5bGic4pR5Pm"
515 | },
516 | "source": [
517 | "!pip install transformers==4.3"
518 | ],
519 | "execution_count": null,
520 | "outputs": []
521 | },
522 | {
523 | "cell_type": "code",
524 | "metadata": {
525 | "colab": {
526 | "base_uri": "https://localhost:8080/",
527 | "height": 132,
528 | "referenced_widgets": [
529 | "524dc030edf24fd294390f9da5a360e0",
530 | "1574aa76516e43b9b7772615c2a679f1",
531 | "1f6934ca621f4bf4a87272798721ff8d",
532 | "c0b1c0b0c4d8424bbca3471347c7fdaa",
533 | "ff848af2f9e249919192db4f81210802",
534 | "a49942edd1514293a65af568a3d8227c",
535 | "58cd1ceac3da46e79928efd3b6af8a3d",
536 | "0c1f67bac76848598762a0da1fff0032",
537 | "7010786313364b3f80713de08b571528",
538 | "8276e2dff64f4a0aaf68a6f7ad9a573b",
539 | "04a2a76ac10545c9801420c353dd33d8",
540 | "b74984ac0551484c95a0c6381f664e88",
541 | "9faaade8dbdd4115b4b5b0b354d2b657",
542 | "2c00d7cf2a9f4c389a045cffe8d3430d",
543 | "ad07e052ea2947219dadb7a1ee738b0f",
544 | "7cc3926286d54a418f42f97967f7325c"
545 | ]
546 | },
547 | "id": "MoFW4TwYz5QN",
548 | "outputId": "2498de45-ec66-4c90-ff2f-975c411d93f0"
549 | },
550 | "source": [
551 | "import tensorflow as tf\n",
552 | "from transformers import GPT2Tokenizer\n",
553 | "\n",
554 | "tokenizer = GPT2Tokenizer.from_pretrained(\"distilgpt2\")\n",
555 | "\n",
556 | "text = \"\"\"\n",
557 | "A SQUAT grey building of only thirty-four stories. Over the main entrance the\n",
558 | "words, CENTRAL LONDON HATCHERY AND CONDITIONING CENTRE,\n",
559 | "and, in a shield, the World State’s motto, COMMUNITY, IDENTITY, STABILITY.\n",
560 | "The enormous room on the ground floor faced towards the north. Cold for\n",
561 | "all the summer beyond the panes, for all the tropical heat of the room itself,\n",
562 | "a harsh thin light glared through the windows, hungrily seeking some draped\n",
563 | "lay figure, some pallid shape of academic goose-flesh, but finding only the glass\n",
564 | "and nickel and bleakly shining porcelain of a laboratory. Wintriness responded\n",
565 | "to wintriness. The overalls of the workers were white, their hands gloved with\n",
566 | "a pale corpse-coloured rubber. The light was frozen, dead, a ghost. Only from\n",
567 | "the yellow barrels of the microscopes did it borrow a certain rich and living\n",
568 | "substance, lying along the polished tubes like butter, streak after luscious streak\n",
569 | "in long recession down the work tables.\n",
570 | "“And this,” said the Director opening the door, “is the Fertilizing Room.”\n",
571 | "Bent over their instruments, three hundred Fertilizers were plunged, as the Director of Hatcheries and Conditioning entered the room, in the scarcely breathing silence, the absent-minded, soliloquizing hum or whistle, of absorbed\n",
572 | "concentration. A troop of newly arrived students, very young, pink and callow,\n",
573 | "followed nervously, rather abjectly, at the Director’s heels. Each of them carried\n",
574 | "a notebook, in which, whenever the great man spoke, he desperately scribbled.\n",
575 | "Straight from the horse’s mouth. It was a rare privilege. The D. H. C. for Central\n",
576 | "London always made a point of personally conducting his new students round\n",
577 | "the various departments.\n",
578 | "“Just to give you a general idea,” he would explain to them. For of course some\n",
579 | "sort of general idea they must have, if they were to do their work intelligentlythough as little of one, if they were to be good and happy members of society, as\n",
580 | "possible. For particulars, as every one knows, make for virture and happiness;\n",
581 | "generalities are intellectually necessary evils. Not philosophers but fretsawyers\n",
582 | "\"\"\" * 100\n",
583 | "\n",
584 | "tokenized_text = tokenizer.encode(text)\n",
585 | "\n",
586 | "examples = []\n",
587 | "block_size = 512\n",
588 | "for i in range(0, len(tokenized_text) - block_size + 1, block_size): # Truncate in block of block_size\n",
589 | " examples.append(tokenized_text[i:i + block_size])\n",
590 | "\n",
591 | "inputs, labels = [], []\n",
592 | "for ex in examples:\n",
593 | " inputs.append(ex[:-1])\n",
594 | " labels.append(ex[1:])\n",
595 | "\n",
596 | "dataset = tf.data.Dataset.from_tensor_slices((inputs, labels))\n",
597 | "\n",
598 | "\n",
599 | "\n",
600 | "BATCH_SIZE = 16\n",
601 | "BUFFER_SIZE = 10000\n",
602 | "\n",
603 | "dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)"
604 | ],
605 | "execution_count": 2,
606 | "outputs": [
607 | {
608 | "output_type": "display_data",
609 | "data": {
610 | "application/vnd.jupyter.widget-view+json": {
611 | "model_id": "524dc030edf24fd294390f9da5a360e0",
612 | "version_minor": 0,
613 | "version_major": 2
614 | },
615 | "text/plain": [
616 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…"
617 | ]
618 | },
619 | "metadata": {
620 | "tags": []
621 | }
622 | },
623 | {
624 | "output_type": "stream",
625 | "text": [
626 | "\n"
627 | ],
628 | "name": "stdout"
629 | },
630 | {
631 | "output_type": "display_data",
632 | "data": {
633 | "application/vnd.jupyter.widget-view+json": {
634 | "model_id": "7010786313364b3f80713de08b571528",
635 | "version_minor": 0,
636 | "version_major": 2
637 | },
638 | "text/plain": [
639 | "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…"
640 | ]
641 | },
642 | "metadata": {
643 | "tags": []
644 | }
645 | },
646 | {
647 | "output_type": "stream",
648 | "text": [
649 | "\n"
650 | ],
651 | "name": "stdout"
652 | },
653 | {
654 | "output_type": "stream",
655 | "text": [
656 | "Token indices sequence length is longer than the specified maximum sequence length for this model (51200 > 1024). Running this sequence through the model will result in indexing errors\n"
657 | ],
658 | "name": "stderr"
659 | }
660 | ]
661 | },
662 | {
663 | "cell_type": "code",
664 | "metadata": {
665 | "colab": {
666 | "base_uri": "https://localhost:8080/"
667 | },
668 | "id": "YthP1ab-R5Pr",
669 | "outputId": "eba49bf5-7109-491c-a216-c834d88eefc7"
670 | },
671 | "source": [
672 | "#@title works good without mixed_precision\n",
673 | "\n",
674 | "from transformers import GPT2Tokenizer, TFGPT2LMHeadModel, GPT2Config\n",
675 | "from tensorflow.keras.mixed_precision import experimental, global_policy\n",
676 | "\n",
677 | "current_policy = global_policy()\n",
678 | "print(f\"current_policy: {current_policy}\")\n",
679 | "\n",
680 | "config = GPT2Config(\n",
681 | " vocab_size=tokenizer.vocab_size,\n",
682 | " n_positions=512,\n",
683 | " n_ctx=512,\n",
684 | " n_embd=512,\n",
685 | " n_layer=4,\n",
686 | " n_head=4,\n",
687 | " pad_token_id=tokenizer.pad_token_id,\n",
688 | " use_cache=False,\n",
689 | " )\n",
690 | "model = TFGPT2LMHeadModel(config)\n",
691 | "\n",
692 | "optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)\n",
693 | "loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
694 | "metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')\n",
695 | "model.compile(optimizer=optimizer, loss=[loss, *[None] * model.config.n_layer], metrics=[metric])\n",
696 | "model.fit(dataset, epochs=20)"
697 | ],
698 | "execution_count": 6,
699 | "outputs": [
700 | {
701 | "output_type": "stream",
702 | "text": [
703 | "current_policy: \n",
704 | "WARNING:tensorflow:tf.keras.mixed_precision.experimental.LossScaleOptimizer is deprecated. Please use tf.keras.mixed_precision.LossScaleOptimizer instead. Note that the non-experimental LossScaleOptimizer does not take a DynamicLossScale but instead takes the dynamic configuration directly in the constructor. For example:\n",
705 | " opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt)\n",
706 | "\n",
707 | "Epoch 1/20\n",
708 | "WARNING:tensorflow:The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).\n",
709 | "WARNING:tensorflow:The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.\n",
710 | "WARNING:tensorflow:The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).\n",
711 | "WARNING:tensorflow:The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.\n",
712 | "6/6 [==============================] - 18s 2s/step - loss: 10.5879 - accuracy: 0.0317\n",
713 | "Epoch 2/20\n",
714 | "6/6 [==============================] - 11s 2s/step - loss: 9.7930 - accuracy: 0.0772\n",
715 | "Epoch 3/20\n",
716 | "6/6 [==============================] - 11s 2s/step - loss: 9.4057 - accuracy: 0.0779\n",
717 | "Epoch 4/20\n",
718 | "6/6 [==============================] - 11s 2s/step - loss: 9.0126 - accuracy: 0.1084\n",
719 | "Epoch 5/20\n",
720 | "6/6 [==============================] - 11s 2s/step - loss: 8.5122 - accuracy: 0.2280\n",
721 | "Epoch 6/20\n",
722 | "6/6 [==============================] - 11s 2s/step - loss: 7.9687 - accuracy: 0.3677\n",
723 | "Epoch 7/20\n",
724 | "6/6 [==============================] - 11s 2s/step - loss: 7.4126 - accuracy: 0.4819\n",
725 | "Epoch 8/20\n",
726 | "6/6 [==============================] - 11s 2s/step - loss: 6.8892 - accuracy: 0.5841\n",
727 | "Epoch 9/20\n",
728 | "6/6 [==============================] - 11s 2s/step - loss: 6.4235 - accuracy: 0.6719\n",
729 | "Epoch 10/20\n",
730 | "6/6 [==============================] - 11s 2s/step - loss: 6.0039 - accuracy: 0.7604\n",
731 | "Epoch 11/20\n",
732 | "6/6 [==============================] - 11s 2s/step - loss: 5.6319 - accuracy: 0.8393\n",
733 | "Epoch 12/20\n",
734 | "6/6 [==============================] - 11s 2s/step - loss: 5.2969 - accuracy: 0.8953\n",
735 | "Epoch 13/20\n",
736 | "6/6 [==============================] - 11s 2s/step - loss: 4.9904 - accuracy: 0.9289\n",
737 | "Epoch 14/20\n",
738 | "6/6 [==============================] - 11s 2s/step - loss: 4.7078 - accuracy: 0.9534\n",
739 | "Epoch 15/20\n",
740 | "6/6 [==============================] - 11s 2s/step - loss: 4.4475 - accuracy: 0.9708\n",
741 | "Epoch 16/20\n",
742 | "6/6 [==============================] - 11s 2s/step - loss: 4.2041 - accuracy: 0.9823\n",
743 | "Epoch 17/20\n",
744 | "6/6 [==============================] - 11s 2s/step - loss: 3.9783 - accuracy: 0.9875\n",
745 | "Epoch 18/20\n",
746 | "6/6 [==============================] - 11s 2s/step - loss: 3.7639 - accuracy: 0.9924\n",
747 | "Epoch 19/20\n",
748 | "6/6 [==============================] - 11s 2s/step - loss: 3.5620 - accuracy: 0.9947\n",
749 | "Epoch 20/20\n",
750 | "6/6 [==============================] - 11s 2s/step - loss: 3.3701 - accuracy: 0.9966\n"
751 | ],
752 | "name": "stdout"
753 | },
754 | {
755 | "output_type": "execute_result",
756 | "data": {
757 | "text/plain": [
758 | ""
759 | ]
760 | },
761 | "metadata": {
762 | "tags": []
763 | },
764 | "execution_count": 6
765 | }
766 | ]
767 | },
768 | {
769 | "cell_type": "code",
770 | "metadata": {
771 | "colab": {
772 | "base_uri": "https://localhost:8080/"
773 | },
774 | "id": "FAeMx5gubydW",
775 | "outputId": "4098f674-bf82-4dc0-c5ff-54b88810790e"
776 | },
777 | "source": [
778 | "#@title current mixed_precision with ad bad result\n",
779 | "from transformers import GPT2Tokenizer, TFGPT2LMHeadModel, GPT2Config\n",
780 | "from tensorflow.keras.mixed_precision import experimental, global_policy\n",
781 | "\n",
782 | "policy = experimental.Policy('mixed_float16')\n",
783 | "experimental.set_policy(policy)\n",
784 | "current_policy = global_policy()\n",
785 | "print(f\"current_policy: {current_policy}\")\n",
786 | "\n",
787 | "\n",
788 | "config = GPT2Config(\n",
789 | " vocab_size=tokenizer.vocab_size,\n",
790 | " n_positions=512,\n",
791 | " n_ctx=512,\n",
792 | " n_embd=512,\n",
793 | " n_layer=4,\n",
794 | " n_head=4,\n",
795 | " pad_token_id=tokenizer.pad_token_id,\n",
796 | " use_cache=False,\n",
797 | " )\n",
798 | "model = TFGPT2LMHeadModel(config)\n",
799 | "\n",
800 | "optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)\n",
801 | "loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
802 | "metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')\n",
803 | "model.compile(optimizer=optimizer, loss=[loss, *[None] * model.config.n_layer], metrics=[metric])\n",
804 | "model.fit(dataset, epochs=100)"
805 | ],
806 | "execution_count": 4,
807 | "outputs": [
808 | {
809 | "output_type": "stream",
810 | "text": [
811 | "INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK\n",
812 | "Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: Tesla T4, compute capability 7.5\n",
813 | "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/mixed_precision/loss_scale.py:56: DynamicLossScale.__init__ (from tensorflow.python.training.experimental.loss_scale) is deprecated and will be removed in a future version.\n",
814 | "Instructions for updating:\n",
815 | "Use tf.keras.mixed_precision.LossScaleOptimizer instead. LossScaleOptimizer now has all the functionality of DynamicLossScale\n",
816 | "current_policy: \n",
817 | "WARNING:tensorflow:tf.keras.mixed_precision.experimental.LossScaleOptimizer is deprecated. Please use tf.keras.mixed_precision.LossScaleOptimizer instead. Note that the non-experimental LossScaleOptimizer does not take a DynamicLossScale but instead takes the dynamic configuration directly in the constructor. For example:\n",
818 | " opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt)\n",
819 | "\n",
820 | "Epoch 1/100\n",
821 | "WARNING:tensorflow:The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).\n",
822 | "WARNING:tensorflow:AutoGraph could not transform > and will run it as-is.\n",
823 | "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
824 | "Cause: module, class, method, function, traceback, frame, or code object was expected, got cython_function_or_method\n",
825 | "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
826 | "WARNING: AutoGraph could not transform > and will run it as-is.\n",
827 | "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
828 | "Cause: module, class, method, function, traceback, frame, or code object was expected, got cython_function_or_method\n",
829 | "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
830 | "WARNING:tensorflow:AutoGraph could not transform and will run it as-is.\n",
831 | "Cause: while/else statement not yet supported\n",
832 | "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
833 | "WARNING: AutoGraph could not transform and will run it as-is.\n",
834 | "Cause: while/else statement not yet supported\n",
835 | "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
836 | "WARNING:tensorflow:The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.\n",
837 | "WARNING:tensorflow:The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).\n",
838 | "WARNING:tensorflow:The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.\n",
839 | "6/6 [==============================] - 49s 442ms/step - loss: 10.5922 - accuracy: 0.0294\n",
840 | "Epoch 2/100\n",
841 | "6/6 [==============================] - 3s 445ms/step - loss: 9.7974 - accuracy: 0.0781\n",
842 | "Epoch 3/100\n",
843 | "6/6 [==============================] - 3s 447ms/step - loss: 9.4078 - accuracy: 0.0786\n",
844 | "Epoch 4/100\n",
845 | "6/6 [==============================] - 3s 449ms/step - loss: 9.0169 - accuracy: 0.1165\n",
846 | "Epoch 5/100\n",
847 | "6/6 [==============================] - 3s 449ms/step - loss: 8.5138 - accuracy: 0.2169\n",
848 | "Epoch 6/100\n",
849 | "6/6 [==============================] - 3s 452ms/step - loss: 7.9683 - accuracy: 0.3333\n",
850 | "Epoch 7/100\n",
851 | "6/6 [==============================] - 3s 452ms/step - loss: 7.4077 - accuracy: 0.4327\n",
852 | "Epoch 8/100\n",
853 | "6/6 [==============================] - 3s 454ms/step - loss: 6.8792 - accuracy: 0.4937\n",
854 | "Epoch 9/100\n",
855 | "6/6 [==============================] - 3s 455ms/step - loss: 6.4030 - accuracy: 0.5516\n",
856 | "Epoch 10/100\n",
857 | "6/6 [==============================] - 3s 456ms/step - loss: 5.9803 - accuracy: 0.6093\n",
858 | "Epoch 11/100\n",
859 | "6/6 [==============================] - 3s 461ms/step - loss: 5.6006 - accuracy: 0.6531\n",
860 | "Epoch 12/100\n",
861 | "6/6 [==============================] - 3s 463ms/step - loss: 5.2593 - accuracy: 0.6836\n",
862 | "Epoch 13/100\n",
863 | "6/6 [==============================] - 3s 459ms/step - loss: 4.9552 - accuracy: 0.7006\n",
864 | "Epoch 14/100\n",
865 | "6/6 [==============================] - 3s 463ms/step - loss: 4.6699 - accuracy: 0.7121\n",
866 | "Epoch 15/100\n",
867 | "6/6 [==============================] - 3s 458ms/step - loss: 4.4067 - accuracy: 0.7179\n",
868 | "Epoch 16/100\n",
869 | "6/6 [==============================] - 3s 459ms/step - loss: 4.1586 - accuracy: 0.7237\n",
870 | "Epoch 17/100\n",
871 | "6/6 [==============================] - 3s 454ms/step - loss: 3.9303 - accuracy: 0.7264\n",
872 | "Epoch 18/100\n",
873 | "6/6 [==============================] - 3s 456ms/step - loss: 3.7146 - accuracy: 0.7287\n",
874 | "Epoch 19/100\n",
875 | "6/6 [==============================] - 3s 453ms/step - loss: 3.5078 - accuracy: 0.7305\n",
876 | "Epoch 20/100\n",
877 | "6/6 [==============================] - 3s 452ms/step - loss: 3.3146 - accuracy: 0.7317\n",
878 | "Epoch 21/100\n",
879 | "6/6 [==============================] - 3s 449ms/step - loss: 3.1277 - accuracy: 0.7326\n",
880 | "Epoch 22/100\n",
881 | "6/6 [==============================] - 3s 451ms/step - loss: 2.9503 - accuracy: 0.7331\n",
882 | "Epoch 23/100\n",
883 | "6/6 [==============================] - 3s 450ms/step - loss: 2.7794 - accuracy: 0.7335\n",
884 | "Epoch 24/100\n",
885 | "6/6 [==============================] - 3s 449ms/step - loss: 2.6134 - accuracy: 0.7337\n",
886 | "Epoch 25/100\n",
887 | "6/6 [==============================] - 3s 449ms/step - loss: 2.4565 - accuracy: 0.7337\n",
888 | "Epoch 26/100\n",
889 | "6/6 [==============================] - 3s 449ms/step - loss: 2.3034 - accuracy: 0.7338\n",
890 | "Epoch 27/100\n",
891 | "6/6 [==============================] - 3s 449ms/step - loss: 2.1594 - accuracy: 0.7338\n",
892 | "Epoch 28/100\n",
893 | "6/6 [==============================] - 3s 446ms/step - loss: 2.0156 - accuracy: 0.7339\n",
894 | "Epoch 29/100\n",
895 | "6/6 [==============================] - 3s 448ms/step - loss: 1.8817 - accuracy: 0.7339\n",
896 | "Epoch 30/100\n",
897 | "6/6 [==============================] - 3s 449ms/step - loss: 1.7470 - accuracy: 0.7339\n",
898 | "Epoch 31/100\n",
899 | "6/6 [==============================] - 3s 447ms/step - loss: 1.6178 - accuracy: 0.7339\n",
900 | "Epoch 32/100\n",
901 | "6/6 [==============================] - 3s 449ms/step - loss: 1.4951 - accuracy: 0.7338\n",
902 | "Epoch 33/100\n",
903 | "6/6 [==============================] - 3s 447ms/step - loss: 1.3766 - accuracy: 0.7339\n",
904 | "Epoch 34/100\n",
905 | "6/6 [==============================] - 3s 448ms/step - loss: 1.2640 - accuracy: 0.7339\n",
906 | "Epoch 35/100\n",
907 | "6/6 [==============================] - 3s 448ms/step - loss: 1.1555 - accuracy: 0.7339\n",
908 | "Epoch 36/100\n",
909 | "6/6 [==============================] - 3s 447ms/step - loss: 1.0530 - accuracy: 0.7339\n",
910 | "Epoch 37/100\n",
911 | "6/6 [==============================] - 3s 450ms/step - loss: 0.9544 - accuracy: 0.7339\n",
912 | "Epoch 38/100\n",
913 | "6/6 [==============================] - 3s 450ms/step - loss: 0.8634 - accuracy: 0.7338\n",
914 | "Epoch 39/100\n",
915 | "6/6 [==============================] - 3s 448ms/step - loss: 0.7794 - accuracy: 0.7339\n",
916 | "Epoch 40/100\n",
917 | "6/6 [==============================] - 3s 453ms/step - loss: 0.7049 - accuracy: 0.7339\n",
918 | "Epoch 41/100\n",
919 | "6/6 [==============================] - 3s 451ms/step - loss: 0.6380 - accuracy: 0.7338\n",
920 | "Epoch 42/100\n",
921 | "6/6 [==============================] - 3s 451ms/step - loss: 0.5792 - accuracy: 0.7339\n",
922 | "Epoch 43/100\n",
923 | "6/6 [==============================] - 3s 453ms/step - loss: 0.5275 - accuracy: 0.7339\n",
924 | "Epoch 44/100\n",
925 | "6/6 [==============================] - 3s 451ms/step - loss: 0.4823 - accuracy: 0.7339\n",
926 | "Epoch 45/100\n",
927 | "6/6 [==============================] - 3s 453ms/step - loss: 0.4416 - accuracy: 0.7339\n",
928 | "Epoch 46/100\n",
929 | "6/6 [==============================] - 3s 452ms/step - loss: 0.4056 - accuracy: 0.7339\n",
930 | "Epoch 47/100\n",
931 | "6/6 [==============================] - 3s 452ms/step - loss: 0.3733 - accuracy: 0.7339\n",
932 | "Epoch 48/100\n",
933 | "6/6 [==============================] - 3s 451ms/step - loss: 0.3453 - accuracy: 0.7339\n",
934 | "Epoch 49/100\n",
935 | "6/6 [==============================] - 3s 456ms/step - loss: 0.3206 - accuracy: 0.7339\n",
936 | "Epoch 50/100\n",
937 | "6/6 [==============================] - 3s 449ms/step - loss: 0.2987 - accuracy: 0.7339\n",
938 | "Epoch 51/100\n",
939 | "6/6 [==============================] - 3s 452ms/step - loss: 0.2786 - accuracy: 0.7339\n",
940 | "Epoch 52/100\n",
941 | "6/6 [==============================] - 3s 452ms/step - loss: 0.2605 - accuracy: 0.7339\n",
942 | "Epoch 53/100\n",
943 | "6/6 [==============================] - 3s 452ms/step - loss: 0.2442 - accuracy: 0.7338\n",
944 | "Epoch 54/100\n",
945 | "6/6 [==============================] - 3s 453ms/step - loss: 0.2299 - accuracy: 0.7339\n",
946 | "Epoch 55/100\n",
947 | "6/6 [==============================] - 3s 451ms/step - loss: 0.2170 - accuracy: 0.7339\n",
948 | "Epoch 56/100\n",
949 | "6/6 [==============================] - 3s 450ms/step - loss: 0.2047 - accuracy: 0.7339\n",
950 | "Epoch 57/100\n",
951 | "6/6 [==============================] - 3s 451ms/step - loss: 0.1942 - accuracy: 0.7339\n",
952 | "Epoch 58/100\n",
953 | "6/6 [==============================] - 3s 448ms/step - loss: 0.1838 - accuracy: 0.7339\n",
954 | "Epoch 59/100\n",
955 | "6/6 [==============================] - 3s 452ms/step - loss: 0.1749 - accuracy: 0.7339\n",
956 | "Epoch 60/100\n",
957 | "6/6 [==============================] - 3s 447ms/step - loss: 0.1669 - accuracy: 0.7338\n",
958 | "Epoch 61/100\n",
959 | "6/6 [==============================] - 3s 448ms/step - loss: 0.1589 - accuracy: 0.7339\n",
960 | "Epoch 62/100\n",
961 | "6/6 [==============================] - 3s 449ms/step - loss: 0.1516 - accuracy: 0.7339\n",
962 | "Epoch 63/100\n",
963 | "6/6 [==============================] - 3s 449ms/step - loss: 0.1454 - accuracy: 0.7339\n",
964 | "Epoch 64/100\n",
965 | "6/6 [==============================] - 3s 448ms/step - loss: 0.1393 - accuracy: 0.7339\n",
966 | "Epoch 65/100\n",
967 | "6/6 [==============================] - 3s 451ms/step - loss: 0.1336 - accuracy: 0.7339\n",
968 | "Epoch 66/100\n",
969 | "6/6 [==============================] - 3s 446ms/step - loss: 0.1279 - accuracy: 0.7339\n",
970 | "Epoch 67/100\n",
971 | "6/6 [==============================] - 3s 449ms/step - loss: 0.1235 - accuracy: 0.7339\n",
972 | "Epoch 68/100\n",
973 | "6/6 [==============================] - 3s 453ms/step - loss: 0.1189 - accuracy: 0.7339\n",
974 | "Epoch 69/100\n",
975 | "6/6 [==============================] - 3s 447ms/step - loss: 0.1144 - accuracy: 0.7339\n",
976 | "Epoch 70/100\n",
977 | "6/6 [==============================] - 3s 448ms/step - loss: 0.1104 - accuracy: 0.7339\n",
978 | "Epoch 71/100\n",
979 | "6/6 [==============================] - 3s 451ms/step - loss: 0.1067 - accuracy: 0.7339\n",
980 | "Epoch 72/100\n",
981 | "6/6 [==============================] - 3s 450ms/step - loss: 0.1031 - accuracy: 0.7339\n",
982 | "Epoch 73/100\n",
983 | "6/6 [==============================] - 3s 450ms/step - loss: 0.0996 - accuracy: 0.7339\n",
984 | "Epoch 74/100\n",
985 | "6/6 [==============================] - 3s 447ms/step - loss: 0.0965 - accuracy: 0.7339\n",
986 | "Epoch 75/100\n",
987 | "6/6 [==============================] - 3s 449ms/step - loss: 0.0938 - accuracy: 0.7339\n",
988 | "Epoch 76/100\n",
989 | "6/6 [==============================] - 3s 448ms/step - loss: 0.0907 - accuracy: 0.7339\n",
990 | "Epoch 77/100\n",
991 | "6/6 [==============================] - 3s 448ms/step - loss: 0.0879 - accuracy: 0.7339\n",
992 | "Epoch 78/100\n",
993 | "6/6 [==============================] - 3s 450ms/step - loss: 0.0853 - accuracy: 0.7339\n",
994 | "Epoch 79/100\n",
995 | "6/6 [==============================] - 3s 449ms/step - loss: 0.0829 - accuracy: 0.7339\n",
996 | "Epoch 80/100\n",
997 | "6/6 [==============================] - 3s 449ms/step - loss: 0.0805 - accuracy: 0.7339\n",
998 | "Epoch 81/100\n",
999 | "6/6 [==============================] - 3s 451ms/step - loss: 0.0784 - accuracy: 0.7339\n",
1000 | "Epoch 82/100\n",
1001 | "6/6 [==============================] - 3s 450ms/step - loss: 0.0762 - accuracy: 0.7339\n",
1002 | "Epoch 83/100\n",
1003 | "6/6 [==============================] - 3s 448ms/step - loss: 0.0742 - accuracy: 0.7339\n",
1004 | "Epoch 84/100\n",
1005 | "6/6 [==============================] - 3s 448ms/step - loss: 0.0722 - accuracy: 0.7339\n",
1006 | "Epoch 85/100\n",
1007 | "6/6 [==============================] - 3s 451ms/step - loss: 0.0703 - accuracy: 0.7339\n",
1008 | "Epoch 86/100\n",
1009 | "6/6 [==============================] - 3s 451ms/step - loss: 0.0686 - accuracy: 0.7339\n",
1010 | "Epoch 87/100\n",
1011 | "6/6 [==============================] - 3s 450ms/step - loss: 0.0669 - accuracy: 0.7339\n",
1012 | "Epoch 88/100\n",
1013 | "6/6 [==============================] - 3s 450ms/step - loss: 0.0654 - accuracy: 0.7339\n",
1014 | "Epoch 89/100\n",
1015 | "6/6 [==============================] - 3s 449ms/step - loss: 0.0638 - accuracy: 0.7339\n",
1016 | "Epoch 90/100\n",
1017 | "6/6 [==============================] - 3s 449ms/step - loss: 0.0622 - accuracy: 0.7339\n",
1018 | "Epoch 91/100\n",
1019 | "6/6 [==============================] - 3s 448ms/step - loss: 0.0609 - accuracy: 0.7339\n",
1020 | "Epoch 92/100\n",
1021 | "6/6 [==============================] - 3s 448ms/step - loss: 0.0595 - accuracy: 0.7339\n",
1022 | "Epoch 93/100\n",
1023 | "6/6 [==============================] - 3s 449ms/step - loss: 0.0581 - accuracy: 0.7339\n",
1024 | "Epoch 94/100\n",
1025 | "6/6 [==============================] - 3s 447ms/step - loss: 0.0568 - accuracy: 0.7339\n",
1026 | "Epoch 95/100\n",
1027 | "6/6 [==============================] - 3s 450ms/step - loss: 0.0558 - accuracy: 0.7338\n",
1028 | "Epoch 96/100\n",
1029 | "6/6 [==============================] - 3s 448ms/step - loss: 0.0545 - accuracy: 0.7339\n",
1030 | "Epoch 97/100\n",
1031 | "6/6 [==============================] - 3s 446ms/step - loss: 0.0533 - accuracy: 0.7339\n",
1032 | "Epoch 98/100\n",
1033 | "6/6 [==============================] - 3s 449ms/step - loss: 0.0524 - accuracy: 0.7339\n",
1034 | "Epoch 99/100\n",
1035 | "6/6 [==============================] - 3s 447ms/step - loss: 0.0512 - accuracy: 0.7339\n",
1036 | "Epoch 100/100\n",
1037 | "6/6 [==============================] - 3s 448ms/step - loss: 0.0503 - accuracy: 0.7339\n"
1038 | ],
1039 | "name": "stdout"
1040 | },
1041 | {
1042 | "output_type": "execute_result",
1043 | "data": {
1044 | "text/plain": [
1045 | ""
1046 | ]
1047 | },
1048 | "metadata": {
1049 | "tags": []
1050 | },
1051 | "execution_count": 4
1052 | }
1053 | ]
1054 | },
1055 | {
1056 | "cell_type": "code",
1057 | "metadata": {
1058 | "id": "EVnKp6TB3AJI"
1059 | },
1060 | "source": [
1061 | "!pip install https://github.com/mymusise/transformers/archive/fix_mixed_precision4gpt2.zip"
1062 | ],
1063 | "execution_count": null,
1064 | "outputs": []
1065 | },
1066 | {
1067 | "cell_type": "code",
1068 | "metadata": {
1069 | "colab": {
1070 | "base_uri": "https://localhost:8080/"
1071 | },
1072 | "id": "Eau9SymJ4LF2",
1073 | "outputId": "578c4299-b8a2-4c71-d879-4e14407f436e"
1074 | },
1075 | "source": [
1076 | "#@title mixed_precision after train\n",
1077 | "\n",
1078 | "from transformers import GPT2Tokenizer, TFGPT2LMHeadModel, GPT2Config\n",
1079 | "from tensorflow.keras.mixed_precision import experimental, global_policy\n",
1080 | "\n",
1081 | "\n",
1082 | "policy = experimental.Policy('mixed_float16')\n",
1083 | "experimental.set_policy(policy)\n",
1084 | "current_policy = global_policy()\n",
1085 | "print(f\"current_policy: {current_policy}\")\n",
1086 | "\n",
1087 | "\n",
1088 | "config = GPT2Config(\n",
1089 | " vocab_size=tokenizer.vocab_size,\n",
1090 | " n_positions=512,\n",
1091 | " n_ctx=512,\n",
1092 | " n_embd=512,\n",
1093 | " n_layer=4,\n",
1094 | " n_head=4,\n",
1095 | " pad_token_id=tokenizer.pad_token_id,\n",
1096 | " use_cache=False,\n",
1097 | " )\n",
1098 | "model = TFGPT2LMHeadModel(config)\n",
1099 | "\n",
1100 | "optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)\n",
1101 | "loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
1102 | "metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')\n",
1103 | "model.compile(optimizer=optimizer, loss=[loss, *[None] * model.config.n_layer], metrics=[metric])\n",
1104 | "model.fit(dataset, epochs=20)"
1105 | ],
1106 | "execution_count": 3,
1107 | "outputs": [
1108 | {
1109 | "output_type": "stream",
1110 | "text": [
1111 | "WARNING:tensorflow:Mixed precision compatibility check (mixed_float16): WARNING\n",
1112 | "Your GPU may run slowly with dtype policy mixed_float16 because it does not have compute capability of at least 7.0. Your GPU:\n",
1113 | " Tesla K80, compute capability 3.7\n",
1114 | "See https://developer.nvidia.com/cuda-gpus for a list of GPUs and their compute capabilities.\n",
1115 | "If you will use compatible GPU(s) not attached to this host, e.g. by running a multi-worker model, you can ignore this warning. This message will only be logged once\n",
1116 | "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/mixed_precision/loss_scale.py:56: DynamicLossScale.__init__ (from tensorflow.python.training.experimental.loss_scale) is deprecated and will be removed in a future version.\n",
1117 | "Instructions for updating:\n",
1118 | "Use tf.keras.mixed_precision.LossScaleOptimizer instead. LossScaleOptimizer now has all the functionality of DynamicLossScale\n",
1119 | "current_policy: \n",
1120 | "WARNING:tensorflow:tf.keras.mixed_precision.experimental.LossScaleOptimizer is deprecated. Please use tf.keras.mixed_precision.LossScaleOptimizer instead. Note that the non-experimental LossScaleOptimizer does not take a DynamicLossScale but instead takes the dynamic configuration directly in the constructor. For example:\n",
1121 | " opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt)\n",
1122 | "\n",
1123 | "Epoch 1/20\n",
1124 | "WARNING:tensorflow:The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).\n",
1125 | "WARNING:tensorflow:AutoGraph could not transform > and will run it as-is.\n",
1126 | "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
1127 | "Cause: module, class, method, function, traceback, frame, or code object was expected, got cython_function_or_method\n",
1128 | "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
1129 | "WARNING: AutoGraph could not transform > and will run it as-is.\n",
1130 | "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
1131 | "Cause: module, class, method, function, traceback, frame, or code object was expected, got cython_function_or_method\n",
1132 | "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
1133 | "WARNING:tensorflow:AutoGraph could not transform and will run it as-is.\n",
1134 | "Cause: while/else statement not yet supported\n",
1135 | "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
1136 | "WARNING: AutoGraph could not transform and will run it as-is.\n",
1137 | "Cause: while/else statement not yet supported\n",
1138 | "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convertWARNING:tensorflow:The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.\n",
1139 | "\n",
1140 | "WARNING:tensorflow:The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).\n",
1141 | "WARNING:tensorflow:The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.\n",
1142 | "6/6 [==============================] - 59s 2s/step - loss: 10.5825 - accuracy: 0.0305\n",
1143 | "Epoch 2/20\n",
1144 | "6/6 [==============================] - 11s 2s/step - loss: 9.7798 - accuracy: 0.0813\n",
1145 | "Epoch 3/20\n",
1146 | "6/6 [==============================] - 11s 2s/step - loss: 9.4015 - accuracy: 0.0863\n",
1147 | "Epoch 4/20\n",
1148 | "6/6 [==============================] - 11s 2s/step - loss: 9.0140 - accuracy: 0.1259\n",
1149 | "Epoch 5/20\n",
1150 | "6/6 [==============================] - 11s 2s/step - loss: 8.5197 - accuracy: 0.2167\n",
1151 | "Epoch 6/20\n",
1152 | "6/6 [==============================] - 11s 2s/step - loss: 7.9820 - accuracy: 0.3299\n",
1153 | "Epoch 7/20\n",
1154 | "6/6 [==============================] - 11s 2s/step - loss: 7.4249 - accuracy: 0.4427\n",
1155 | "Epoch 8/20\n",
1156 | "6/6 [==============================] - 11s 2s/step - loss: 6.9011 - accuracy: 0.5430\n",
1157 | "Epoch 9/20\n",
1158 | "6/6 [==============================] - 11s 2s/step - loss: 6.4375 - accuracy: 0.6279\n",
1159 | "Epoch 10/20\n",
1160 | "6/6 [==============================] - 11s 2s/step - loss: 6.0271 - accuracy: 0.7055\n",
1161 | "Epoch 11/20\n",
1162 | "6/6 [==============================] - 11s 2s/step - loss: 5.6643 - accuracy: 0.7885\n",
1163 | "Epoch 12/20\n",
1164 | "6/6 [==============================] - 11s 2s/step - loss: 5.3339 - accuracy: 0.8529\n",
1165 | "Epoch 13/20\n",
1166 | "6/6 [==============================] - 11s 2s/step - loss: 5.0323 - accuracy: 0.9020\n",
1167 | "Epoch 14/20\n",
1168 | "6/6 [==============================] - 11s 2s/step - loss: 4.7540 - accuracy: 0.9327\n",
1169 | "Epoch 15/20\n",
1170 | "6/6 [==============================] - 11s 2s/step - loss: 4.4923 - accuracy: 0.9497\n",
1171 | "Epoch 16/20\n",
1172 | "6/6 [==============================] - 11s 2s/step - loss: 4.2491 - accuracy: 0.9646\n",
1173 | "Epoch 17/20\n",
1174 | "6/6 [==============================] - 11s 2s/step - loss: 4.0191 - accuracy: 0.9744\n",
1175 | "Epoch 18/20\n",
1176 | "6/6 [==============================] - 11s 2s/step - loss: 3.8010 - accuracy: 0.9819\n",
1177 | "Epoch 19/20\n",
1178 | "6/6 [==============================] - 11s 2s/step - loss: 3.5975 - accuracy: 0.9866\n",
1179 | "Epoch 20/20\n",
1180 | "6/6 [==============================] - 11s 2s/step - loss: 3.3968 - accuracy: 0.9904\n"
1181 | ],
1182 | "name": "stdout"
1183 | },
1184 | {
1185 | "output_type": "execute_result",
1186 | "data": {
1187 | "text/plain": [
1188 | ""
1189 | ]
1190 | },
1191 | "metadata": {
1192 | "tags": []
1193 | },
1194 | "execution_count": 3
1195 | }
1196 | ]
1197 | }
1198 | ]
1199 | }
1200 |
--------------------------------------------------------------------------------