├── requirements.txt
├── readme.md
├── deepspeed.yaml
├── config.yaml
├── sentiments.py
├── scripts
├── graph_plot.py
└── graph_plot.svg
├── captions.py
├── utils.py
├── randomwalks.py
├── carps.py
├── ilql.py
└── models.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.12.1+cu113
2 | transformers==4.21.1
3 | accelerate==0.12.0
4 | deepspeed==0.7.0
5 | datasets==2.4.0
6 | tokenizers==0.12.1
7 | tqdm==4.64.0
8 | wandb==0.13.2
9 | networkx==2.8.6
10 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | Simplified implementation of [Implicit Language Q Learning (Snell et al. 2022)](https://sea-snell.github.io/ILQL_site/) ([official](https://github.com/Sea-Snell/Implicit-Language-Q-Learning/), [paper](https://arxiv.org/abs/2206.11871))
2 |
3 | Evaluating on Graph Shortest Path task from [Decision Transformer (Lili Chen et al. 2021)](https://arxiv.org/abs/2106.01345):
4 |
5 |
6 |
7 | where for each random graph, a transformer is trained to find optimal trajectories using only 1000 random walks.
8 |
--------------------------------------------------------------------------------
/deepspeed.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config:
3 | deepspeed_config_file: zero3.json
4 | deepspeed_multinode_launcher: standard
5 | gradient_accumulation_steps: 1
6 | offload_optimizer_device: none
7 | offload_param_device: none
8 | gradient_clipping: 1.0
9 | zero3_init_flag: false
10 | zero_stage: 3
11 | distributed_type: DEEPSPEED
12 | downcast_bf16: no
13 | fsdp_config: {}
14 | machine_rank: ???
15 | main_process_ip: ???
16 | main_process_port: 1234
17 | main_training_function: main
18 | mixed_precision: bf16
19 | num_machines: ???
20 | num_processes: ???
21 | rdzv_backend: static
22 | same_network: true
23 | use_cpu: false
24 |
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | RandomWalks:
2 | lr: 0.001
3 | opt_betas: [0.9, 0.95]
4 | batch_size: 500
5 | tau: 0.7
6 | gamma: 0.99
7 | cql_scale: 0.1
8 | awac_scale: 1
9 | alpha: 0.1
10 | steps_for_target_q_sync: 10
11 | steps_for_eval: 5
12 | inference_betas: [1]
13 | n_epochs: 100
14 | seed: 1000
15 | n_layers_unfrozen: 0
16 | two_qs: true
17 | gptconfig:
18 | n_embd: 144
19 | n_layer: 6
20 | n_head: 1
21 |
22 | Sentiments:
23 | lr: 0.0006
24 | opt_betas: [0.9, 0.95]
25 | batch_size: 6
26 | tau: 0.5
27 | gamma: 0.99
28 | cql_scale: 0.1
29 | awac_scale: 1
30 | alpha: 1
31 | steps_for_target_q_sync: 40
32 | steps_for_eval: 100
33 | inference_betas: [1]
34 | model: EleutherAI/gpt-j-6B
35 | n_layers_unfrozen: 0
36 | two_qs: true
37 | n_epochs: 1
38 |
39 | Carps:
40 | lr: 0.00004
41 | opt_betas: [0.9, 0.95]
42 | batch_size: 72
43 | tau: 0.5
44 | cql_scale: 0.1
45 | alpha: 1
46 | steps_for_target_q_sync: 100
47 | inference_betas: [0, 1, 2, 4]
48 | n_epochs: 10
49 | model: gpt2-large
50 | n_layers_unfrozen: 2
51 | max_length: 48
52 | diff_reward: true
53 |
54 | Captions:
55 | lr: 0.0006
56 | opt_betas: [0.9, 0.95]
57 | batch_size: 16
58 | tau: 0.7
59 | gamma: 0.99
60 | cql_scale: 0.1
61 | awac_scale: 1
62 | alpha: 1
63 | steps_for_target_q_sync: 50
64 | steps_for_eval: 100
65 | inference_betas: [0, 1]
66 | model: gpt2-large
67 | n_layers_unfrozen: 0
68 | two_qs: true
69 | n_epochs: 1
70 |
--------------------------------------------------------------------------------
/sentiments.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch as th
4 | from torch import tensor
5 | import torch.nn.functional as F
6 | from torch.utils.data import TensorDataset
7 | from functools import partial, reduce
8 |
9 | from transformers import AutoTokenizer, pipeline
10 | from datasets import load_dataset
11 | import wandb
12 | from tqdm import tqdm
13 | import math
14 |
15 | from utils import batch_map, tohuman, load_tensors
16 |
17 | def get_reward(sentiment_pipe, texts):
18 | sentiments = batch_map(lambda batch: sentiment_pipe(batch), texts, bsize=1024, desc='Sentiments')
19 | return tensor([-s['score'] if s['label'] == 'NEGATIVE' else s['score'] for s in sentiments])
20 |
21 | class Sentiments:
22 | def __init__(self, tokenizer: AutoTokenizer, max_length=50, n_samples=64, needs_reward_model=False):
23 | self.max_length = max_length
24 | self.tokenizer = tokenizer
25 |
26 | if needs_reward_model:
27 | self.sentiment_pipe = pipeline('sentiment-analysis', 'lvwerra/distilbert-imdb', device=th.device(0))
28 | else:
29 | self.sentiment_pipe = None
30 |
31 | texts = load_dataset('imdb', split='train+test')['text']
32 | tensors = load_tensors(
33 | 'sentiments',
34 | texts=texts,
35 | reward_model=partial(get_reward, self.sentiment_pipe),
36 | tokenizer=self.tokenizer,
37 | max_length=max_length,
38 | use_cache=True
39 | )
40 |
41 | query = tensor([self.tokenizer.bos_token_id] * n_samples).view(n_samples, 1)
42 | self.logit_mask = None
43 |
44 | self.dataset = TensorDataset(tensors['input_ids'], tensors['attention_mask'], tensors['rewards'])
45 | self.eval_dataset = TensorDataset(query)
46 |
47 | def eval(self, samples, beta):
48 | reviews = self.tokenizer.batch_decode(samples, skip_special_tokens=True)
49 |
50 | rewards = [1-s['score'] if s['label'] == 'NEGATIVE' else s['score'] for s in self.sentiment_pipe(reviews)]
51 | reward = np.mean(rewards)
52 |
53 | rows = list(zip(reviews, rewards))
54 | print(f'\n{beta=} {reward=:.2f}\n' + '\n'.join([f'[{sent:.2f}] {text}' for text, sent in rows[:8]]))
55 |
56 | stats = { f'reward/{beta}': reward,
57 | f'responses/{beta}': wandb.Table(columns=['response', 'sentiment'], rows=rows[:32]) }
58 |
59 | return reward, stats
60 |
61 | if __name__ == '__main__':
62 | import sys
63 | from rich import print
64 | tokenizer = AutoTokenizer.from_pretrained(sys.argv[1])
65 | tokenizer.pad_token = tokenizer.eos_token
66 | ds = Sentiments(tokenizer, needs_reward_model=True).dataset
67 | print(f'{next(iter(ds))=}')
68 |
--------------------------------------------------------------------------------
/scripts/graph_plot.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4 | from ilql import main
5 | from torch import tensor
6 | import torch as th
7 | import networkx as nx
8 | import numpy as np
9 |
10 | optimal_lengths = []
11 | sampled_lengths = []
12 | iql_lengths = []
13 |
14 | for seed in range(10):
15 | model, data = main(seed=seed, debug=True)
16 | model.eval()
17 |
18 | g = nx.from_numpy_array(data.adj, create_using=nx.DiGraph)
19 |
20 | # optimal
21 | for start in set(range(data.n_nodes)) - {data.goal}:
22 | try:
23 | shortest_path = nx.shortest_path(g, start, data.goal)[:data.walk_size]
24 | optimal_lengths.append(len(shortest_path)-1)
25 | except:
26 | optimal_lengths.append(data.walk_size)
27 |
28 | # ilql
29 | starts = th.arange(1, data.n_nodes).unsqueeze(1).to(model.device)
30 | paths, _ = model.sample(starts, max_length=data.walk_size, logit_mask=tensor(~data.adj), beta=10) # argmax
31 | for path in paths:
32 | length = data.walk_size
33 | for ind, node in enumerate(path):
34 | if node == data.goal:
35 | length = ind
36 | break
37 |
38 | iql_lengths.append(length)
39 |
40 | # all samples
41 | for path in data.tensors[0]:
42 | length = data.walk_size
43 | for ind, node in enumerate(path):
44 | if node == data.goal:
45 | length = ind
46 | break
47 |
48 | sampled_lengths.append(length)
49 | # ■ ~
50 |
51 | from matplotlib import pyplot
52 | import matplotlib
53 |
54 | fontcolor = '#444'
55 | matplotlib.rcParams['text.color'] = fontcolor
56 | matplotlib.rcParams['axes.labelcolor'] = fontcolor
57 | matplotlib.rcParams['xtick.color'] = fontcolor
58 | matplotlib.rcParams['ytick.color'] = fontcolor
59 | matplotlib.rcParams['xtick.labelcolor'] = fontcolor
60 | matplotlib.rcParams['ytick.labelcolor'] = fontcolor
61 | matplotlib.rcParams['xtick.labelcolor'] = fontcolor
62 |
63 | matplotlib.rcParams["font.family"] = "Futura"
64 | matplotlib.rcParams["font.size"] = 15
65 | matplotlib.rcParams["xtick.labelsize"] = 20
66 | matplotlib.rcParams["ytick.labelsize"] = 20
67 | matplotlib.rcParams["figure.titlesize"] = 12
68 | matplotlib.rcParams["figure.figsize"] = 15, 8
69 |
70 | matplotlib.style.use('ggplot')
71 | matplotlib.rcParams['figure.dpi'] = 70
72 |
73 | ax = pyplot.gca()
74 | ax.set_facecolor('#fff')
75 | ax.grid(color='lightgray', alpha=0.4, axis='y')
76 | ax.tick_params(top=False, labeltop=False, bottom=False, labelbottom=True, left=False, labelleft=True)
77 |
78 | optimal_hist = np.histogram(optimal_lengths, bins=np.arange(1, data.walk_size+2), density=True)[0]
79 | sampled_hist = np.histogram(sampled_lengths, bins=np.arange(1, data.walk_size+2), density=True)[0]
80 | iql_hist = np.histogram(iql_lengths, bins=np.arange(1, data.walk_size+2), density=True)[0]
81 |
82 | barsize = 0.36
83 | iql_color = '#99a3fd'
84 | opt_color = '#f2ad48'
85 | random_color='lightgray'
86 |
87 | pyplot.bar(np.arange(1, data.walk_size+1)-barsize/1.5, optimal_hist, width=barsize, label='shortest path', color=opt_color, zorder=2)
88 | pyplot.bar(np.arange(1, data.walk_size+1), iql_hist, width=barsize, label='ILQL', color=iql_color, zorder=3)
89 | pyplot.bar(np.arange(1, data.walk_size+1)+barsize/1.5, sampled_hist, width=barsize, label='random walk', color=random_color, zorder=1)
90 |
91 | pyplot.legend(fontsize=16)
92 | pyplot.xticks(np.arange(1, data.walk_size+1), list(np.arange(1, data.walk_size)) + ['∞'])
93 |
94 | pyplot.xlabel('# of steps to goal', fontsize=22, color=fontcolor, labelpad=20)
95 | pyplot.ylabel('proportion of paths', fontsize=22, color=fontcolor, labelpad=20)
96 |
97 | pyplot.savefig('scripts/graph_plot.svg')
98 |
--------------------------------------------------------------------------------
/captions.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch as th
4 | from torch import tensor
5 | import torch.nn.functional as F
6 | from torch.utils.data import TensorDataset
7 |
8 | import wandb
9 | import sqlite3
10 |
11 | from utils import batch_map, tohuman
12 |
13 | def load_tensors(cache_path, tokenizer, use_cache=True):
14 | cache_path = f'{cache_path}_tokenizer={tokenizer.name_or_path}.pt'
15 |
16 | if use_cache and os.path.exists(cache_path):
17 | out = th.load(cache_path)
18 | else:
19 | conn = sqlite3.connect('sac_public_2022_06_29.sqlite')
20 | c = conn.cursor()
21 | c.execute("SELECT prompt, rating FROM ratings "
22 | "JOIN images ON images.id=ratings.iid "
23 | "JOIN generations ON images.gid=generations.id "
24 | "WHERE rating IS NOT NULL;")
25 |
26 | prompts, ratings = tuple(map(list, zip(*filter(lambda x: len(x[0]) > 10, c.fetchall()))))
27 |
28 | out = tokenizer(prompts, padding=True, return_tensors='pt')
29 | input_ids, attention_mask = out['input_ids'], out['attention_mask']
30 |
31 | # append eos
32 | input_ids = F.pad(input_ids, (0, 1), value=tokenizer.eos_token_id)
33 | attention_mask = F.pad(attention_mask, (0, 1), value=0)
34 |
35 | # figure out sentences' endings
36 | diff_padding = th.zeros(input_ids.shape[0], 1, dtype=th.long)
37 | endings = input_ids.eq(tokenizer.eos_token_id).diff(prepend=diff_padding, dim=-1).nonzero(as_tuple=True)
38 |
39 | ratings = tensor(ratings, dtype=th.float32).view(-1, 1)
40 | ratings = (ratings - ratings.mean()) / (ratings.std() + 1e-100)
41 | rewards = ratings.repeat(1, input_ids.shape[1])
42 |
43 | # zero padding
44 | rewards[input_ids.eq(tokenizer.pad_token_id).nonzero(as_tuple=True)] = 0
45 | # refill rewards for the actual eos in case pad == eos
46 | rewards[endings] = ratings.view(-1)
47 |
48 | # prepend bos
49 | input_ids = F.pad(input_ids, (1, 0), value=tokenizer.eos_token_id)
50 | attention_mask = F.pad(attention_mask, (1, 0), value=1)
51 |
52 | out = {'input_ids': input_ids, 'attention_mask': attention_mask, 'rewards': rewards}
53 |
54 | if not os.path.exists(os.path.dirname(cache_path)):
55 | os.mkdir(os.path.dirname(cache_path))
56 |
57 | th.save(out, cache_path)
58 |
59 | print(f"Total {tohuman(np.prod(out['input_ids'].shape))} tokens")
60 | return out
61 |
62 | class AestheticCaptions(TensorDataset):
63 | def __init__(self, tokenizer, max_length=77, n_samples=32, batch_size=1, use_cache=True):
64 | self.max_length = max_length
65 | self.n_samples = n_samples
66 | self.batch_size = batch_size
67 | self.tokenizer = tokenizer
68 |
69 | tensors = load_tensors('cache/aesthetic-captions', self.tokenizer, use_cache=use_cache)
70 | super().__init__(tensors['input_ids'], tensors['attention_mask'], tensors['rewards'])
71 |
72 | def eval(self, logs, model, betas=[1]):
73 | query = tensor([self.tokenizer.eos_token_id] * self.n_samples, device=model.device).view(self.n_samples, 1)
74 |
75 | for beta in betas:
76 | responses = batch_map(
77 | lambda batch: model.sample(query, beta=beta, max_length=self.max_length,
78 | eos_token_id=self.tokenizer.eos_token_id)[0],
79 | query, bsize=self.batch_size, desc='Generating')
80 |
81 | responses = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
82 |
83 | print(f'\n{beta=}\n' + '\n'.join([f'{text}' for text in responses[:16]]))
84 |
85 | logs.update({f'responses/{beta}': wandb.Table(columns=['response'], data=[[r] for r in responses[:32]])})
86 |
87 | return np.inf, {}
88 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch as th
4 | from tqdm import tqdm
5 | import math
6 | from contextlib import contextmanager
7 | import torch.nn.functional as F
8 | from time import time
9 | import deepspeed
10 |
11 | try:
12 | __IPYTHON__
13 | run_from_ipython = True
14 | except NameError:
15 | run_from_ipython = False
16 |
17 | def randexclude(rng: np.random.RandomState, n: int, exclude: int) -> int:
18 | while True:
19 | x = rng.randint(n)
20 | if x != exclude:
21 | return x
22 |
23 | def tohuman(n: int) -> str:
24 | if n > 1e9:
25 | return f'{n / 1e9:.1f}B'
26 | elif n > 1e6:
27 | return f'{n / 1e6:.1f}M'
28 | elif n > 1e3:
29 | return f'{n / 1e3:.1f}K'
30 | return str(n)
31 |
32 | def logvars(name, logs, xs):
33 | xs = th.vstack(xs)
34 | logs.update({ f'{name}-mean': xs.mean(),
35 | f'{name}-std': xs.std(),
36 | f'{name}-min': xs.min(),
37 | f'{name}-max': xs.max() })
38 |
39 | def batch_map(fn, xs, bsize: int, desc=None):
40 | out = []
41 | for ind in tqdm(range(math.ceil(len(xs) / bsize)), desc=desc, disable=not desc):
42 | batch = xs[ind*bsize:min(len(xs), (ind+1)*bsize)]
43 | out.extend(fn(batch))
44 |
45 | return out
46 |
47 | def load_tensors(name, texts, reward_model, tokenizer, max_length=64, use_cache=True):
48 | cache_path = f'cache/{name}_{max_length=}_tokenizer={tokenizer.name_or_path.split("/")[-1]}.pt'
49 | if use_cache and os.path.exists(cache_path):
50 | tensors = th.load(cache_path)
51 | else:
52 | tensors = tokenizer(
53 | [tokenizer.bos_token + x for x in texts],
54 | max_length=max_length,
55 | truncation=True,
56 | padding=True,
57 | return_tensors='pt'
58 | )
59 |
60 | trimmed_texts = tokenizer.batch_decode(tensors['input_ids'], skip_special_tokens=True)
61 | rewards = th.as_tensor(reward_model(trimmed_texts))
62 | rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-30)
63 | rewards = rewards.view(-1, 1).repeat(1, tensors['input_ids'].shape[1])
64 | rewards[tensors['attention_mask'].eq(0)] = 0
65 |
66 | tensors['rewards'] = rewards
67 | tensors['attention_mask'] = F.pad(tensors['attention_mask'], (0, 1), value=0)
68 | tensors['input_ids'] = F.pad(tensors['input_ids'], (0, 1), value=tokenizer.eos_token_id)
69 |
70 | if not os.path.exists(os.path.dirname(cache_path)):
71 | os.mkdir(os.path.dirname(cache_path))
72 |
73 | th.save(tensors, cache_path)
74 |
75 | print(f"{tohuman(np.prod(tensors['input_ids'].shape))} tokens")
76 | return tensors
77 |
78 |
79 | def isdelim(c: str):
80 | return c == '?' or c == '!' or c == '.' or c == ';'
81 |
82 | def pprint(s):
83 | trig = False
84 | si = 0
85 | l = len(s)-1
86 |
87 | for i in range(len(s)):
88 | if i == l:
89 | print(s[si:].strip())
90 |
91 | elif trig or isdelim(s[i]):
92 | trig = True
93 |
94 | if s[i].isspace():
95 | print(s[si:i+1].strip())
96 | si = i + 1
97 | trig = False
98 |
99 | @contextmanager
100 | def timeit(desc='something important'):
101 | print(f'{desc}...')
102 | stime = time()
103 | try:
104 | yield None
105 | finally:
106 | print(f'done with {desc.lower()} in {time() - stime:.1f}s')
107 |
108 | def check_weights(param):
109 | if os.environ.get('DEEPSPEED_ZERO_STAGE', '0') == '3':
110 | with deepspeed.zero.GatheredParameters(param[0].weight, modifier_rank=0):
111 | if deepspeed.comm.get_rank() == 0:
112 | return param[0].weight.sum()
113 | else:
114 | return param[0].weight.sum()
115 |
--------------------------------------------------------------------------------
/randomwalks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 | from torch import tensor
4 | from torch.utils.data import TensorDataset
5 | import torch.nn.functional as F
6 | from utils import randexclude
7 | import networkx as nx
8 |
9 | # Toy dataset from Decision Transformer (Chen et. al 2021)
10 | class RandomWalks:
11 | def __init__(self, n_nodes=20, max_length=10, n_walks=1000, p_edge=0.1, seed=1002):
12 | self.n_nodes = n_nodes
13 | self.n_walks = n_walks
14 | self.max_length = max_length
15 | rng = np.random.RandomState(seed)
16 |
17 | walks, rewards = [], []
18 | while True:
19 | self.adj = rng.rand(n_nodes, n_nodes) > (1 - p_edge)
20 | np.fill_diagonal(self.adj, 0)
21 | if np.all(self.adj.sum(1)): break
22 |
23 | # terminal state
24 | self.adj[0, :] = 0
25 | self.adj[0, 0] = 1
26 |
27 | self.goal = 0
28 | for _ in range(n_walks):
29 | node = randexclude(rng, n_nodes, self.goal)
30 | walk = [node]
31 |
32 | for istep in range(max_length-1):
33 | node = rng.choice(np.nonzero(self.adj[node])[0])
34 | walk.append(node)
35 | if node == self.goal:
36 | break
37 |
38 | r = th.zeros(max_length-1)
39 | r[:len(walk)-1] = -1 if walk[-1] == self.goal else -100
40 |
41 | rewards.append(r)
42 | walks.append(walk)
43 |
44 | states = []
45 | attention_masks = []
46 |
47 | for r, walk in zip(rewards, map(th.tensor, walks)):
48 | attention_mask = th.zeros(max_length, dtype=int)
49 | attention_mask[:len(walk)] = 1
50 |
51 | attention_masks.append(attention_mask)
52 | states.append(F.pad(walk, (0, max_length-len(walk))))
53 |
54 | self.worstlen = self.max_length
55 | self.avglen = sum(map(len, walks)) / self.n_walks
56 | self.bestlen = 0
57 | g = nx.from_numpy_array(self.adj, create_using=nx.DiGraph)
58 | for start in set(range(self.n_nodes)) - {self.goal}:
59 | try:
60 | shortest_path = nx.shortest_path(g, start, self.goal)[:self.max_length]
61 | self.bestlen += len(shortest_path)
62 | except:
63 | self.bestlen += self.max_length
64 |
65 | self.bestlen /= self.n_nodes - 1
66 |
67 | print(f'{self.n_walks} walks of which {(np.array([r[0] for r in rewards])==-1).mean()*100:.0f}% arrived at destination')
68 |
69 | # disallows selecting unaccessible nodes in a graph
70 | self.logit_mask = tensor(~self.adj)
71 |
72 | self.dataset = TensorDataset(th.stack(states), th.stack(attention_masks), th.stack(rewards))
73 | self.eval_dataset = TensorDataset(th.arange(1, self.n_nodes).unsqueeze(1))
74 |
75 | def render(self):
76 | from matplotlib import pyplot
77 |
78 | g = nx.from_numpy_array(self.adj, create_using=nx.DiGraph)
79 | pos = nx.spring_layout(g, seed=7357)
80 |
81 | pyplot.figure(figsize=(10, 8))
82 | nx.draw_networkx_edges(g, pos=pos, alpha=0.5, width=1, edge_color='#d3d3d3')
83 | nx.draw_networkx_nodes(g, nodelist=set(range(len(self.adj))) - {self.goal}, pos=pos, node_size=300, node_color='orange')
84 | nx.draw_networkx_nodes(g, nodelist=[self.goal], pos=pos, node_size=300, node_color='darkblue')
85 | pyplot.show()
86 |
87 | def eval(self, samples, beta):
88 | narrived = 0
89 | actlen = 0
90 | for node in range(self.n_nodes-1):
91 | for istep in range(self.max_length):
92 | if samples[node, istep] == self.goal:
93 | narrived += 1
94 | break
95 |
96 | actlen += (istep + 1) / (self.n_nodes - 1)
97 |
98 | current = (self.worstlen - actlen)/(self.worstlen - self.bestlen)
99 | average = (self.worstlen - self.avglen)/(self.worstlen - self.bestlen)
100 |
101 | stats = { 'actlen': actlen,
102 | 'avglen': self.avglen,
103 | 'bestlen': self.bestlen,
104 | 'worstlen': self.worstlen,
105 | 'arrived': f'{narrived / (self.n_nodes-1) * 100:.0f}%',
106 | 'optimal': f'{current*100:.0f}% > {average*100:.0f}%' }
107 |
108 | return -actlen, stats
109 |
--------------------------------------------------------------------------------
/carps.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch as th
4 | from torch import tensor
5 | import torch.nn.functional as F
6 | from torch.utils.data import Dataset
7 | from functools import partial, reduce
8 | from transformers import AutoTokenizer
9 | from datasets import load_dataset
10 | from util.carp_util import load_carp, scorer
11 | import wandb
12 |
13 | tokenizer = AutoTokenizer.from_pretrained('gpt2')
14 |
15 | carp = load_carp(
16 | model_type='coop',
17 | config_path='ControlledCarp/magiCARP/configs/coop/alignment_coop.yml',
18 | ckpt_path='New_Alignment_CoOp_Carp_L/'
19 | ).to('cuda')
20 | carp.eval()
21 |
22 | def clean_text(text):
23 | return '. '.join(map(
24 | lambda x: x.strip(),
25 | text.replace(' . ', '. ').replace(' , ', ', ').replace(" '", "'").replace(" n't", "n't").split('. ')
26 | ))
27 |
28 | def sizesplit(size: int, xs):
29 | for ind in range(len(xs) // size + int((len(xs) % size) > 0)):
30 | yield xs[ind*size:min(len(xs), (ind+1)*size)]
31 |
32 | def topk_mask(xs, k):
33 | mintop = th.topk(xs, k)[0][:, -1].unsqueeze(-1)
34 | return th.where(xs < mintop, -np.inf * th.ones_like(xs, dtype=xs.dtype), xs)
35 |
36 | def tokenize(max_length, diff_reward, offset_reward, review, sample):
37 | text = clean_text(sample['text'])
38 | tokens = tokenizer.encode(text, return_tensors='pt')[:, :max_length-1]
39 | tokens = F.pad(tokens, (0, max_length-tokens.shape[1]-1), value=tokenizer.eos_token_id)
40 |
41 | if diff_reward:
42 | substrings = []
43 | newtext = ""
44 | for token in tokens[0]:
45 | newtext += tokenizer.decode(token)
46 | substrings.append(newtext)
47 |
48 | rewards = carp_score(substrings, review).cpu()
49 | rewards = th.hstack((tensor([offset_reward]), rewards)).diff()
50 | rewards = th.where(tokens[0] == tokenizer.eos_token_id, 0, rewards)
51 |
52 | else:
53 | r = carp_score(text, review).item()
54 | rewards = th.empty(max_length-1)
55 | rewards.fill_(r)
56 | rewards[tokens[0] == tokenizer.eos_token_id] = 0
57 |
58 | attn = [1] * max_length
59 | attn[-1] = 0
60 | sample['text'] = text
61 | sample['tokens'] = th.hstack((tensor([[tokenizer.eos_token_id]]), tokens))
62 | sample['attention'] = attn
63 | sample['rewards'] = rewards
64 | return sample
65 |
66 | @th.inference_mode()
67 | def carp_score(texts, review):
68 | return scorer(texts, [review], carp, mode='coop').view(-1)
69 |
70 | @th.inference_mode()
71 | def sample(model, query=None, n_samples=128, beta=1, max_length=32, temperature=0.8, top_k=20):
72 | if query is None:
73 | query = tensor([tokenizer.bos_token_id] * n_samples, device=model.device).view(n_samples, 1)
74 |
75 | for _ in range(max_length):
76 | logits, qs, _, vs = model(input_ids=query)
77 | logits = logits[:, -1, :]
78 | qs = qs[:, -1, :]
79 | vs = vs[:, -1, :]
80 |
81 | adv = qs - vs
82 | pi = F.log_softmax(logits, -1)
83 | modpi = topk_mask(pi + beta * adv, top_k)
84 | ps = F.softmax(modpi / temperature, -1)
85 |
86 | tokens = th.multinomial(ps, 1)
87 | query = th.hstack((query, tokens))
88 |
89 | return query
90 |
91 | class Carps(Dataset):
92 | def __init__(self, review='good', max_length=48, diff_reward=True, n_samples=64):
93 | self.review = review
94 | self.max_length = max_length
95 | self.n_samples = n_samples
96 |
97 | cache_path = f'stash/carps-{max_length}l-{diff_reward}d.pt'
98 |
99 | if os.path.exists(cache_path):
100 | cache = th.load(cache_path)
101 | self.tokens = cache['tokens']
102 | self.rewards = cache['rewards']
103 | self.attention_masks = cache['attention_masks']
104 | self.validation_queries = cache['validation_queries']
105 | else:
106 | ds, valid = load_dataset(
107 | 'text',
108 | data_files={'train': 'roc_train_all.txt', 'valid': 'roc_valid.txt'},
109 | split=['train', f'valid[:{n_samples}]'])
110 |
111 | if diff_reward:
112 | vocab = list(tokenizer.get_vocab().keys())
113 | offset = th.hstack([carp_score(words, review) for words in sizesplit(32, vocab)]).mean()
114 | else:
115 | offset = 0
116 |
117 | ds = ds.map(partial(tokenize, max_length, diff_reward, offset, review))
118 | valid = valid.map(partial(tokenize, max_length, diff_reward, offset, review))
119 |
120 | self.tokens = th.tensor(ds['tokens']).squeeze(1)
121 | self.rewards = tensor(ds['rewards'])
122 | self.attention_masks = tensor(ds['attention'])
123 | self.validation_queries = tensor(valid['tokens']).squeeze(1)[:n_samples, :6]
124 |
125 | th.save({ 'tokens': self.tokens,
126 | 'rewards': self.rewards,
127 | 'attention_masks': self.attention_masks,
128 | 'validation_queries': self.validation_queries }, cache_path)
129 |
130 | def __len__(self):
131 | return self.tokens.shape[0]
132 |
133 | def __getitem__(self, ind):
134 | return self.tokens[ind], self.attention_masks[ind], self.rewards[ind]
135 |
136 | def eval(self, logs, model, betas=[1]):
137 | model.eval()
138 | queries = self.validation_queries.to(model.device)
139 |
140 | for beta in betas:
141 | responses = sample(model, query=queries, beta=beta, max_length=self.max_length, n_samples=self.n_samples)
142 | texts = [tokenizer.decode(response[1:]) for response in responses]
143 |
144 | rewards = th.hstack([carp_score(ts, self.review) for ts in sizesplit(8, texts)])
145 | reward = rewards.mean().item()
146 | rows = list(zip(texts, rewards.tolist()))
147 |
148 | print(f'\n{beta=} {reward=:.2f}\n' + '\n'.join([f'[{r:.2f}] {text}' for text, r in rows[:8]]))
149 |
150 | logs[f'reward/beta{beta}'] = reward
151 | logs.update({f'responses/beta{beta}': wandb.Table(columns=['response', 'reward'], rows=rows[:32])})
152 |
153 | stats = {'reward': f'{reward:.2f}'}
154 | model.train()
155 | return reward, stats
156 |
--------------------------------------------------------------------------------
/ilql.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import yaml
4 | from time import time
5 |
6 | import torch as th
7 | from torch import tensor, nn
8 | import torch.nn.functional as F
9 | from torch.utils.data import DataLoader
10 | from transformers import GPT2Config, AutoTokenizer
11 | from accelerate import Accelerator
12 | import numpy as np
13 |
14 | import wandb
15 | from tqdm import tqdm, trange
16 | from utils import run_from_ipython, timeit, check_weights
17 | from models import QVModel
18 | from copy import deepcopy
19 | import accelerate
20 | import deepspeed
21 |
22 | th.set_printoptions(sci_mode=False)
23 |
24 | WORLD_SIZE = int(os.environ.get('WORLD_SIZE', 1))
25 | WORLD_RANK = int(os.environ.get('RANK', 0))
26 | LOCAL_RANK = int(os.environ.get('LOCAL_RANK', 0))
27 |
28 | def main(**args):
29 | task = args['task'] if 'task' in args else 'RandomWalks'
30 | config = yaml.safe_load(open('config.yaml'))[task]
31 | config.update(args)
32 |
33 | accelerator = Accelerator(log_with='wandb')
34 | accelerator.print(os.environ)
35 | device = accelerator.device
36 |
37 | if WORLD_SIZE > 1:
38 | th.distributed.barrier(device_ids=[LOCAL_RANK])
39 | else:
40 | th.random.manual_seed(1000)
41 |
42 | if not config.get('debug', False) and accelerator.is_main_process:
43 | modelname = config.get('model', '')
44 | accelerator.init_trackers(project_name='test-ilql', init_kwargs={
45 | 'wandb': {'name': f'ilql-{task}-{modelname}',
46 | 'mode': 'disabled' if args.get('debug', False) else 'online'}}, config=config)
47 |
48 | config = wandb.config
49 |
50 | if task == 'RandomWalks':
51 | from randomwalks import RandomWalks
52 | data = RandomWalks(seed=config['seed'])
53 | gptconfig = GPT2Config(**config['gptconfig'], vocab_size=data.n_nodes)
54 | model = QVModel(gptconfig, config)
55 |
56 | elif task == 'Sentiments':
57 | from sentiments import Sentiments
58 |
59 | with accelerator.main_process_first():
60 | tokenizer = AutoTokenizer.from_pretrained(config['model'])
61 | tokenizer.pad_token_id = tokenizer.eos_token_id
62 | data = Sentiments(tokenizer, needs_reward_model=accelerator.is_main_process)
63 |
64 | with timeit('init model'):
65 | model = QVModel(config['model'], config)
66 |
67 | elif task == 'Carps':
68 | from carps import Carps
69 | data = Carps(max_length=config['max_length'], diff_reward=config['diff_reward'])
70 | model = QVModel(config['model'], two_qs=config['two_qs']).to(device)
71 |
72 | elif task == 'Captions':
73 | from captions import AestheticCaptions
74 |
75 | tokenizer = AutoTokenizer.from_pretrained(config['model'])
76 | tokenizer.pad_token = tokenizer.eos_token_id
77 | with accelerator.main_process_first():
78 | data = AestheticCaptions(tokenizer, batch_size=config['batch_size'], n_samples=16)
79 |
80 | model = QVModel(config['model'], two_qs=config['two_qs']).to(device)
81 | else:
82 | raise ValueError(f'nonexistent {task=}')
83 |
84 | if hasattr(model.gpt, 'gpt_neox'):
85 | gpt_blocks = list(model.gpt.gpt_neox.layers)[:-config['n_layers_unfrozen']]
86 | else:
87 | gpt_blocks = list(model.gpt.transformer.h)[:-config['n_layers_unfrozen']]
88 |
89 | for m in gpt_blocks:
90 | m.requires_grad_(False)
91 |
92 | train_dataloader = DataLoader(data.dataset, batch_size=config['batch_size'])
93 |
94 | eval_batch_size = max(1, len(data.eval_dataset) // WORLD_SIZE)
95 | eval_dataloader = DataLoader(data.eval_dataset, eval_batch_size)
96 |
97 | opt_cls = (
98 | th.optim.AdamW
99 | if accelerator.state.deepspeed_plugin is None
100 | or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
101 | else accelerate.utils.DummyOptim
102 | )
103 | opt = opt_cls([p for p in model.parameters() if p.requires_grad], lr=config['lr'], betas=config['opt_betas'])
104 |
105 | total_steps = int(config['n_epochs'] * (len(data.dataset) // (config['batch_size'] * WORLD_SIZE)))
106 | n_opt_steps = 0
107 |
108 | with timeit('prepare'):
109 | model, opt, train_dataloader, eval_dataloader = accelerator.prepare(
110 | model, opt, train_dataloader, eval_dataloader
111 | )
112 |
113 | print(f'{WORLD_RANK=}: {model(**accelerator.unwrap_model(model).dummy_inputs)[0].device}')
114 |
115 | model.train()
116 | tbar = trange(total_steps, disable=not accelerator.is_local_main_process)
117 |
118 | for iepoch in range(config['n_epochs']):
119 | for batch in train_dataloader:
120 | logs = {}
121 |
122 | if n_opt_steps % config['steps_for_eval'] == 0:
123 | model.eval()
124 | beta = config['inference_betas'][0]
125 |
126 | all_samples = []
127 | for tokens in eval_dataloader:
128 | tokens = tokens[0].to(device)
129 | with th.no_grad():
130 | samples, stats = accelerator.unwrap_model(model).sample(
131 | tokens,
132 | beta=beta,
133 | max_length=data.max_length,
134 | logit_mask=data.logit_mask
135 | )
136 |
137 | all_samples.append(samples)
138 | logs.update(stats)
139 |
140 | samples = accelerator.gather(th.vstack(all_samples))
141 |
142 | if accelerator.is_main_process:
143 | reward, stats = data.eval(samples, beta)
144 | logs.update(stats)
145 | tbar.set_postfix(stats)
146 |
147 | model.train()
148 |
149 | for ix in range(len(batch)):
150 | batch[ix] = batch[ix].to(device)
151 |
152 | batch_time = time()
153 | forward_time = time()
154 | loss, stats = model.loss(batch)
155 | forward_time = time() - forward_time
156 |
157 | backward_time = time()
158 | accelerator.backward(loss)
159 | backward_time = time() - backward_time
160 |
161 | opt.step()
162 | opt.zero_grad()
163 | n_opt_steps += 1
164 |
165 | batch_time = time() - batch_time
166 | tokens_per_sec = batch[0].numel() * WORLD_SIZE / batch_time
167 | tbar.set_description(f'{tokens_per_sec=:.2f} {batch_time=:.2f}')
168 | tbar.update()
169 |
170 | if (n_opt_steps + 1) % config['steps_for_target_q_sync'] == 0:
171 | accelerator.unwrap_model(model).sync_target_q_heads()
172 |
173 | logs.update(stats)
174 | logs['target_sum'] = check_weights(accelerator.unwrap_model(model).target_q1_head)
175 | logs['batch_time'] = batch_time
176 |
177 | if not config.get('debug', False):
178 | accelerator.log(logs)
179 |
180 | return model, data
181 |
182 | if __name__ == '__main__':
183 | if os.environ.get('LOCAL_RANK'):
184 | os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = os.environ['LOCAL_RANK']
185 |
186 | if run_from_ipython:
187 | args = {'debug': True}
188 | else:
189 | # poor man's argparse
190 | args = {a[2:]: eval(v) for a, v in map(lambda s: s.split('='), sys.argv[1:])}
191 |
192 | main(**args)
193 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch as th
3 | import numpy as np
4 | from torch import tensor, nn
5 | import torch.nn.functional as F
6 | import transformers
7 | from transformers import AutoModelForCausalLM, PretrainedConfig, AutoConfig
8 |
9 | from typing import NamedTuple, Tuple, Union
10 | from copy import deepcopy
11 | from collections import defaultdict
12 | from accelerate.utils import compute_module_sizes
13 | from itertools import chain
14 |
15 | import accelerate
16 | import deepspeed
17 |
18 | def topk_mask(xs: th.FloatTensor, k: int):
19 | mintop = th.topk(xs, k)[0][:, -1].unsqueeze(-1)
20 | return th.where(xs < mintop, -np.inf * th.ones_like(xs, dtype=xs.dtype), xs)
21 |
22 | class QVOutput(Tuple):
23 | logits: th.FloatTensor
24 | qs: th.FloatTensor
25 | target_qs: th.FloatTensor
26 | vs: th.FloatTensor
27 | past_key_values: Tuple[th.FloatTensor]
28 |
29 | def make_head(n_embd: int, out: int):
30 | return nn.Sequential(
31 | nn.Linear(n_embd, n_embd * 2),
32 | nn.ReLU(),
33 | nn.Linear(n_embd * 2, out)
34 | )
35 |
36 | class QVModel(nn.Module):
37 | def __init__(self, config: Union[PretrainedConfig, str], params):
38 | super().__init__()
39 |
40 | # enable zero3 init within from_pretrained
41 | if os.environ.get('DEEPSPEED_ZERO_STAGE', '0') == '3':
42 | config_path = os.environ.get('DEEPSPEED_CONFIG_FILE', '')
43 | if config_path:
44 | _hfconfig = transformers.deepspeed.HfDeepSpeedConfig(config_path)
45 |
46 | if isinstance(config, PretrainedConfig):
47 | self.gpt = AutoModelForCausalLM.from_config(config)
48 | else:
49 | self.gpt = AutoModelForCausalLM.from_pretrained(config)
50 |
51 | if hasattr(self.gpt.config, 'hidden_size'):
52 | self.n_embd = self.gpt.config.hidden_size
53 | else:
54 | self.n_embd = self.gpt.config.n_embd
55 | self.vocab_size = self.gpt.config.vocab_size
56 |
57 | self.v_head = make_head(self.n_embd, 1)
58 | self.q1_head = make_head(self.n_embd, self.vocab_size)
59 | self.target_q1_head = deepcopy(self.q1_head)
60 | self.target_q1_head.requires_grad_(False)
61 |
62 | self.tau = params['tau']
63 | self.alpha = params['alpha']
64 | self.gamma = params['gamma']
65 | self.awac_scale = params['awac_scale']
66 | self.cql_scale = params['cql_scale']
67 | self.two_qs = params['two_qs']
68 |
69 | if self.two_qs:
70 | self.q2_head = make_head(self.n_embd, self.vocab_size)
71 | self.target_q2_head = deepcopy(self.q2_head)
72 | self.target_q2_head.requires_grad_(False)
73 |
74 | def forward(self, **x):
75 | if hasattr(self.gpt, 'gpt_neox'):
76 | out = self.gpt.gpt_neox(**x)
77 | else:
78 | out = self.gpt.transformer(**x)
79 |
80 | hs = out.last_hidden_state
81 |
82 | if self.two_qs:
83 | qs = (self.q1_head(hs), self.q2_head(hs))
84 | target_qs = (self.target_q1_head(hs), self.target_q2_head(hs))
85 | else:
86 | qs = self.q1_head(hs)
87 | target_qs = self.target_q1_head(hs)
88 |
89 | if hasattr(self.gpt, 'gpt_neox'):
90 | logits = self.gpt.embed_out(hs)
91 | else:
92 | logits = self.gpt.lm_head(hs)
93 |
94 | return QVOutput((logits, qs, target_qs, self.v_head(hs), out.past_key_values))
95 |
96 | def loss(self, batch):
97 | tokens, attn, rewards = batch
98 | actions = tokens[:, 1:, None]
99 | isterminal = attn[:, :-1]
100 |
101 | logits, qs, target_qs, vs, _ = self(input_ids=tokens, attention_mask=attn)
102 | bsize, ntokens, dsize = logits.shape
103 |
104 | if self.two_qs:
105 | Q1 = qs[0][:, :-1].gather(-1, actions).squeeze(-1)
106 | Q2 = qs[1][:, :-1].gather(-1, actions).squeeze(-1)
107 |
108 | targetQ1 = target_qs[0][:, :-1].gather(-1, actions).squeeze(-1).detach()
109 | targetQ2 = target_qs[1][:, :-1].gather(-1, actions).squeeze(-1).detach()
110 | targetQ = th.minimum(targetQ1, targetQ2)
111 | else:
112 | Q = qs[:, :-1].gather(-1, actions).squeeze(-1)
113 | targetQ = target_qs[:, :-1].gather(-1, actions).squeeze(-1).detach()
114 |
115 | n_nonterminal = max(1, isterminal.sum())
116 | V = vs[:, 1:].squeeze() * isterminal
117 | Q_ = rewards + self.gamma * V
118 |
119 | if self.two_qs:
120 | loss_q1 = ((Q1 - Q_.detach()) * isterminal).pow(2).sum() / n_nonterminal
121 | loss_q2 = ((Q2 - Q_.detach()) * isterminal).pow(2).sum() / n_nonterminal
122 | loss_q = loss_q1 + loss_q2
123 | else:
124 | loss_q = ((Q - Q_.detach()) * isterminal).pow(2).sum() / n_nonterminal
125 |
126 | loss_v = (((targetQ >= V).int() * self.tau * (targetQ - V).pow(2) + (targetQ < V).int() * (1 - self.tau) * (targetQ - V).pow(2)) * isterminal).sum() / n_nonterminal
127 |
128 | if self.two_qs:
129 | loss_cql_q1 = (F.cross_entropy(qs[0][:, :-1].reshape(-1, dsize), actions.reshape(-1), reduction='none').reshape(bsize, ntokens-1) * isterminal).sum() / n_nonterminal
130 | loss_cql_q2 = (F.cross_entropy(qs[1][:, :-1].reshape(-1, dsize), actions.reshape(-1), reduction='none').reshape(bsize, ntokens-1) * isterminal).sum() / n_nonterminal
131 | loss_cql = loss_cql_q1 + loss_cql_q2
132 | else:
133 | loss_cql = (F.cross_entropy(qs[:, :-1].reshape(-1, dsize), actions.reshape(-1), reduction='none').reshape(bsize, ntokens-1) * isterminal).sum() / n_nonterminal
134 |
135 | loss_awac = (F.cross_entropy(logits[:, :-1].reshape(-1, dsize), actions.reshape(-1), reduction='none').reshape(bsize, ntokens-1) * isterminal).sum() / n_nonterminal
136 |
137 | loss = loss_q + loss_v + self.cql_scale * loss_cql + self.awac_scale * loss_awac
138 | stats = {
139 | k: v for k, v in locals().items() if k in
140 | ['loss', 'loss_v', 'loss_q', 'loss_cql', 'loss_awac']
141 | }
142 |
143 | return loss, stats
144 |
145 | def _sync_target_q_heads(self, alpha):
146 | for target_param, copy_param in zip(self.target_q1_head.parameters(), self.q1_head.parameters()):
147 | target_param.data.copy_((alpha * copy_param.data) + (1.0 - alpha) * target_param.data)
148 |
149 | if self.two_qs:
150 | for target_param, copy_param in zip(self.target_q2_head.parameters(), self.q2_head.parameters()):
151 | target_param.data.copy_((alpha * copy_param.data) + (1.0 - alpha) * target_param.data)
152 |
153 | def sync_target_q_heads(self):
154 | if os.environ.get('DEEPSPEED_ZERO_STAGE', '0') == '3':
155 | params = chain(self.q1_head.parameters(),
156 | self.target_q1_head.parameters(),
157 | self.q2_head.parameters() if self.two_qs else [],
158 | self.target_q2_head.parameters() if self.two_qs else [])
159 |
160 | with deepspeed.zero.GatheredParameters(list(params), modifier_rank=0):
161 | if deepspeed.comm.get_rank() == 0:
162 | self._sync_target_q_heads(self.alpha)
163 | else:
164 | self._sync_target_q_heads(self.alpha)
165 |
166 | @th.inference_mode()
167 | def sample(self, query, beta=1, max_length=32, temperature=1, top_k=20, logit_mask=None, logs=True, eos_token_id=50256):
168 | input = query.clone()
169 | past_key_values = None
170 | tensors = defaultdict(list)
171 |
172 | finished = th.zeros(input.shape[0], 1, dtype=th.long, device=query.device)
173 |
174 | for _ in range(max_length-1):
175 | logits, _, target_qs, vs, past_key_values = self.forward(input_ids=input, past_key_values=past_key_values)
176 |
177 | if self.two_qs:
178 | qs = th.minimum(target_qs[0][:, -1], target_qs[1][:, -1])
179 | else:
180 | qs = target_qs[:, -1]
181 |
182 | logits = logits[:, -1]
183 |
184 | if logit_mask is not None:
185 | logits[th.where(logit_mask[input[:, -1]])] = -np.inf
186 |
187 | adv = qs - vs[:, -1, :]
188 | pi = F.log_softmax(logits, -1)
189 | modpi = topk_mask(pi + beta * adv, top_k)
190 | ps = F.softmax(modpi / temperature, -1)
191 |
192 | tokens = th.multinomial(ps, 1)
193 | tokens = (1 - finished) * tokens + finished * eos_token_id
194 |
195 | query = th.hstack((query, tokens))
196 |
197 | input = tokens
198 | finished = (tokens == eos_token_id).long()
199 |
200 | if logs:
201 | tensors['qs'].append(qs)
202 | tensors['vs'].append(vs)
203 | tensors['adv'].append(adv)
204 |
205 | stats = {}
206 | for name, xs in tensors.items():
207 | xs = th.vstack(xs)
208 | stats.update({
209 | f'{name}-min': xs.min(),
210 | f'{name}-max': xs.max(),
211 | f'{name}-std': xs.std(),
212 | f'{name}-avg': xs.mean(),
213 | })
214 |
215 | return query, stats
216 |
217 | @property
218 | def dummy_inputs(self):
219 | return {'input_ids': th.ones(1, 1, device=self.gpt.device, dtype=th.long)}
220 |
221 | @property
222 | def device(self):
223 | return self.gpt.device
224 |
--------------------------------------------------------------------------------
/scripts/graph_plot.svg:
--------------------------------------------------------------------------------
1 |
2 |
1299 |
--------------------------------------------------------------------------------