├── .gitignore ├── README.md ├── instant_apply_mlx.py ├── requirements.txt ├── sample_edit.txt └── sample_target.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .venvs 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Instant Apply Proof-of-Concept. 2 | 3 | Implementation of Cursor's [Instant Apply](https://www.cursor.com/blog/instant-apply) feature. 4 | 5 | Achieves close to max parallelism during generation (generation tok/s approx 70% of prefill tok/s on my laptop). 6 | 7 | Only MLX (Apple devices) supported for now. 8 | 9 | ## Usage 10 | 11 | ### With `pip` 12 | ```sh 13 | pip install -r requirements.txt # preferably in a venv 14 | python instant_apply_mlx.py 15 | ``` 16 | 17 | ### With `uv` 18 | ```sh 19 | ./instant_apply_mlx.py 20 | ``` 21 | 22 | ## How it works 23 | 24 | The core logic is in these lines: 25 | ```python 26 | draft: list[int] = [] 27 | target_idx = target_edit_dist.index(min(target_edit_dist)) 28 | if target_idx > 0 and token == target_tokens[target_idx - 1]: 29 | draft = target_tokens[target_idx:] 30 | else: 31 | edit_idx = edit_edit_dist.index(min(edit_edit_dist)) 32 | if edit_idx > 0 and token == edit_tokens[edit_idx - 1]: 33 | draft = edit_tokens[edit_idx:] 34 | else: 35 | # to recover quickly from the LLM deleting a large chunk of text 36 | # (otherwise keeps drafting from pre-deletion position due to edit dist) 37 | target_idx = min( 38 | (i for i, t in enumerate(target_tokens) if t == token), 39 | default=0, 40 | key=lambda i: target_edit_dist[i + 1], 41 | ) 42 | draft = target_tokens[target_idx + 1 :] 43 | ``` 44 | 45 | This generally mispredicts only once per hunk, which enables turning up the speculative lookahead to the point of FLOPS saturation. 46 | -------------------------------------------------------------------------------- /instant_apply_mlx.py: -------------------------------------------------------------------------------- 1 | # requires-python = ">=3.12" 2 | # dependencies = [ 3 | # "mlx-lm~=0.17.1", 4 | # ] 5 | # /// 6 | 7 | import argparse 8 | import time 9 | 10 | import mlx_lm 11 | import mlx.core as mx 12 | import mlx.nn as nn 13 | from mlx_lm.models.base import KVCache 14 | 15 | 16 | def main() -> None: 17 | parser = argparse.ArgumentParser( 18 | description="Instant Apply, from https://www.cursor.com/blog/instant-apply" 19 | ) 20 | _ = parser.add_argument( 21 | "model", type=str, help="Example: mlx-community/Meta-Llama-3.1-8B-8bit" 22 | ) 23 | _ = parser.add_argument("target", type=str, help="Example: sample_target.py") 24 | _ = parser.add_argument("edit", type=str, help="Example: sample_edit.py") 25 | _ = parser.add_argument("--speculation-lookahead", type=int, default=64) 26 | _ = parser.add_argument("--max-tokens", type=int, default=4096) 27 | args = parser.parse_args() 28 | 29 | model, tokenizer = mlx_lm.load(args.model) 30 | with open(args.target) as target_file, open(args.edit) as edit_file: 31 | target, edit = target_file.read(), edit_file.read() 32 | target_tokens, edit_tokens = tokenizer.encode(target), tokenizer.encode(edit) 33 | target_edit_dist = list(range(len(target_tokens) + 1)) 34 | edit_edit_dist = list(range(len(edit_tokens) + 1)) 35 | 36 | try: # instruct model 37 | prompt = tokenizer.apply_chat_template( 38 | [ 39 | { 40 | "role": "user", 41 | "content": f"Apply to the following file:\n```\n{target}\n```\nthe following edit:\n```\n{edit}\n```\nRespond with only the full modified file (no omissions), Markdown fenced. The content from the edit MUST replace the content from the target where applicable.", 42 | } 43 | ], 44 | tokenize=True, 45 | add_generation_prompt=True, 46 | ) 47 | except ValueError: # base model 48 | prompt = tokenizer.encode( 49 | f"The original source code was:\n```\n{target}\n```\nAfter applying the following edit:\n```\n{edit}\n```\nthe new code was the following, which differs from the original code where indicated by the edit:" 50 | ) 51 | prompt = mx.array(prompt)[None] 52 | prompt_len = prompt.shape[1] 53 | cache = create_cache(model) 54 | detokenizer = tokenizer.detokenizer 55 | detokenizer.reset() 56 | tic = time.perf_counter() 57 | prompt_time = float("inf") 58 | token = 0 59 | n_tokens = 0 60 | 61 | for n in range(args.max_tokens): 62 | draft: list[int] = [] 63 | target_idx = target_edit_dist.index(min(target_edit_dist)) 64 | if target_idx > 0 and token == target_tokens[target_idx - 1]: 65 | draft = target_tokens[target_idx:] 66 | else: 67 | edit_idx = edit_edit_dist.index(min(edit_edit_dist)) 68 | if edit_idx > 0 and token == edit_tokens[edit_idx - 1]: 69 | draft = edit_tokens[edit_idx:] 70 | else: 71 | # to recover quickly from the LLM deleting a large chunk of text 72 | # (otherwise keeps drafting from pre-deletion position due to edit dist) 73 | target_idx = min( 74 | (i for i, t in enumerate(target_tokens) if t == token), 75 | default=0, 76 | key=lambda i: target_edit_dist[i + 1], 77 | ) 78 | draft = target_tokens[target_idx + 1 :] 79 | draft = draft[: args.speculation_lookahead] or [0] 80 | draft_toks = mx.array(draft)[None] 81 | input_toks = mx.concatenate([prompt, draft_toks[:, :-1]], axis=-1) 82 | logits = model(input_toks, cache=cache) 83 | logits = logits[:, prompt.shape[1] - 1 :, :] 84 | output_toks = logits.argmax(axis=-1) 85 | n_accepted = (output_toks == draft_toks).astype(mx.uint8).cummin().sum().item() 86 | n_used = min(n_accepted + 1, len(draft)) 87 | break_flag = False 88 | for i in range(n_used): 89 | prompt = output_toks[:, i : i + 1] 90 | token = prompt.item() 91 | detokenizer.add_token(token) 92 | n_tokens += 1 93 | if token == tokenizer.eos_token_id: 94 | break_flag = True 95 | break 96 | update_edit_dists(target_edit_dist, target_tokens, token) 97 | update_edit_dists(edit_edit_dist, edit_tokens, token) 98 | if break_flag: 99 | break 100 | for c in cache: 101 | drop_from_cache(c, len(draft) - n_used) 102 | print(detokenizer.last_segment, end="", flush=True) 103 | if n == 0: 104 | prompt_time = time.perf_counter() - tic 105 | tic = time.perf_counter() 106 | 107 | gen_time = time.perf_counter() - tic 108 | print(detokenizer.last_segment) 109 | print(f"Prompt processing: {prompt_len / prompt_time} tokens-per-second") 110 | print(f"Generation: {n_tokens / gen_time} tokens-per-second") 111 | 112 | 113 | def create_cache(model: nn.Module) -> list[KVCache]: 114 | if hasattr(model, "make_cache"): 115 | return model.make_cache() 116 | else: 117 | kv_heads = ( 118 | [model.n_kv_heads] * len(model.layers) 119 | if isinstance(model.n_kv_heads, int) 120 | else model.n_kv_heads 121 | ) 122 | return [KVCache(model.head_dim, n) for n in kv_heads] 123 | 124 | 125 | def drop_from_cache(cache: KVCache, n: int): 126 | if n >= cache.offset: 127 | cache.keys = cache.values = None 128 | cache.offset = 0 129 | elif n > 0: 130 | cache.offset -= n 131 | 132 | 133 | def update_edit_dists(edit_dist: list[int], tokens: list[int], token: int) -> None: 134 | prev = edit_dist[0] 135 | edit_dist[0] += 1 136 | for i in range(len(tokens)): 137 | cur = edit_dist[i + 1] 138 | edit_dist[i + 1] = ( 139 | prev if token == tokens[i] else 1 + min(prev, cur, edit_dist[i]) 140 | ) 141 | prev = cur 142 | 143 | 144 | if __name__ == "__main__": 145 | main() 146 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mlx-lm~=0.17.1 2 | -------------------------------------------------------------------------------- /sample_edit.txt: -------------------------------------------------------------------------------- 1 | # In the class docstring: 2 | tie_embeddings (`bool`, *optional*, defaults to `False`): 3 | Whether to tie weight embeddings 4 | 5 | # In the __init__ method: 6 | def __init__( 7 | self, 8 | # ... other parameters ... 9 | tie_embeddings=False, 10 | # ... rest of the parameters ... 11 | ): 12 | # ... other initializations ... 13 | super().__init__( 14 | pad_token_id=pad_token_id, 15 | bos_token_id=bos_token_id, 16 | eos_token_id=eos_token_id, 17 | tie_word_embeddings=tie_embeddings, 18 | **kwargs, 19 | ) 20 | -------------------------------------------------------------------------------- /sample_target.txt: -------------------------------------------------------------------------------- 1 | from ...configuration_utils import PretrainedConfig 2 | from ...modeling_rope_utils import rope_config_validation 3 | 4 | 5 | class LlamaConfig(PretrainedConfig): 6 | r""" 7 | This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA 8 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 9 | defaults will yield a similar configuration to that of the LLaMA-7B. 10 | 11 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 12 | documentation from [`PretrainedConfig`] for more information. 13 | 14 | 15 | Args: 16 | vocab_size (`int`, *optional*, defaults to 32000): 17 | Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the 18 | `inputs_ids` passed when calling [`LlamaModel`] 19 | hidden_size (`int`, *optional*, defaults to 4096): 20 | Dimension of the hidden representations. 21 | intermediate_size (`int`, *optional*, defaults to 11008): 22 | Dimension of the MLP representations. 23 | num_hidden_layers (`int`, *optional*, defaults to 32): 24 | Number of hidden layers in the Transformer decoder. 25 | num_attention_heads (`int`, *optional*, defaults to 32): 26 | Number of attention heads for each attention layer in the Transformer decoder. 27 | num_key_value_heads (`int`, *optional*): 28 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 29 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 30 | `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When 31 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 32 | by meanpooling all the original heads within that group. For more details checkout [this 33 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 34 | `num_attention_heads`. 35 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 36 | The non-linear activation function (function or string) in the decoder. 37 | max_position_embeddings (`int`, *optional*, defaults to 2048): 38 | The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, 39 | Llama 2 up to 4096, CodeLlama up to 16384. 40 | initializer_range (`float`, *optional*, defaults to 0.02): 41 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 42 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 43 | The epsilon used by the rms normalization layers. 44 | use_cache (`bool`, *optional*, defaults to `True`): 45 | Whether or not the model should return the last key/values attentions (not used by all models). Only 46 | relevant if `config.is_decoder=True`. 47 | pad_token_id (`int`, *optional*): 48 | Padding token id. 49 | bos_token_id (`int`, *optional*, defaults to 1): 50 | Beginning of stream token id. 51 | eos_token_id (`int`, *optional*, defaults to 2): 52 | End of stream token id. 53 | pretraining_tp (`int`, *optional*, defaults to 1): 54 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this 55 | document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to 56 | understand more about it. This value is necessary to ensure exact reproducibility of the pretraining 57 | results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). 58 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 59 | Whether to tie weight embeddings 60 | rope_theta (`float`, *optional*, defaults to 10000.0): 61 | The base period of the RoPE embeddings. 62 | rope_scaling (`Dict`, *optional*): 63 | Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type 64 | and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value 65 | accordingly. 66 | Expected contents: 67 | `rope_type` (`str`): 68 | The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 69 | 'llama3'], with 'default' being the original RoPE implementation. 70 | `factor` (`float`, *optional*): 71 | Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In 72 | most scaling types, a `factor` of x will enable the model to handle sequences of length x * 73 | original maximum pre-trained length. 74 | `original_max_position_embeddings` (`int`, *optional*): 75 | Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during 76 | pretraining. 77 | `attention_factor` (`float`, *optional*): 78 | Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention 79 | computation. If unspecified, it defaults to value recommended by the implementation, using the 80 | `factor` field to infer the suggested value. 81 | `beta_fast` (`float`, *optional*): 82 | Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear 83 | ramp function. If unspecified, it defaults to 32. 84 | `beta_slow` (`float`, *optional*): 85 | Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear 86 | ramp function. If unspecified, it defaults to 1. 87 | `short_factor` (`List[float]`, *optional*): 88 | Only used with 'longrope'. The scaling factor to be applied to short contexts (< 89 | `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden 90 | size divided by the number of attention heads divided by 2 91 | `long_factor` (`List[float]`, *optional*): 92 | Only used with 'longrope'. The scaling factor to be applied to long contexts (< 93 | `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden 94 | size divided by the number of attention heads divided by 2 95 | `low_freq_factor` (`float`, *optional*): 96 | Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE 97 | `high_freq_factor` (`float`, *optional*): 98 | Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE 99 | attention_bias (`bool`, *optional*, defaults to `False`): 100 | Whether to use a bias in the query, key, value and output projection layers during self-attention. 101 | attention_dropout (`float`, *optional*, defaults to 0.0): 102 | The dropout ratio for the attention probabilities. 103 | mlp_bias (`bool`, *optional*, defaults to `False`): 104 | Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. 105 | head_dim (`int`, *optional*): 106 | The attention head dimension. If None, it will default to hidden_size // num_heads 107 | 108 | ```python 109 | >>> from transformers import LlamaModel, LlamaConfig 110 | 111 | >>> # Initializing a LLaMA llama-7b style configuration 112 | >>> configuration = LlamaConfig() 113 | 114 | >>> # Initializing a model from the llama-7b style configuration 115 | >>> model = LlamaModel(configuration) 116 | 117 | >>> # Accessing the model configuration 118 | >>> configuration = model.config 119 | ```""" 120 | 121 | model_type = "llama" 122 | keys_to_ignore_at_inference = ["past_key_values"] 123 | 124 | def __init__( 125 | self, 126 | vocab_size=32000, 127 | hidden_size=4096, 128 | intermediate_size=11008, 129 | num_hidden_layers=32, 130 | num_attention_heads=32, 131 | num_key_value_heads=None, 132 | hidden_act="silu", 133 | max_position_embeddings=2048, 134 | initializer_range=0.02, 135 | rms_norm_eps=1e-6, 136 | use_cache=True, 137 | pad_token_id=None, 138 | bos_token_id=1, 139 | eos_token_id=2, 140 | pretraining_tp=1, 141 | tie_word_embeddings=False, 142 | rope_theta=10000.0, 143 | rope_scaling=None, 144 | attention_bias=False, 145 | attention_dropout=0.0, 146 | mlp_bias=False, 147 | head_dim=None, 148 | **kwargs, 149 | ): 150 | self.vocab_size = vocab_size 151 | self.max_position_embeddings = max_position_embeddings 152 | self.hidden_size = hidden_size 153 | self.intermediate_size = intermediate_size 154 | self.num_hidden_layers = num_hidden_layers 155 | self.num_attention_heads = num_attention_heads 156 | 157 | # for backward compatibility 158 | if num_key_value_heads is None: 159 | num_key_value_heads = num_attention_heads 160 | 161 | self.num_key_value_heads = num_key_value_heads 162 | self.hidden_act = hidden_act 163 | self.initializer_range = initializer_range 164 | self.rms_norm_eps = rms_norm_eps 165 | self.pretraining_tp = pretraining_tp 166 | self.use_cache = use_cache 167 | self.rope_theta = rope_theta 168 | self.rope_scaling = rope_scaling 169 | self.attention_bias = attention_bias 170 | self.attention_dropout = attention_dropout 171 | self.mlp_bias = mlp_bias 172 | self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads 173 | # Validate the correctness of rotary position embeddings parameters 174 | # BC: if there is a 'type' field, move it to 'rope_type'. 175 | if self.rope_scaling is not None and "type" in self.rope_scaling: 176 | self.rope_scaling["rope_type"] = self.rope_scaling["type"] 177 | rope_config_validation(self) 178 | 179 | super().__init__( 180 | pad_token_id=pad_token_id, 181 | bos_token_id=bos_token_id, 182 | eos_token_id=eos_token_id, 183 | tie_word_embeddings=tie_word_embeddings, 184 | **kwargs, 185 | ) 186 | --------------------------------------------------------------------------------