├── requirements.txt ├── readme.md ├── deepspeed.yaml ├── config.yaml ├── sentiments.py ├── scripts ├── graph_plot.py └── graph_plot.svg ├── captions.py ├── utils.py ├── randomwalks.py ├── carps.py ├── ilql.py └── models.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.12.1+cu113 2 | transformers==4.21.1 3 | accelerate==0.12.0 4 | deepspeed==0.7.0 5 | datasets==2.4.0 6 | tokenizers==0.12.1 7 | tqdm==4.64.0 8 | wandb==0.13.2 9 | networkx==2.8.6 10 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | Simplified implementation of [Implicit Language Q Learning (Snell et al. 2022)](https://sea-snell.github.io/ILQL_site/) ([official](https://github.com/Sea-Snell/Implicit-Language-Q-Learning/), [paper](https://arxiv.org/abs/2206.11871)) 2 | 3 | Evaluating on Graph Shortest Path task from [Decision Transformer (Lili Chen et al. 2021)](https://arxiv.org/abs/2106.01345): 4 | 5 | 6 | 7 | where for each random graph, a transformer is trained to find optimal trajectories using only 1000 random walks. 8 | -------------------------------------------------------------------------------- /deepspeed.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_config_file: zero3.json 4 | deepspeed_multinode_launcher: standard 5 | gradient_accumulation_steps: 1 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | gradient_clipping: 1.0 9 | zero3_init_flag: false 10 | zero_stage: 3 11 | distributed_type: DEEPSPEED 12 | downcast_bf16: no 13 | fsdp_config: {} 14 | machine_rank: ??? 15 | main_process_ip: ??? 16 | main_process_port: 1234 17 | main_training_function: main 18 | mixed_precision: bf16 19 | num_machines: ??? 20 | num_processes: ??? 21 | rdzv_backend: static 22 | same_network: true 23 | use_cpu: false 24 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | RandomWalks: 2 | lr: 0.001 3 | opt_betas: [0.9, 0.95] 4 | batch_size: 500 5 | tau: 0.7 6 | gamma: 0.99 7 | cql_scale: 0.1 8 | awac_scale: 1 9 | alpha: 0.1 10 | steps_for_target_q_sync: 10 11 | steps_for_eval: 5 12 | inference_betas: [1] 13 | n_epochs: 100 14 | seed: 1000 15 | n_layers_unfrozen: 0 16 | two_qs: true 17 | gptconfig: 18 | n_embd: 144 19 | n_layer: 6 20 | n_head: 1 21 | 22 | Sentiments: 23 | lr: 0.0006 24 | opt_betas: [0.9, 0.95] 25 | batch_size: 6 26 | tau: 0.5 27 | gamma: 0.99 28 | cql_scale: 0.1 29 | awac_scale: 1 30 | alpha: 1 31 | steps_for_target_q_sync: 40 32 | steps_for_eval: 100 33 | inference_betas: [1] 34 | model: EleutherAI/gpt-j-6B 35 | n_layers_unfrozen: 0 36 | two_qs: true 37 | n_epochs: 1 38 | 39 | Carps: 40 | lr: 0.00004 41 | opt_betas: [0.9, 0.95] 42 | batch_size: 72 43 | tau: 0.5 44 | cql_scale: 0.1 45 | alpha: 1 46 | steps_for_target_q_sync: 100 47 | inference_betas: [0, 1, 2, 4] 48 | n_epochs: 10 49 | model: gpt2-large 50 | n_layers_unfrozen: 2 51 | max_length: 48 52 | diff_reward: true 53 | 54 | Captions: 55 | lr: 0.0006 56 | opt_betas: [0.9, 0.95] 57 | batch_size: 16 58 | tau: 0.7 59 | gamma: 0.99 60 | cql_scale: 0.1 61 | awac_scale: 1 62 | alpha: 1 63 | steps_for_target_q_sync: 50 64 | steps_for_eval: 100 65 | inference_betas: [0, 1] 66 | model: gpt2-large 67 | n_layers_unfrozen: 0 68 | two_qs: true 69 | n_epochs: 1 70 | -------------------------------------------------------------------------------- /sentiments.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch as th 4 | from torch import tensor 5 | import torch.nn.functional as F 6 | from torch.utils.data import TensorDataset 7 | from functools import partial, reduce 8 | 9 | from transformers import AutoTokenizer, pipeline 10 | from datasets import load_dataset 11 | import wandb 12 | from tqdm import tqdm 13 | import math 14 | 15 | from utils import batch_map, tohuman, load_tensors 16 | 17 | def get_reward(sentiment_pipe, texts): 18 | sentiments = batch_map(lambda batch: sentiment_pipe(batch), texts, bsize=1024, desc='Sentiments') 19 | return tensor([-s['score'] if s['label'] == 'NEGATIVE' else s['score'] for s in sentiments]) 20 | 21 | class Sentiments: 22 | def __init__(self, tokenizer: AutoTokenizer, max_length=50, n_samples=64, needs_reward_model=False): 23 | self.max_length = max_length 24 | self.tokenizer = tokenizer 25 | 26 | if needs_reward_model: 27 | self.sentiment_pipe = pipeline('sentiment-analysis', 'lvwerra/distilbert-imdb', device=th.device(0)) 28 | else: 29 | self.sentiment_pipe = None 30 | 31 | texts = load_dataset('imdb', split='train+test')['text'] 32 | tensors = load_tensors( 33 | 'sentiments', 34 | texts=texts, 35 | reward_model=partial(get_reward, self.sentiment_pipe), 36 | tokenizer=self.tokenizer, 37 | max_length=max_length, 38 | use_cache=True 39 | ) 40 | 41 | query = tensor([self.tokenizer.bos_token_id] * n_samples).view(n_samples, 1) 42 | self.logit_mask = None 43 | 44 | self.dataset = TensorDataset(tensors['input_ids'], tensors['attention_mask'], tensors['rewards']) 45 | self.eval_dataset = TensorDataset(query) 46 | 47 | def eval(self, samples, beta): 48 | reviews = self.tokenizer.batch_decode(samples, skip_special_tokens=True) 49 | 50 | rewards = [1-s['score'] if s['label'] == 'NEGATIVE' else s['score'] for s in self.sentiment_pipe(reviews)] 51 | reward = np.mean(rewards) 52 | 53 | rows = list(zip(reviews, rewards)) 54 | print(f'\n{beta=} {reward=:.2f}\n' + '\n'.join([f'[{sent:.2f}] {text}' for text, sent in rows[:8]])) 55 | 56 | stats = { f'reward/{beta}': reward, 57 | f'responses/{beta}': wandb.Table(columns=['response', 'sentiment'], rows=rows[:32]) } 58 | 59 | return reward, stats 60 | 61 | if __name__ == '__main__': 62 | import sys 63 | from rich import print 64 | tokenizer = AutoTokenizer.from_pretrained(sys.argv[1]) 65 | tokenizer.pad_token = tokenizer.eos_token 66 | ds = Sentiments(tokenizer, needs_reward_model=True).dataset 67 | print(f'{next(iter(ds))=}') 68 | -------------------------------------------------------------------------------- /scripts/graph_plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | from ilql import main 5 | from torch import tensor 6 | import torch as th 7 | import networkx as nx 8 | import numpy as np 9 | 10 | optimal_lengths = [] 11 | sampled_lengths = [] 12 | iql_lengths = [] 13 | 14 | for seed in range(10): 15 | model, data = main(seed=seed, debug=True) 16 | model.eval() 17 | 18 | g = nx.from_numpy_array(data.adj, create_using=nx.DiGraph) 19 | 20 | # optimal 21 | for start in set(range(data.n_nodes)) - {data.goal}: 22 | try: 23 | shortest_path = nx.shortest_path(g, start, data.goal)[:data.walk_size] 24 | optimal_lengths.append(len(shortest_path)-1) 25 | except: 26 | optimal_lengths.append(data.walk_size) 27 | 28 | # ilql 29 | starts = th.arange(1, data.n_nodes).unsqueeze(1).to(model.device) 30 | paths, _ = model.sample(starts, max_length=data.walk_size, logit_mask=tensor(~data.adj), beta=10) # argmax 31 | for path in paths: 32 | length = data.walk_size 33 | for ind, node in enumerate(path): 34 | if node == data.goal: 35 | length = ind 36 | break 37 | 38 | iql_lengths.append(length) 39 | 40 | # all samples 41 | for path in data.tensors[0]: 42 | length = data.walk_size 43 | for ind, node in enumerate(path): 44 | if node == data.goal: 45 | length = ind 46 | break 47 | 48 | sampled_lengths.append(length) 49 | # ■ ~ 50 | 51 | from matplotlib import pyplot 52 | import matplotlib 53 | 54 | fontcolor = '#444' 55 | matplotlib.rcParams['text.color'] = fontcolor 56 | matplotlib.rcParams['axes.labelcolor'] = fontcolor 57 | matplotlib.rcParams['xtick.color'] = fontcolor 58 | matplotlib.rcParams['ytick.color'] = fontcolor 59 | matplotlib.rcParams['xtick.labelcolor'] = fontcolor 60 | matplotlib.rcParams['ytick.labelcolor'] = fontcolor 61 | matplotlib.rcParams['xtick.labelcolor'] = fontcolor 62 | 63 | matplotlib.rcParams["font.family"] = "Futura" 64 | matplotlib.rcParams["font.size"] = 15 65 | matplotlib.rcParams["xtick.labelsize"] = 20 66 | matplotlib.rcParams["ytick.labelsize"] = 20 67 | matplotlib.rcParams["figure.titlesize"] = 12 68 | matplotlib.rcParams["figure.figsize"] = 15, 8 69 | 70 | matplotlib.style.use('ggplot') 71 | matplotlib.rcParams['figure.dpi'] = 70 72 | 73 | ax = pyplot.gca() 74 | ax.set_facecolor('#fff') 75 | ax.grid(color='lightgray', alpha=0.4, axis='y') 76 | ax.tick_params(top=False, labeltop=False, bottom=False, labelbottom=True, left=False, labelleft=True) 77 | 78 | optimal_hist = np.histogram(optimal_lengths, bins=np.arange(1, data.walk_size+2), density=True)[0] 79 | sampled_hist = np.histogram(sampled_lengths, bins=np.arange(1, data.walk_size+2), density=True)[0] 80 | iql_hist = np.histogram(iql_lengths, bins=np.arange(1, data.walk_size+2), density=True)[0] 81 | 82 | barsize = 0.36 83 | iql_color = '#99a3fd' 84 | opt_color = '#f2ad48' 85 | random_color='lightgray' 86 | 87 | pyplot.bar(np.arange(1, data.walk_size+1)-barsize/1.5, optimal_hist, width=barsize, label='shortest path', color=opt_color, zorder=2) 88 | pyplot.bar(np.arange(1, data.walk_size+1), iql_hist, width=barsize, label='ILQL', color=iql_color, zorder=3) 89 | pyplot.bar(np.arange(1, data.walk_size+1)+barsize/1.5, sampled_hist, width=barsize, label='random walk', color=random_color, zorder=1) 90 | 91 | pyplot.legend(fontsize=16) 92 | pyplot.xticks(np.arange(1, data.walk_size+1), list(np.arange(1, data.walk_size)) + ['∞']) 93 | 94 | pyplot.xlabel('# of steps to goal', fontsize=22, color=fontcolor, labelpad=20) 95 | pyplot.ylabel('proportion of paths', fontsize=22, color=fontcolor, labelpad=20) 96 | 97 | pyplot.savefig('scripts/graph_plot.svg') 98 | -------------------------------------------------------------------------------- /captions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch as th 4 | from torch import tensor 5 | import torch.nn.functional as F 6 | from torch.utils.data import TensorDataset 7 | 8 | import wandb 9 | import sqlite3 10 | 11 | from utils import batch_map, tohuman 12 | 13 | def load_tensors(cache_path, tokenizer, use_cache=True): 14 | cache_path = f'{cache_path}_tokenizer={tokenizer.name_or_path}.pt' 15 | 16 | if use_cache and os.path.exists(cache_path): 17 | out = th.load(cache_path) 18 | else: 19 | conn = sqlite3.connect('sac_public_2022_06_29.sqlite') 20 | c = conn.cursor() 21 | c.execute("SELECT prompt, rating FROM ratings " 22 | "JOIN images ON images.id=ratings.iid " 23 | "JOIN generations ON images.gid=generations.id " 24 | "WHERE rating IS NOT NULL;") 25 | 26 | prompts, ratings = tuple(map(list, zip(*filter(lambda x: len(x[0]) > 10, c.fetchall())))) 27 | 28 | out = tokenizer(prompts, padding=True, return_tensors='pt') 29 | input_ids, attention_mask = out['input_ids'], out['attention_mask'] 30 | 31 | # append eos 32 | input_ids = F.pad(input_ids, (0, 1), value=tokenizer.eos_token_id) 33 | attention_mask = F.pad(attention_mask, (0, 1), value=0) 34 | 35 | # figure out sentences' endings 36 | diff_padding = th.zeros(input_ids.shape[0], 1, dtype=th.long) 37 | endings = input_ids.eq(tokenizer.eos_token_id).diff(prepend=diff_padding, dim=-1).nonzero(as_tuple=True) 38 | 39 | ratings = tensor(ratings, dtype=th.float32).view(-1, 1) 40 | ratings = (ratings - ratings.mean()) / (ratings.std() + 1e-100) 41 | rewards = ratings.repeat(1, input_ids.shape[1]) 42 | 43 | # zero padding 44 | rewards[input_ids.eq(tokenizer.pad_token_id).nonzero(as_tuple=True)] = 0 45 | # refill rewards for the actual eos in case pad == eos 46 | rewards[endings] = ratings.view(-1) 47 | 48 | # prepend bos 49 | input_ids = F.pad(input_ids, (1, 0), value=tokenizer.eos_token_id) 50 | attention_mask = F.pad(attention_mask, (1, 0), value=1) 51 | 52 | out = {'input_ids': input_ids, 'attention_mask': attention_mask, 'rewards': rewards} 53 | 54 | if not os.path.exists(os.path.dirname(cache_path)): 55 | os.mkdir(os.path.dirname(cache_path)) 56 | 57 | th.save(out, cache_path) 58 | 59 | print(f"Total {tohuman(np.prod(out['input_ids'].shape))} tokens") 60 | return out 61 | 62 | class AestheticCaptions(TensorDataset): 63 | def __init__(self, tokenizer, max_length=77, n_samples=32, batch_size=1, use_cache=True): 64 | self.max_length = max_length 65 | self.n_samples = n_samples 66 | self.batch_size = batch_size 67 | self.tokenizer = tokenizer 68 | 69 | tensors = load_tensors('cache/aesthetic-captions', self.tokenizer, use_cache=use_cache) 70 | super().__init__(tensors['input_ids'], tensors['attention_mask'], tensors['rewards']) 71 | 72 | def eval(self, logs, model, betas=[1]): 73 | query = tensor([self.tokenizer.eos_token_id] * self.n_samples, device=model.device).view(self.n_samples, 1) 74 | 75 | for beta in betas: 76 | responses = batch_map( 77 | lambda batch: model.sample(query, beta=beta, max_length=self.max_length, 78 | eos_token_id=self.tokenizer.eos_token_id)[0], 79 | query, bsize=self.batch_size, desc='Generating') 80 | 81 | responses = self.tokenizer.batch_decode(responses, skip_special_tokens=True) 82 | 83 | print(f'\n{beta=}\n' + '\n'.join([f'{text}' for text in responses[:16]])) 84 | 85 | logs.update({f'responses/{beta}': wandb.Table(columns=['response'], data=[[r] for r in responses[:32]])}) 86 | 87 | return np.inf, {} 88 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch as th 4 | from tqdm import tqdm 5 | import math 6 | from contextlib import contextmanager 7 | import torch.nn.functional as F 8 | from time import time 9 | import deepspeed 10 | 11 | try: 12 | __IPYTHON__ 13 | run_from_ipython = True 14 | except NameError: 15 | run_from_ipython = False 16 | 17 | def randexclude(rng: np.random.RandomState, n: int, exclude: int) -> int: 18 | while True: 19 | x = rng.randint(n) 20 | if x != exclude: 21 | return x 22 | 23 | def tohuman(n: int) -> str: 24 | if n > 1e9: 25 | return f'{n / 1e9:.1f}B' 26 | elif n > 1e6: 27 | return f'{n / 1e6:.1f}M' 28 | elif n > 1e3: 29 | return f'{n / 1e3:.1f}K' 30 | return str(n) 31 | 32 | def logvars(name, logs, xs): 33 | xs = th.vstack(xs) 34 | logs.update({ f'{name}-mean': xs.mean(), 35 | f'{name}-std': xs.std(), 36 | f'{name}-min': xs.min(), 37 | f'{name}-max': xs.max() }) 38 | 39 | def batch_map(fn, xs, bsize: int, desc=None): 40 | out = [] 41 | for ind in tqdm(range(math.ceil(len(xs) / bsize)), desc=desc, disable=not desc): 42 | batch = xs[ind*bsize:min(len(xs), (ind+1)*bsize)] 43 | out.extend(fn(batch)) 44 | 45 | return out 46 | 47 | def load_tensors(name, texts, reward_model, tokenizer, max_length=64, use_cache=True): 48 | cache_path = f'cache/{name}_{max_length=}_tokenizer={tokenizer.name_or_path.split("/")[-1]}.pt' 49 | if use_cache and os.path.exists(cache_path): 50 | tensors = th.load(cache_path) 51 | else: 52 | tensors = tokenizer( 53 | [tokenizer.bos_token + x for x in texts], 54 | max_length=max_length, 55 | truncation=True, 56 | padding=True, 57 | return_tensors='pt' 58 | ) 59 | 60 | trimmed_texts = tokenizer.batch_decode(tensors['input_ids'], skip_special_tokens=True) 61 | rewards = th.as_tensor(reward_model(trimmed_texts)) 62 | rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-30) 63 | rewards = rewards.view(-1, 1).repeat(1, tensors['input_ids'].shape[1]) 64 | rewards[tensors['attention_mask'].eq(0)] = 0 65 | 66 | tensors['rewards'] = rewards 67 | tensors['attention_mask'] = F.pad(tensors['attention_mask'], (0, 1), value=0) 68 | tensors['input_ids'] = F.pad(tensors['input_ids'], (0, 1), value=tokenizer.eos_token_id) 69 | 70 | if not os.path.exists(os.path.dirname(cache_path)): 71 | os.mkdir(os.path.dirname(cache_path)) 72 | 73 | th.save(tensors, cache_path) 74 | 75 | print(f"{tohuman(np.prod(tensors['input_ids'].shape))} tokens") 76 | return tensors 77 | 78 | 79 | def isdelim(c: str): 80 | return c == '?' or c == '!' or c == '.' or c == ';' 81 | 82 | def pprint(s): 83 | trig = False 84 | si = 0 85 | l = len(s)-1 86 | 87 | for i in range(len(s)): 88 | if i == l: 89 | print(s[si:].strip()) 90 | 91 | elif trig or isdelim(s[i]): 92 | trig = True 93 | 94 | if s[i].isspace(): 95 | print(s[si:i+1].strip()) 96 | si = i + 1 97 | trig = False 98 | 99 | @contextmanager 100 | def timeit(desc='something important'): 101 | print(f'{desc}...') 102 | stime = time() 103 | try: 104 | yield None 105 | finally: 106 | print(f'done with {desc.lower()} in {time() - stime:.1f}s') 107 | 108 | def check_weights(param): 109 | if os.environ.get('DEEPSPEED_ZERO_STAGE', '0') == '3': 110 | with deepspeed.zero.GatheredParameters(param[0].weight, modifier_rank=0): 111 | if deepspeed.comm.get_rank() == 0: 112 | return param[0].weight.sum() 113 | else: 114 | return param[0].weight.sum() 115 | -------------------------------------------------------------------------------- /randomwalks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | from torch import tensor 4 | from torch.utils.data import TensorDataset 5 | import torch.nn.functional as F 6 | from utils import randexclude 7 | import networkx as nx 8 | 9 | # Toy dataset from Decision Transformer (Chen et. al 2021) 10 | class RandomWalks: 11 | def __init__(self, n_nodes=20, max_length=10, n_walks=1000, p_edge=0.1, seed=1002): 12 | self.n_nodes = n_nodes 13 | self.n_walks = n_walks 14 | self.max_length = max_length 15 | rng = np.random.RandomState(seed) 16 | 17 | walks, rewards = [], [] 18 | while True: 19 | self.adj = rng.rand(n_nodes, n_nodes) > (1 - p_edge) 20 | np.fill_diagonal(self.adj, 0) 21 | if np.all(self.adj.sum(1)): break 22 | 23 | # terminal state 24 | self.adj[0, :] = 0 25 | self.adj[0, 0] = 1 26 | 27 | self.goal = 0 28 | for _ in range(n_walks): 29 | node = randexclude(rng, n_nodes, self.goal) 30 | walk = [node] 31 | 32 | for istep in range(max_length-1): 33 | node = rng.choice(np.nonzero(self.adj[node])[0]) 34 | walk.append(node) 35 | if node == self.goal: 36 | break 37 | 38 | r = th.zeros(max_length-1) 39 | r[:len(walk)-1] = -1 if walk[-1] == self.goal else -100 40 | 41 | rewards.append(r) 42 | walks.append(walk) 43 | 44 | states = [] 45 | attention_masks = [] 46 | 47 | for r, walk in zip(rewards, map(th.tensor, walks)): 48 | attention_mask = th.zeros(max_length, dtype=int) 49 | attention_mask[:len(walk)] = 1 50 | 51 | attention_masks.append(attention_mask) 52 | states.append(F.pad(walk, (0, max_length-len(walk)))) 53 | 54 | self.worstlen = self.max_length 55 | self.avglen = sum(map(len, walks)) / self.n_walks 56 | self.bestlen = 0 57 | g = nx.from_numpy_array(self.adj, create_using=nx.DiGraph) 58 | for start in set(range(self.n_nodes)) - {self.goal}: 59 | try: 60 | shortest_path = nx.shortest_path(g, start, self.goal)[:self.max_length] 61 | self.bestlen += len(shortest_path) 62 | except: 63 | self.bestlen += self.max_length 64 | 65 | self.bestlen /= self.n_nodes - 1 66 | 67 | print(f'{self.n_walks} walks of which {(np.array([r[0] for r in rewards])==-1).mean()*100:.0f}% arrived at destination') 68 | 69 | # disallows selecting unaccessible nodes in a graph 70 | self.logit_mask = tensor(~self.adj) 71 | 72 | self.dataset = TensorDataset(th.stack(states), th.stack(attention_masks), th.stack(rewards)) 73 | self.eval_dataset = TensorDataset(th.arange(1, self.n_nodes).unsqueeze(1)) 74 | 75 | def render(self): 76 | from matplotlib import pyplot 77 | 78 | g = nx.from_numpy_array(self.adj, create_using=nx.DiGraph) 79 | pos = nx.spring_layout(g, seed=7357) 80 | 81 | pyplot.figure(figsize=(10, 8)) 82 | nx.draw_networkx_edges(g, pos=pos, alpha=0.5, width=1, edge_color='#d3d3d3') 83 | nx.draw_networkx_nodes(g, nodelist=set(range(len(self.adj))) - {self.goal}, pos=pos, node_size=300, node_color='orange') 84 | nx.draw_networkx_nodes(g, nodelist=[self.goal], pos=pos, node_size=300, node_color='darkblue') 85 | pyplot.show() 86 | 87 | def eval(self, samples, beta): 88 | narrived = 0 89 | actlen = 0 90 | for node in range(self.n_nodes-1): 91 | for istep in range(self.max_length): 92 | if samples[node, istep] == self.goal: 93 | narrived += 1 94 | break 95 | 96 | actlen += (istep + 1) / (self.n_nodes - 1) 97 | 98 | current = (self.worstlen - actlen)/(self.worstlen - self.bestlen) 99 | average = (self.worstlen - self.avglen)/(self.worstlen - self.bestlen) 100 | 101 | stats = { 'actlen': actlen, 102 | 'avglen': self.avglen, 103 | 'bestlen': self.bestlen, 104 | 'worstlen': self.worstlen, 105 | 'arrived': f'{narrived / (self.n_nodes-1) * 100:.0f}%', 106 | 'optimal': f'{current*100:.0f}% > {average*100:.0f}%' } 107 | 108 | return -actlen, stats 109 | -------------------------------------------------------------------------------- /carps.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch as th 4 | from torch import tensor 5 | import torch.nn.functional as F 6 | from torch.utils.data import Dataset 7 | from functools import partial, reduce 8 | from transformers import AutoTokenizer 9 | from datasets import load_dataset 10 | from util.carp_util import load_carp, scorer 11 | import wandb 12 | 13 | tokenizer = AutoTokenizer.from_pretrained('gpt2') 14 | 15 | carp = load_carp( 16 | model_type='coop', 17 | config_path='ControlledCarp/magiCARP/configs/coop/alignment_coop.yml', 18 | ckpt_path='New_Alignment_CoOp_Carp_L/' 19 | ).to('cuda') 20 | carp.eval() 21 | 22 | def clean_text(text): 23 | return '. '.join(map( 24 | lambda x: x.strip(), 25 | text.replace(' . ', '. ').replace(' , ', ', ').replace(" '", "'").replace(" n't", "n't").split('. ') 26 | )) 27 | 28 | def sizesplit(size: int, xs): 29 | for ind in range(len(xs) // size + int((len(xs) % size) > 0)): 30 | yield xs[ind*size:min(len(xs), (ind+1)*size)] 31 | 32 | def topk_mask(xs, k): 33 | mintop = th.topk(xs, k)[0][:, -1].unsqueeze(-1) 34 | return th.where(xs < mintop, -np.inf * th.ones_like(xs, dtype=xs.dtype), xs) 35 | 36 | def tokenize(max_length, diff_reward, offset_reward, review, sample): 37 | text = clean_text(sample['text']) 38 | tokens = tokenizer.encode(text, return_tensors='pt')[:, :max_length-1] 39 | tokens = F.pad(tokens, (0, max_length-tokens.shape[1]-1), value=tokenizer.eos_token_id) 40 | 41 | if diff_reward: 42 | substrings = [] 43 | newtext = "" 44 | for token in tokens[0]: 45 | newtext += tokenizer.decode(token) 46 | substrings.append(newtext) 47 | 48 | rewards = carp_score(substrings, review).cpu() 49 | rewards = th.hstack((tensor([offset_reward]), rewards)).diff() 50 | rewards = th.where(tokens[0] == tokenizer.eos_token_id, 0, rewards) 51 | 52 | else: 53 | r = carp_score(text, review).item() 54 | rewards = th.empty(max_length-1) 55 | rewards.fill_(r) 56 | rewards[tokens[0] == tokenizer.eos_token_id] = 0 57 | 58 | attn = [1] * max_length 59 | attn[-1] = 0 60 | sample['text'] = text 61 | sample['tokens'] = th.hstack((tensor([[tokenizer.eos_token_id]]), tokens)) 62 | sample['attention'] = attn 63 | sample['rewards'] = rewards 64 | return sample 65 | 66 | @th.inference_mode() 67 | def carp_score(texts, review): 68 | return scorer(texts, [review], carp, mode='coop').view(-1) 69 | 70 | @th.inference_mode() 71 | def sample(model, query=None, n_samples=128, beta=1, max_length=32, temperature=0.8, top_k=20): 72 | if query is None: 73 | query = tensor([tokenizer.bos_token_id] * n_samples, device=model.device).view(n_samples, 1) 74 | 75 | for _ in range(max_length): 76 | logits, qs, _, vs = model(input_ids=query) 77 | logits = logits[:, -1, :] 78 | qs = qs[:, -1, :] 79 | vs = vs[:, -1, :] 80 | 81 | adv = qs - vs 82 | pi = F.log_softmax(logits, -1) 83 | modpi = topk_mask(pi + beta * adv, top_k) 84 | ps = F.softmax(modpi / temperature, -1) 85 | 86 | tokens = th.multinomial(ps, 1) 87 | query = th.hstack((query, tokens)) 88 | 89 | return query 90 | 91 | class Carps(Dataset): 92 | def __init__(self, review='good', max_length=48, diff_reward=True, n_samples=64): 93 | self.review = review 94 | self.max_length = max_length 95 | self.n_samples = n_samples 96 | 97 | cache_path = f'stash/carps-{max_length}l-{diff_reward}d.pt' 98 | 99 | if os.path.exists(cache_path): 100 | cache = th.load(cache_path) 101 | self.tokens = cache['tokens'] 102 | self.rewards = cache['rewards'] 103 | self.attention_masks = cache['attention_masks'] 104 | self.validation_queries = cache['validation_queries'] 105 | else: 106 | ds, valid = load_dataset( 107 | 'text', 108 | data_files={'train': 'roc_train_all.txt', 'valid': 'roc_valid.txt'}, 109 | split=['train', f'valid[:{n_samples}]']) 110 | 111 | if diff_reward: 112 | vocab = list(tokenizer.get_vocab().keys()) 113 | offset = th.hstack([carp_score(words, review) for words in sizesplit(32, vocab)]).mean() 114 | else: 115 | offset = 0 116 | 117 | ds = ds.map(partial(tokenize, max_length, diff_reward, offset, review)) 118 | valid = valid.map(partial(tokenize, max_length, diff_reward, offset, review)) 119 | 120 | self.tokens = th.tensor(ds['tokens']).squeeze(1) 121 | self.rewards = tensor(ds['rewards']) 122 | self.attention_masks = tensor(ds['attention']) 123 | self.validation_queries = tensor(valid['tokens']).squeeze(1)[:n_samples, :6] 124 | 125 | th.save({ 'tokens': self.tokens, 126 | 'rewards': self.rewards, 127 | 'attention_masks': self.attention_masks, 128 | 'validation_queries': self.validation_queries }, cache_path) 129 | 130 | def __len__(self): 131 | return self.tokens.shape[0] 132 | 133 | def __getitem__(self, ind): 134 | return self.tokens[ind], self.attention_masks[ind], self.rewards[ind] 135 | 136 | def eval(self, logs, model, betas=[1]): 137 | model.eval() 138 | queries = self.validation_queries.to(model.device) 139 | 140 | for beta in betas: 141 | responses = sample(model, query=queries, beta=beta, max_length=self.max_length, n_samples=self.n_samples) 142 | texts = [tokenizer.decode(response[1:]) for response in responses] 143 | 144 | rewards = th.hstack([carp_score(ts, self.review) for ts in sizesplit(8, texts)]) 145 | reward = rewards.mean().item() 146 | rows = list(zip(texts, rewards.tolist())) 147 | 148 | print(f'\n{beta=} {reward=:.2f}\n' + '\n'.join([f'[{r:.2f}] {text}' for text, r in rows[:8]])) 149 | 150 | logs[f'reward/beta{beta}'] = reward 151 | logs.update({f'responses/beta{beta}': wandb.Table(columns=['response', 'reward'], rows=rows[:32])}) 152 | 153 | stats = {'reward': f'{reward:.2f}'} 154 | model.train() 155 | return reward, stats 156 | -------------------------------------------------------------------------------- /ilql.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import yaml 4 | from time import time 5 | 6 | import torch as th 7 | from torch import tensor, nn 8 | import torch.nn.functional as F 9 | from torch.utils.data import DataLoader 10 | from transformers import GPT2Config, AutoTokenizer 11 | from accelerate import Accelerator 12 | import numpy as np 13 | 14 | import wandb 15 | from tqdm import tqdm, trange 16 | from utils import run_from_ipython, timeit, check_weights 17 | from models import QVModel 18 | from copy import deepcopy 19 | import accelerate 20 | import deepspeed 21 | 22 | th.set_printoptions(sci_mode=False) 23 | 24 | WORLD_SIZE = int(os.environ.get('WORLD_SIZE', 1)) 25 | WORLD_RANK = int(os.environ.get('RANK', 0)) 26 | LOCAL_RANK = int(os.environ.get('LOCAL_RANK', 0)) 27 | 28 | def main(**args): 29 | task = args['task'] if 'task' in args else 'RandomWalks' 30 | config = yaml.safe_load(open('config.yaml'))[task] 31 | config.update(args) 32 | 33 | accelerator = Accelerator(log_with='wandb') 34 | accelerator.print(os.environ) 35 | device = accelerator.device 36 | 37 | if WORLD_SIZE > 1: 38 | th.distributed.barrier(device_ids=[LOCAL_RANK]) 39 | else: 40 | th.random.manual_seed(1000) 41 | 42 | if not config.get('debug', False) and accelerator.is_main_process: 43 | modelname = config.get('model', '') 44 | accelerator.init_trackers(project_name='test-ilql', init_kwargs={ 45 | 'wandb': {'name': f'ilql-{task}-{modelname}', 46 | 'mode': 'disabled' if args.get('debug', False) else 'online'}}, config=config) 47 | 48 | config = wandb.config 49 | 50 | if task == 'RandomWalks': 51 | from randomwalks import RandomWalks 52 | data = RandomWalks(seed=config['seed']) 53 | gptconfig = GPT2Config(**config['gptconfig'], vocab_size=data.n_nodes) 54 | model = QVModel(gptconfig, config) 55 | 56 | elif task == 'Sentiments': 57 | from sentiments import Sentiments 58 | 59 | with accelerator.main_process_first(): 60 | tokenizer = AutoTokenizer.from_pretrained(config['model']) 61 | tokenizer.pad_token_id = tokenizer.eos_token_id 62 | data = Sentiments(tokenizer, needs_reward_model=accelerator.is_main_process) 63 | 64 | with timeit('init model'): 65 | model = QVModel(config['model'], config) 66 | 67 | elif task == 'Carps': 68 | from carps import Carps 69 | data = Carps(max_length=config['max_length'], diff_reward=config['diff_reward']) 70 | model = QVModel(config['model'], two_qs=config['two_qs']).to(device) 71 | 72 | elif task == 'Captions': 73 | from captions import AestheticCaptions 74 | 75 | tokenizer = AutoTokenizer.from_pretrained(config['model']) 76 | tokenizer.pad_token = tokenizer.eos_token_id 77 | with accelerator.main_process_first(): 78 | data = AestheticCaptions(tokenizer, batch_size=config['batch_size'], n_samples=16) 79 | 80 | model = QVModel(config['model'], two_qs=config['two_qs']).to(device) 81 | else: 82 | raise ValueError(f'nonexistent {task=}') 83 | 84 | if hasattr(model.gpt, 'gpt_neox'): 85 | gpt_blocks = list(model.gpt.gpt_neox.layers)[:-config['n_layers_unfrozen']] 86 | else: 87 | gpt_blocks = list(model.gpt.transformer.h)[:-config['n_layers_unfrozen']] 88 | 89 | for m in gpt_blocks: 90 | m.requires_grad_(False) 91 | 92 | train_dataloader = DataLoader(data.dataset, batch_size=config['batch_size']) 93 | 94 | eval_batch_size = max(1, len(data.eval_dataset) // WORLD_SIZE) 95 | eval_dataloader = DataLoader(data.eval_dataset, eval_batch_size) 96 | 97 | opt_cls = ( 98 | th.optim.AdamW 99 | if accelerator.state.deepspeed_plugin is None 100 | or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config 101 | else accelerate.utils.DummyOptim 102 | ) 103 | opt = opt_cls([p for p in model.parameters() if p.requires_grad], lr=config['lr'], betas=config['opt_betas']) 104 | 105 | total_steps = int(config['n_epochs'] * (len(data.dataset) // (config['batch_size'] * WORLD_SIZE))) 106 | n_opt_steps = 0 107 | 108 | with timeit('prepare'): 109 | model, opt, train_dataloader, eval_dataloader = accelerator.prepare( 110 | model, opt, train_dataloader, eval_dataloader 111 | ) 112 | 113 | print(f'{WORLD_RANK=}: {model(**accelerator.unwrap_model(model).dummy_inputs)[0].device}') 114 | 115 | model.train() 116 | tbar = trange(total_steps, disable=not accelerator.is_local_main_process) 117 | 118 | for iepoch in range(config['n_epochs']): 119 | for batch in train_dataloader: 120 | logs = {} 121 | 122 | if n_opt_steps % config['steps_for_eval'] == 0: 123 | model.eval() 124 | beta = config['inference_betas'][0] 125 | 126 | all_samples = [] 127 | for tokens in eval_dataloader: 128 | tokens = tokens[0].to(device) 129 | with th.no_grad(): 130 | samples, stats = accelerator.unwrap_model(model).sample( 131 | tokens, 132 | beta=beta, 133 | max_length=data.max_length, 134 | logit_mask=data.logit_mask 135 | ) 136 | 137 | all_samples.append(samples) 138 | logs.update(stats) 139 | 140 | samples = accelerator.gather(th.vstack(all_samples)) 141 | 142 | if accelerator.is_main_process: 143 | reward, stats = data.eval(samples, beta) 144 | logs.update(stats) 145 | tbar.set_postfix(stats) 146 | 147 | model.train() 148 | 149 | for ix in range(len(batch)): 150 | batch[ix] = batch[ix].to(device) 151 | 152 | batch_time = time() 153 | forward_time = time() 154 | loss, stats = model.loss(batch) 155 | forward_time = time() - forward_time 156 | 157 | backward_time = time() 158 | accelerator.backward(loss) 159 | backward_time = time() - backward_time 160 | 161 | opt.step() 162 | opt.zero_grad() 163 | n_opt_steps += 1 164 | 165 | batch_time = time() - batch_time 166 | tokens_per_sec = batch[0].numel() * WORLD_SIZE / batch_time 167 | tbar.set_description(f'{tokens_per_sec=:.2f} {batch_time=:.2f}') 168 | tbar.update() 169 | 170 | if (n_opt_steps + 1) % config['steps_for_target_q_sync'] == 0: 171 | accelerator.unwrap_model(model).sync_target_q_heads() 172 | 173 | logs.update(stats) 174 | logs['target_sum'] = check_weights(accelerator.unwrap_model(model).target_q1_head) 175 | logs['batch_time'] = batch_time 176 | 177 | if not config.get('debug', False): 178 | accelerator.log(logs) 179 | 180 | return model, data 181 | 182 | if __name__ == '__main__': 183 | if os.environ.get('LOCAL_RANK'): 184 | os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = os.environ['LOCAL_RANK'] 185 | 186 | if run_from_ipython: 187 | args = {'debug': True} 188 | else: 189 | # poor man's argparse 190 | args = {a[2:]: eval(v) for a, v in map(lambda s: s.split('='), sys.argv[1:])} 191 | 192 | main(**args) 193 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch as th 3 | import numpy as np 4 | from torch import tensor, nn 5 | import torch.nn.functional as F 6 | import transformers 7 | from transformers import AutoModelForCausalLM, PretrainedConfig, AutoConfig 8 | 9 | from typing import NamedTuple, Tuple, Union 10 | from copy import deepcopy 11 | from collections import defaultdict 12 | from accelerate.utils import compute_module_sizes 13 | from itertools import chain 14 | 15 | import accelerate 16 | import deepspeed 17 | 18 | def topk_mask(xs: th.FloatTensor, k: int): 19 | mintop = th.topk(xs, k)[0][:, -1].unsqueeze(-1) 20 | return th.where(xs < mintop, -np.inf * th.ones_like(xs, dtype=xs.dtype), xs) 21 | 22 | class QVOutput(Tuple): 23 | logits: th.FloatTensor 24 | qs: th.FloatTensor 25 | target_qs: th.FloatTensor 26 | vs: th.FloatTensor 27 | past_key_values: Tuple[th.FloatTensor] 28 | 29 | def make_head(n_embd: int, out: int): 30 | return nn.Sequential( 31 | nn.Linear(n_embd, n_embd * 2), 32 | nn.ReLU(), 33 | nn.Linear(n_embd * 2, out) 34 | ) 35 | 36 | class QVModel(nn.Module): 37 | def __init__(self, config: Union[PretrainedConfig, str], params): 38 | super().__init__() 39 | 40 | # enable zero3 init within from_pretrained 41 | if os.environ.get('DEEPSPEED_ZERO_STAGE', '0') == '3': 42 | config_path = os.environ.get('DEEPSPEED_CONFIG_FILE', '') 43 | if config_path: 44 | _hfconfig = transformers.deepspeed.HfDeepSpeedConfig(config_path) 45 | 46 | if isinstance(config, PretrainedConfig): 47 | self.gpt = AutoModelForCausalLM.from_config(config) 48 | else: 49 | self.gpt = AutoModelForCausalLM.from_pretrained(config) 50 | 51 | if hasattr(self.gpt.config, 'hidden_size'): 52 | self.n_embd = self.gpt.config.hidden_size 53 | else: 54 | self.n_embd = self.gpt.config.n_embd 55 | self.vocab_size = self.gpt.config.vocab_size 56 | 57 | self.v_head = make_head(self.n_embd, 1) 58 | self.q1_head = make_head(self.n_embd, self.vocab_size) 59 | self.target_q1_head = deepcopy(self.q1_head) 60 | self.target_q1_head.requires_grad_(False) 61 | 62 | self.tau = params['tau'] 63 | self.alpha = params['alpha'] 64 | self.gamma = params['gamma'] 65 | self.awac_scale = params['awac_scale'] 66 | self.cql_scale = params['cql_scale'] 67 | self.two_qs = params['two_qs'] 68 | 69 | if self.two_qs: 70 | self.q2_head = make_head(self.n_embd, self.vocab_size) 71 | self.target_q2_head = deepcopy(self.q2_head) 72 | self.target_q2_head.requires_grad_(False) 73 | 74 | def forward(self, **x): 75 | if hasattr(self.gpt, 'gpt_neox'): 76 | out = self.gpt.gpt_neox(**x) 77 | else: 78 | out = self.gpt.transformer(**x) 79 | 80 | hs = out.last_hidden_state 81 | 82 | if self.two_qs: 83 | qs = (self.q1_head(hs), self.q2_head(hs)) 84 | target_qs = (self.target_q1_head(hs), self.target_q2_head(hs)) 85 | else: 86 | qs = self.q1_head(hs) 87 | target_qs = self.target_q1_head(hs) 88 | 89 | if hasattr(self.gpt, 'gpt_neox'): 90 | logits = self.gpt.embed_out(hs) 91 | else: 92 | logits = self.gpt.lm_head(hs) 93 | 94 | return QVOutput((logits, qs, target_qs, self.v_head(hs), out.past_key_values)) 95 | 96 | def loss(self, batch): 97 | tokens, attn, rewards = batch 98 | actions = tokens[:, 1:, None] 99 | isterminal = attn[:, :-1] 100 | 101 | logits, qs, target_qs, vs, _ = self(input_ids=tokens, attention_mask=attn) 102 | bsize, ntokens, dsize = logits.shape 103 | 104 | if self.two_qs: 105 | Q1 = qs[0][:, :-1].gather(-1, actions).squeeze(-1) 106 | Q2 = qs[1][:, :-1].gather(-1, actions).squeeze(-1) 107 | 108 | targetQ1 = target_qs[0][:, :-1].gather(-1, actions).squeeze(-1).detach() 109 | targetQ2 = target_qs[1][:, :-1].gather(-1, actions).squeeze(-1).detach() 110 | targetQ = th.minimum(targetQ1, targetQ2) 111 | else: 112 | Q = qs[:, :-1].gather(-1, actions).squeeze(-1) 113 | targetQ = target_qs[:, :-1].gather(-1, actions).squeeze(-1).detach() 114 | 115 | n_nonterminal = max(1, isterminal.sum()) 116 | V = vs[:, 1:].squeeze() * isterminal 117 | Q_ = rewards + self.gamma * V 118 | 119 | if self.two_qs: 120 | loss_q1 = ((Q1 - Q_.detach()) * isterminal).pow(2).sum() / n_nonterminal 121 | loss_q2 = ((Q2 - Q_.detach()) * isterminal).pow(2).sum() / n_nonterminal 122 | loss_q = loss_q1 + loss_q2 123 | else: 124 | loss_q = ((Q - Q_.detach()) * isterminal).pow(2).sum() / n_nonterminal 125 | 126 | loss_v = (((targetQ >= V).int() * self.tau * (targetQ - V).pow(2) + (targetQ < V).int() * (1 - self.tau) * (targetQ - V).pow(2)) * isterminal).sum() / n_nonterminal 127 | 128 | if self.two_qs: 129 | loss_cql_q1 = (F.cross_entropy(qs[0][:, :-1].reshape(-1, dsize), actions.reshape(-1), reduction='none').reshape(bsize, ntokens-1) * isterminal).sum() / n_nonterminal 130 | loss_cql_q2 = (F.cross_entropy(qs[1][:, :-1].reshape(-1, dsize), actions.reshape(-1), reduction='none').reshape(bsize, ntokens-1) * isterminal).sum() / n_nonterminal 131 | loss_cql = loss_cql_q1 + loss_cql_q2 132 | else: 133 | loss_cql = (F.cross_entropy(qs[:, :-1].reshape(-1, dsize), actions.reshape(-1), reduction='none').reshape(bsize, ntokens-1) * isterminal).sum() / n_nonterminal 134 | 135 | loss_awac = (F.cross_entropy(logits[:, :-1].reshape(-1, dsize), actions.reshape(-1), reduction='none').reshape(bsize, ntokens-1) * isterminal).sum() / n_nonterminal 136 | 137 | loss = loss_q + loss_v + self.cql_scale * loss_cql + self.awac_scale * loss_awac 138 | stats = { 139 | k: v for k, v in locals().items() if k in 140 | ['loss', 'loss_v', 'loss_q', 'loss_cql', 'loss_awac'] 141 | } 142 | 143 | return loss, stats 144 | 145 | def _sync_target_q_heads(self, alpha): 146 | for target_param, copy_param in zip(self.target_q1_head.parameters(), self.q1_head.parameters()): 147 | target_param.data.copy_((alpha * copy_param.data) + (1.0 - alpha) * target_param.data) 148 | 149 | if self.two_qs: 150 | for target_param, copy_param in zip(self.target_q2_head.parameters(), self.q2_head.parameters()): 151 | target_param.data.copy_((alpha * copy_param.data) + (1.0 - alpha) * target_param.data) 152 | 153 | def sync_target_q_heads(self): 154 | if os.environ.get('DEEPSPEED_ZERO_STAGE', '0') == '3': 155 | params = chain(self.q1_head.parameters(), 156 | self.target_q1_head.parameters(), 157 | self.q2_head.parameters() if self.two_qs else [], 158 | self.target_q2_head.parameters() if self.two_qs else []) 159 | 160 | with deepspeed.zero.GatheredParameters(list(params), modifier_rank=0): 161 | if deepspeed.comm.get_rank() == 0: 162 | self._sync_target_q_heads(self.alpha) 163 | else: 164 | self._sync_target_q_heads(self.alpha) 165 | 166 | @th.inference_mode() 167 | def sample(self, query, beta=1, max_length=32, temperature=1, top_k=20, logit_mask=None, logs=True, eos_token_id=50256): 168 | input = query.clone() 169 | past_key_values = None 170 | tensors = defaultdict(list) 171 | 172 | finished = th.zeros(input.shape[0], 1, dtype=th.long, device=query.device) 173 | 174 | for _ in range(max_length-1): 175 | logits, _, target_qs, vs, past_key_values = self.forward(input_ids=input, past_key_values=past_key_values) 176 | 177 | if self.two_qs: 178 | qs = th.minimum(target_qs[0][:, -1], target_qs[1][:, -1]) 179 | else: 180 | qs = target_qs[:, -1] 181 | 182 | logits = logits[:, -1] 183 | 184 | if logit_mask is not None: 185 | logits[th.where(logit_mask[input[:, -1]])] = -np.inf 186 | 187 | adv = qs - vs[:, -1, :] 188 | pi = F.log_softmax(logits, -1) 189 | modpi = topk_mask(pi + beta * adv, top_k) 190 | ps = F.softmax(modpi / temperature, -1) 191 | 192 | tokens = th.multinomial(ps, 1) 193 | tokens = (1 - finished) * tokens + finished * eos_token_id 194 | 195 | query = th.hstack((query, tokens)) 196 | 197 | input = tokens 198 | finished = (tokens == eos_token_id).long() 199 | 200 | if logs: 201 | tensors['qs'].append(qs) 202 | tensors['vs'].append(vs) 203 | tensors['adv'].append(adv) 204 | 205 | stats = {} 206 | for name, xs in tensors.items(): 207 | xs = th.vstack(xs) 208 | stats.update({ 209 | f'{name}-min': xs.min(), 210 | f'{name}-max': xs.max(), 211 | f'{name}-std': xs.std(), 212 | f'{name}-avg': xs.mean(), 213 | }) 214 | 215 | return query, stats 216 | 217 | @property 218 | def dummy_inputs(self): 219 | return {'input_ids': th.ones(1, 1, device=self.gpt.device, dtype=th.long)} 220 | 221 | @property 222 | def device(self): 223 | return self.gpt.device 224 | -------------------------------------------------------------------------------- /scripts/graph_plot.svg: -------------------------------------------------------------------------------- 1 | 2 | 18 | 38 | 40 | 41 | 42 | 44 | 2022-09-04T18:32:28.332967 45 | image/svg+xml 46 | 47 | 48 | Matplotlib v3.5.3, https://matplotlib.org/ 49 | 50 | 51 | 52 | 53 | 54 | 56 | 59 | 60 | 63 | 66 | 71 | 72 | 75 | 77 | 81 | 82 | 84 | 86 | 88 | 93 | 94 | 96 | 97 | 101 | 103 | 107 | 108 | 111 | 112 | 113 | 114 | 116 | 118 | 123 | 124 | 126 | 127 | 131 | 133 | 137 | 138 | 141 | 142 | 143 | 144 | 146 | 148 | 153 | 154 | 156 | 157 | 161 | 163 | 167 | 168 | 171 | 172 | 173 | 174 | 176 | 178 | 183 | 184 | 186 | 187 | 191 | 193 | 197 | 198 | 201 | 202 | 203 | 204 | 206 | 208 | 213 | 214 | 216 | 217 | 221 | 223 | 227 | 228 | 231 | 232 | 233 | 234 | 236 | 238 | 243 | 244 | 246 | 247 | 251 | 253 | 257 | 258 | 261 | 262 | 263 | 264 | 266 | 268 | 273 | 274 | 276 | 277 | 281 | 283 | 287 | 288 | 291 | 292 | 293 | 294 | 296 | 298 | 303 | 304 | 306 | 307 | 311 | 313 | 317 | 318 | 321 | 322 | 323 | 324 | 326 | 328 | 333 | 334 | 336 | 337 | 341 | 343 | 347 | 348 | 351 | 352 | 353 | 354 | 356 | 358 | 363 | 364 | 366 | 367 | 371 | 373 | 377 | 378 | 381 | 382 | 383 | 384 | 386 | 387 | 391 | 393 | 397 | 401 | 405 | 409 | 413 | 417 | 421 | 425 | 429 | 433 | 437 | 438 | 441 | 445 | 449 | 453 | 457 | 461 | 465 | 469 | 473 | 477 | 481 | 485 | 489 | 493 | 497 | 501 | 505 | 509 | 510 | 511 | 512 | 514 | 516 | 518 | 523 | 524 | 526 | 527 | 531 | 533 | 537 | 541 | 542 | 545 | 549 | 553 | 554 | 555 | 556 | 558 | 560 | 565 | 566 | 568 | 569 | 573 | 576 | 580 | 584 | 585 | 586 | 587 | 589 | 591 | 596 | 597 | 599 | 600 | 604 | 607 | 611 | 615 | 616 | 617 | 618 | 620 | 622 | 627 | 628 | 630 | 631 | 635 | 638 | 642 | 646 | 647 | 648 | 649 | 651 | 653 | 658 | 659 | 661 | 662 | 666 | 669 | 673 | 677 | 678 | 679 | 680 | 682 | 684 | 689 | 690 | 692 | 693 | 697 | 700 | 704 | 708 | 709 | 710 | 711 | 713 | 714 | 718 | 720 | 724 | 728 | 732 | 736 | 737 | 740 | 744 | 748 | 752 | 756 | 760 | 764 | 768 | 772 | 776 | 780 | 784 | 788 | 792 | 796 | 800 | 804 | 808 | 812 | 813 | 814 | 815 | 817 | 822 | 823 | 825 | 830 | 831 | 833 | 838 | 839 | 841 | 846 | 847 | 849 | 854 | 855 | 857 | 862 | 863 | 865 | 870 | 871 | 873 | 878 | 879 | 881 | 886 | 887 | 889 | 894 | 895 | 897 | 902 | 903 | 905 | 910 | 911 | 913 | 918 | 919 | 921 | 926 | 927 | 929 | 934 | 935 | 937 | 942 | 943 | 945 | 950 | 951 | 953 | 958 | 959 | 961 | 966 | 967 | 969 | 974 | 975 | 977 | 981 | 982 | 984 | 988 | 989 | 991 | 995 | 996 | 998 | 1002 | 1003 | 1005 | 1010 | 1011 | 1013 | 1018 | 1019 | 1021 | 1026 | 1027 | 1029 | 1034 | 1035 | 1037 | 1042 | 1043 | 1045 | 1050 | 1051 | 1053 | 1058 | 1059 | 1061 | 1066 | 1067 | 1069 | 1074 | 1075 | 1077 | 1082 | 1083 | 1085 | 1087 | 1091 | 1092 | 1094 | 1098 | 1099 | 1101 | 1102 | 1106 | 1109 | 1113 | 1117 | 1121 | 1125 | 1129 | 1133 | 1137 | 1141 | 1145 | 1149 | 1153 | 1157 | 1158 | 1159 | 1161 | 1165 | 1166 | 1168 | 1169 | 1173 | 1175 | 1179 | 1183 | 1187 | 1188 | 1191 | 1195 | 1199 | 1203 | 1204 | 1205 | 1207 | 1211 | 1212 | 1214 | 1215 | 1219 | 1221 | 1225 | 1229 | 1233 | 1237 | 1238 | 1241 | 1245 | 1249 | 1253 | 1257 | 1261 | 1265 | 1269 | 1273 | 1277 | 1281 | 1282 | 1283 | 1284 | 1285 | 1286 | 1288 | 1290 | 1296 | 1297 | 1298 | 1299 | --------------------------------------------------------------------------------