├── web.py ├── Dockerfile ├── model.py ├── README.md └── converttotorch.py /web.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from pydantic import BaseModel 3 | import model 4 | 5 | app = FastAPI() 6 | 7 | class Input(BaseModel): 8 | generate_tokens_limit: int = 100 9 | top_p: float = 0.7 10 | top_k: float = 0 11 | temperature: float = 1.0 12 | text: str 13 | 14 | 15 | @app.post("/generate/") 16 | async def generate(input: Input): 17 | # we intentionally make non-await call to model, on GPU implementation it can't be paralelized 18 | # for parallel generation please check running GPT-J on Google TPU https://github.com/kingoflolz/mesh-transformer-jax 19 | try: 20 | output = model.eval(input) 21 | return {"completion": output} 22 | except Exception as e: 23 | return {"error": str(e)} 24 | 25 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.1-base 2 | RUN apt update\ 3 | && apt install -y python3 python3-pip wget git zstd curl\ 4 | && DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt install -y nvidia-cuda-toolkit 5 | RUN wget -c https://mystic.the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd\ 6 | && tar -I zstd -xf step_383500_slim.tar.zstd\ 7 | && rm step_383500_slim.tar.zstd 8 | RUN git clone https://github.com/kingoflolz/mesh-transformer-jax.git 9 | RUN pip3 install -r mesh-transformer-jax/requirements.txt 10 | RUN pip3 install torch mesh-transformer-jax/ jax==0.2.12 jaxlib==0.1.68 -f https://storage.googleapis.com/jax-releases/jax_releases.html 11 | RUN mkdir gpt-j-6B &&\ 12 | curl https://gist.githubusercontent.com/finetuneanon/a55bdb3f5881e361faef0e96e1d41f09/raw/e5a38dad34ff42bbad188afd5e4fdb2ab2eacb6d/gpt-j-6b.json > gpt-j-6B/config.json 13 | COPY converttotorch.py ./ 14 | RUN python3 converttotorch.py 15 | RUN pip3 install fastapi pydantic uvicorn && pip3 install numpy --upgrade && pip3 install git+https://github.com/finetuneanon/transformers@gpt-j 16 | COPY web.py ./ 17 | COPY model.py ./ 18 | CMD uvicorn web:app --port 8080 --host 0.0.0.0 19 | 20 | 21 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from transformers import GPTNeoForCausalLM, AutoConfig, GPT2Tokenizer, AutoTokenizer 2 | import transformers 3 | import torch 4 | from datetime import datetime 5 | 6 | def format_timedelta(td): 7 | seconds = td.total_seconds() 8 | days, remainder = divmod(seconds, 86400) 9 | hours, remainder = divmod(remainder, 3600) 10 | minutes, seconds = divmod(remainder, 60) 11 | if seconds < 1: 12 | return "<1 sec" 13 | return '{} {} {} {}'.format( 14 | "" if int(days) == 0 else str(int(days)) + ' days', 15 | "" if int(hours) == 0 else str(int(hours)) + ' hours', 16 | "" if int(minutes) == 0 else str(int(minutes)) + ' mins', 17 | "" if int(seconds) == 0 else str(int(seconds)) + ' secs' 18 | ) 19 | 20 | 21 | t1 = datetime.now() 22 | tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2') 23 | print('⌚ Model tokenizer created', format_timedelta(datetime.now()-t1)) 24 | 25 | t1 = datetime.now() 26 | model = GPTNeoForCausalLM.from_pretrained('./gpt-j-6B') 27 | print('⌚ Model loaded (.from_pretrained)', format_timedelta(datetime.now()-t1)) 28 | 29 | t1 = datetime.now() 30 | 31 | model.half().cuda() 32 | 33 | print('⌚ Model half().cuda()', format_timedelta(datetime.now()-t1)) 34 | 35 | t1 = datetime.now() 36 | prompt = "Hello my name is Paul and" 37 | input_ids = tokenizer.encode(str(prompt), return_tensors='pt').cuda() 38 | 39 | 40 | output = model.generate( 41 | input_ids, 42 | do_sample=True, 43 | max_length=100, 44 | temperature=0.8, 45 | top_k=0, 46 | top_p=0.7, 47 | ) 48 | print('⌚ Test response time', format_timedelta(datetime.now() - t1)) 49 | print('🤖 Test response', tokenizer.decode(output[0], skip_special_tokens=True)) 50 | 51 | def eval(input): 52 | t1 = datetime.now() 53 | 54 | input_ids = tokenizer.encode(str(input.text), return_tensors='pt').cuda() 55 | token_count = input_ids.size(dim=1) 56 | if token_count + input.generate_tokens_limit > 2048: 57 | raise Exception(f"This model can't generate more then 2048 tokens, you passed {token_count} "+ 58 | f"input tokens and requested to generate {input.generate_tokens_limit} tokens") 59 | output = model.generate( 60 | input_ids, 61 | do_sample=True, 62 | max_length=token_count + input.generate_tokens_limit, 63 | top_p=input.top_p, 64 | top_k=input.top_k, 65 | temperature=input.temperature, 66 | ) 67 | resp = tokenizer.decode(output[0], skip_special_tokens=True) 68 | print(f'⌚ Response time {format_timedelta(datetime.now() - t1)} in len: { len(input.text) } resp len { len(resp) }') 69 | return resp 70 | 71 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Run GPT-J-6B model (text generation open source GPT-3 analog) for inference on server with GPU using zero-dependency Docker image. 2 | 3 | First script loads model into video RAM (can take several minutes) and then runs internal HTTP server which is listening on 8080. 4 | 5 | # Prerequirements to run GPT-J on GPU 6 | 7 | You can run this image only on instance with 16 GB Video memory and Linux (e.g. Ubuntu) 8 | 9 | Server machine should have NVIDIA Driver and Docker daemon with NVIDIA Container Toolkit. See below. 10 | 11 | > Tested on NVIDIA Titan RTX, NVIDIA Tesla P100, 12 | > Not supported: NVIDIA RTX 3090, RTX A5000, RTX A6000. Reasone Cuda+PyTorch coombination: 13 | > CUDA capability sm_86 is not supported, PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70 (we use latest PyTorch during image build), [match sm_x to video card](https://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/) 14 | 15 | ## Install Nvidia Drivers 16 | 17 | You can skip this step if you already have `nvidia-smi` and it outputs the table with CUDA Version: 18 | 19 | ``` 20 | Mon Feb 14 14:28:16 2022 21 | +-----------------------------------------------------------------------------+ 22 | | NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 | 23 | |-------------------------------+----------------------+----------------------+ 24 | | ... 25 | 26 | ``` 27 | 28 | E.g. for Ubuntu 20.04 29 | ``` 30 | apt purge *nvidia* 31 | apt autoremove 32 | add-apt-repository ppa:graphics-drivers/ppa 33 | apt update 34 | apt install -y ubuntu-drivers-common 35 | ubuntu-drivers autoinstall 36 | ``` 37 | 38 | > Note: Unfortunetely NVIDIA drivers installation process might be quite challenging sometimes, e.g. there might be some known issues https://bugs.launchpad.net/ubuntu/+source/nvidia-graphics-drivers-390/+bug/1768050/comments/3, Google helps a lot 39 | 40 | After installing and rebooting, test it with `nvidia-smi`, you should see table. 41 | 42 | ## Install Dockerd with NVIDIA Container Toolkit: 43 | 44 | How to install it on Ubuntu: 45 | 46 | ``` 47 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID) \ 48 | && curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | apt-key add - \ 49 | && curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | tee /etc/apt/sources.list.d/nvidia-docker.list 50 | 51 | apt update && apt -y upgrade 52 | curl https://get.docker.com | sh && systemctl --now restart docker 53 | apt install -y nvidia-docker2 54 | ``` 55 | And reboot server. 56 | 57 | To test that CUDA in Docker works run : 58 | 59 | ``` 60 | docker run --rm --gpus all nvidia/cuda:11.1-base nvidia-smi 61 | ``` 62 | 63 | If all was installed correctly it should show same table as `nvidia-smi` on host. 64 | If you have no NVIDIA Container Toolkit or did not reboot server yet you would get `docker: Error response from daemon: could not select device driver "" with capabilities: [[gpu]]` 65 | 66 | 67 | # Docker command to run image: 68 | 69 | ``` 70 | docker run -p8080:8080 --gpus all --rm -it devforth/gpt-j-6b-gpu 71 | ``` 72 | 73 | > `--gpus all` passes GPU into docker container, so internal bundled cuda instance will smoothly use it 74 | 75 | > Though for apu we are using async FastAPI web server, calls to model which generate a text are blocking, so you should not expect parallelism from this webserver 76 | 77 | Then you can call model by using REST API: 78 | 79 | ``` 80 | POST http://yourServerPublicIP:8080/generate/ 81 | Content-Type: application/json 82 | Body: 83 | 84 | { 85 | "text": "Client: Hi, who are you?\nAI: I am Vincent and I am barista!\nClient: What do you do every day?\nAI:", 86 | "generate_tokens_limit": 40, 87 | "top_p": 0.7, 88 | "top_k": 0, 89 | "temperature":1.0 90 | } 91 | ``` 92 | 93 | 94 | For developemnt clone the repository and run on server: 95 | 96 | ``` 97 | docker run -p8080:8080 --gpus all --rm -it $(docker build -q .) 98 | ``` 99 | 100 | -------------------------------------------------------------------------------- /converttotorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import jax.numpy as jnp 4 | import io 5 | import os 6 | 7 | torch.set_printoptions(linewidth=130, sci_mode=False) 8 | np.set_printoptions(linewidth=130, suppress=True) 9 | 10 | layers = 28 11 | total_shards = 8 12 | ckpt_dir = "step_383500/" 13 | output_dir = "j6b_ckpt" 14 | 15 | def reshard(x, old_shape): 16 | if len(x.shape) == 1: 17 | # print("epoch") 18 | # print(x) 19 | out = x[0:1] 20 | 21 | elif len(x.shape) == 2: 22 | #print(f"LN/bias {x.shape}") 23 | #print(x[:, :16]) 24 | 25 | if (x[1:] == x[-1]).all(): 26 | #print("LN") 27 | if (x[1:] == 0).all() or (x[1:] == 1).all(): 28 | out = x[0:1] 29 | else: 30 | #print("shard bias") 31 | out = x[0:1] * 8#* x.shape[0] / old_shape[0] 32 | else: 33 | #print("bias") 34 | out = x.reshape(old_shape) 35 | 36 | #print(out[:, :16]) 37 | 38 | elif len(x.shape) == 3: 39 | #print(f"weight {x.shape}") 40 | if x.shape[0] * x.shape[2] == old_shape[2]: 41 | #print("case 1") 42 | out = jnp.transpose(x, (1, 0, 2)).reshape(old_shape) 43 | elif x.shape[0] * x.shape[1] == old_shape[1]: 44 | #print("case 2") 45 | out = x.reshape(old_shape) 46 | else: 47 | raise Exception(f"unimplemented, {x.shape}, {old_shape}") 48 | else: 49 | raise Exception(f"unimplemented, {x}") 50 | #flattened, structure = jax.tree_flatten(out) 51 | #return flattened 52 | return out 53 | 54 | def get_old_shape(t, dim=2): 55 | if len(t.shape) == 3: 56 | shard_shape = t.shape 57 | if dim == 1: 58 | return (shard_shape[0] * shard_shape[1], shard_shape[2]) 59 | elif dim == 2: 60 | return (shard_shape[1], shard_shape[0] * shard_shape[2]) 61 | else: 62 | raise ValueError(f"unsupported dim {dim}") 63 | if len(t.shape) == 2: 64 | return (t.shape[1] * t.shape[0],) 65 | else: 66 | raise ValueError(f"unsupported shape {t.shape}") 67 | 68 | def read_shard(ckpt_dir): 69 | global part 70 | out = [] 71 | idx = part 72 | file_path = ckpt_dir + f"{idx}.npz" 73 | #print(f"-- {file_path}") 74 | with open(file_path, "rb") as f: 75 | buf = f.read() 76 | f_io = io.BytesIO(buf) 77 | deserialized = np.load(f_io) 78 | for i in deserialized: 79 | out.append(deserialized[i]) 80 | #print(deserialized[i].shape) 81 | return out 82 | 83 | def save(ckpt): 84 | try: os.mkdir(output_dir) 85 | except: pass 86 | checkpoint = {} 87 | for i, x in enumerate(ckpt.items()): 88 | checkpoint[x[0]] = f"{output_dir}/b{i}.pt" 89 | torch.save(x[1], f"{output_dir}/b{i}.pt") 90 | torch.save(checkpoint, f"{output_dir}/m.pt") 91 | 92 | unshard = None 93 | transforms = [("transformer.wte.bias", None, None), ("transformer.wte.weight", unshard, 1)] 94 | 95 | checkpoint = {} 96 | 97 | layer_names = sorted(map(str, range(layers))) 98 | for layer in layer_names: 99 | checkpoint[f"transformer.h.{layer}.attn.attention.bias"] = torch.tril(torch.ones(1, 1, 2048, 2048)) 100 | checkpoint[f"transformer.h.{layer}.attn.attention.masked_bias"] = torch.tensor(-1e9) 101 | transforms.extend([ 102 | (f"transformer.h.{layer}.attn.attention.q_proj.weight", unshard, 2), 103 | (f"transformer.h.{layer}.attn.attention.v_proj.weight", unshard, 2), 104 | (f"transformer.h.{layer}.attn.attention.k_proj.weight", unshard, 2), 105 | (f"transformer.h.{layer}.attn.attention.out_proj.weight", unshard, 1), 106 | (f"transformer.h.{layer}.mlp.c_fc.bias", unshard, 1), 107 | (f"transformer.h.{layer}.mlp.c_fc.weight", unshard, 2), 108 | (f"transformer.h.{layer}.mlp.c_proj.bias", None, None), 109 | (f"transformer.h.{layer}.mlp.c_proj.weight", unshard, 1), 110 | (f"transformer.h.{layer}.ln_1.bias", None, None), 111 | (f"transformer.h.{layer}.ln_1.weight", None, None), 112 | ]) 113 | transforms.extend([ 114 | ("lm_head.bias", unshard, 1), 115 | ("lm_head.weight", unshard, 2), 116 | ("transformer.ln_f.bias", None, None), 117 | ("transformer.ln_f.weight", None, None), 118 | ]) 119 | 120 | part = 0 121 | element = 0 122 | while len(transforms) > 0: 123 | print(f"loading shards for part {part}") 124 | shards = list(map(read_shard, [f"{ckpt_dir}shard_{i}/" for i in range(total_shards)])) 125 | print(f"read from checkpoint") 126 | 127 | unsharded = [] 128 | 129 | for all_shards in zip(*shards): 130 | x = np.stack(all_shards) 131 | # No idea why this is V2...? 132 | if x.dtype == np.dtype('V2'): 133 | x.dtype = jnp.bfloat16 134 | x = x.astype(np.float32) 135 | unsharded.append(x) 136 | #print(f"unsharded: {x.shape}") 137 | 138 | while len(transforms) > 0 and len(unsharded) > 0: 139 | transform = transforms.pop(0) 140 | params = unsharded.pop(0) 141 | if transform[2] is not None: 142 | old_shape = (1,) + get_old_shape(params, transform[2]) 143 | else: 144 | old_shape = (params.shape[1],) 145 | print(f"< {params.shape} to {old_shape}") 146 | params = reshard(params, old_shape).squeeze(0).T 147 | params = torch.tensor(params.copy()).half() 148 | if params.isnan().any() or params.isinf().any(): 149 | raise ValueError(f"fp16 over/underflow at {part} {element}") 150 | checkpoint[transform[0]] = params 151 | print(f"> {transform[0]} {params.shape}") 152 | element += 1 153 | part += 1 154 | 155 | checkpoint['transformer.wte.weight'] = (checkpoint['transformer.wte.weight'].T + checkpoint['transformer.wte.bias']) 156 | del checkpoint['transformer.wte.bias'] 157 | 158 | print(f"left over: {unsharded}") 159 | print("saving") 160 | torch.save(checkpoint, "./gpt-j-6B/pytorch_model.bin") # load as in: https://github.com/finetuneanon/misc/blob/main/SizeTest.ipynb 161 | print("done") 162 | --------------------------------------------------------------------------------