├── NOTICE ├── .pre-commit-config.yaml ├── pyproject.toml ├── rnd ├── __init__.py ├── generation_config.py ├── configuration_rnd.py ├── generation_utils.py ├── terminal_visualizer.py ├── sampling.py └── modeling_rnd.py ├── README.md ├── .gitignore ├── assets └── rn-logo-desktop-vector.svg ├── LICENSE └── demo_rnd_generation.py /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright 2025 Radical Numerics Inc. 2 | 3 | This source code is licensed under the Apache License, Version 2.0, found in the 4 | LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Usage 2 | # uv run pre-commit install 3 | # uv run pre-commit run --all-files 4 | 5 | repos: 6 | - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: v6.0.0 8 | hooks: 9 | - id: check-added-large-files 10 | - id: check-case-conflict 11 | - id: check-merge-conflict 12 | - id: check-symlinks 13 | - id: mixed-line-ending 14 | - id: trailing-whitespace 15 | 16 | - repo: https://github.com/astral-sh/ruff-pre-commit 17 | rev: v0.13.2 18 | hooks: 19 | - id: ruff-format 20 | - id: ruff-check 21 | args: [--fix] 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "rnd" 7 | version = "0.1.0" 8 | dependencies = [ 9 | "accelerate", 10 | "torch>=2.8", 11 | "transformers", 12 | "rich" 13 | ] 14 | 15 | [project.optional-dependencies] 16 | flashinfer = [ 17 | "flashinfer-python", 18 | ] 19 | sglang = ["sglang[all]"] 20 | vllm = ["vllm"] 21 | linting = ["ruff>=0.13.2", "pre-commit>=4.0.0"] 22 | 23 | [tool.setuptools] 24 | packages = ["rnd"] 25 | 26 | [tool.ruff] 27 | line-length = 120 28 | target-version = "py311" 29 | show-fixes = false 30 | extend-exclude = ["*.ipynb"] 31 | 32 | [tool.ruff.lint] 33 | select = ["F", "E", "W", "I001", "UP"] 34 | task-tags = ["TODO", "FIXME"] 35 | 36 | [tool.ruff.lint.per-file-ignores] 37 | "__init__.py" = ["E402", "F401"] 38 | 39 | [tool.ruff.lint.isort] 40 | known-first-party = [] 41 | known-third-party = [] 42 | section-order = [ 43 | "future", 44 | "standard-library", 45 | "third-party", 46 | "first-party", 47 | "local-folder", 48 | ] 49 | combine-as-imports = true 50 | split-on-trailing-comma = false 51 | lines-between-types = 1 -------------------------------------------------------------------------------- /rnd/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Radical Numerics Inc. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0, found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Radical Numerics Diffusion (RND1) - Diffusion-based Language Model. 8 | """ 9 | 10 | from .configuration_rnd import RND1Config 11 | from .generation_config import RND1GenerationConfig 12 | from .generation_utils import RND1GenerationMixin 13 | from .modeling_rnd import RND1LM, RND1Attention, RND1DecoderLayer, RND1Model, RND1PreTrainedModel, RND1SparseMoeBlock 14 | from .sampling import apply_top_k_filtering, apply_top_p_filtering, diffusion_sample 15 | from .terminal_visualizer import SimpleProgressBar, TerminalVisualizer 16 | 17 | __version__ = "0.1.0" 18 | 19 | __all__ = [ 20 | "RND1Config", 21 | "RND1GenerationConfig", 22 | "RND1LM", 23 | "RND1Model", 24 | "RND1PreTrainedModel", 25 | "RND1Attention", 26 | "RND1DecoderLayer", 27 | "RND1SparseMoeBlock", 28 | "RND1GenerationMixin", 29 | "TerminalVisualizer", 30 | "SimpleProgressBar", 31 | ] 32 | 33 | # Register with HuggingFace Auto classes for local usage 34 | try: 35 | from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM 36 | 37 | AutoConfig.register("rnd1", RND1Config) 38 | AutoModel.register(RND1Config, RND1Model) 39 | AutoModelForMaskedLM.register(RND1Config, RND1LM) 40 | except ImportError: 41 | # transformers not available or Auto classes not imported 42 | pass 43 | -------------------------------------------------------------------------------- /rnd/generation_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Radical Numerics Inc. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0, found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | RND1 Generation Configuration. 8 | 9 | This module defines the generation configuration for RND1 models, 10 | controlling the diffusion-based generation process. 11 | """ 12 | 13 | from transformers.generation.configuration_utils import GenerationConfig 14 | 15 | 16 | class RND1GenerationConfig(GenerationConfig): 17 | """ 18 | Configuration class for RND1 generation parameters. 19 | 20 | This class extends the base GenerationConfig to include parameters 21 | specific to diffusion-based language generation. 22 | 23 | Args: 24 | max_length: Maximum sequence length 25 | num_diffusion_steps: Number of denoising steps in the diffusion process 26 | mask_token_id: Token ID used for masking during diffusion 27 | temperature: Temperature for sampling (higher = more random) 28 | top_k: Optional top-k filtering 29 | top_p: Optional nucleus (top-p) filtering 30 | greedy: Whether to use greedy decoding (True) or stochastic sampling (False) 31 | **kwargs: Additional arguments passed to GenerationConfig 32 | """ 33 | 34 | def __init__( 35 | self, 36 | max_length: int = 256, 37 | num_diffusion_steps: int = 256, 38 | mask_token_id: int = 151669, 39 | temperature: float = 0.1, 40 | top_k: int | None = None, 41 | top_p: float | None = None, 42 | greedy: bool = False, 43 | bos_token_id: int = None, 44 | eos_token_id: int = None, 45 | pad_token_id: int = None, 46 | use_cache: bool = False, 47 | **kwargs, 48 | ): 49 | # Force no caching for RND generation 50 | # kwargs['use_cache'] = False 51 | kwargs.pop("use_cache", None) 52 | super().__init__( 53 | max_length=max_length, 54 | bos_token_id=bos_token_id, 55 | eos_token_id=eos_token_id, 56 | pad_token_id=pad_token_id, 57 | temperature=temperature, 58 | top_k=top_k, 59 | top_p=top_p, 60 | do_sample=not greedy, 61 | use_cache=False, 62 | **kwargs, 63 | ) 64 | 65 | # RND-specific parameters 66 | self.num_diffusion_steps = num_diffusion_steps 67 | self.mask_token_id = mask_token_id 68 | self.greedy = greedy 69 | 70 | def to_dict(self): 71 | """Convert configuration to dictionary.""" 72 | output = super().to_dict() 73 | output["num_diffusion_steps"] = self.num_diffusion_steps 74 | output["mask_token_id"] = self.mask_token_id 75 | output["greedy"] = self.greedy 76 | return output 77 | -------------------------------------------------------------------------------- /rnd/configuration_rnd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Radical Numerics Inc. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0, found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | RND1 Model Configuration. 8 | 9 | This module defines the configuration class for RND1 models. 10 | The default settings are derived from Qwen/Qwen3-30B-A3B and augmented 11 | with RND1-specific parameters. 12 | """ 13 | 14 | from transformers.configuration_utils import PretrainedConfig 15 | 16 | # Qwen3-30B-A3B / checkpoint defaults 17 | CONFIG_DEFAULTS = { 18 | "attention_bias": False, 19 | "attention_dropout": 0.0, 20 | "decoder_sparse_step": 1, 21 | "eos_token_id": 151645, 22 | "head_dim": 128, 23 | "hidden_act": "silu", 24 | "hidden_size": 2048, 25 | "initializer_range": 0.02, 26 | "intermediate_size": 6144, 27 | "max_position_embeddings": 40960, 28 | "max_window_layers": 48, 29 | "mlp_only_layers": [], 30 | "moe_intermediate_size": 768, 31 | "norm_topk_prob": True, 32 | "num_attention_heads": 32, 33 | "num_experts": 128, 34 | "num_experts_per_tok": 8, 35 | "num_hidden_layers": 48, 36 | "num_key_value_heads": 4, 37 | "output_router_logits": False, 38 | "pad_token_id": 151643, 39 | "rms_norm_eps": 1e-06, 40 | "rope_scaling": False, 41 | "rope_theta": 1000000.0, 42 | "router_aux_loss_coef": 0.001, 43 | "sliding_window": False, 44 | "tie_word_embeddings": False, 45 | "dtype": "bfloat16", 46 | "use_cache": False, 47 | "use_sliding_window": False, 48 | "vocab_size": 151936, 49 | } 50 | 51 | 52 | class RND1Config(PretrainedConfig): 53 | """ 54 | Configuration class for RND1 models. 55 | 56 | This configuration extends Qwen3MoeConfig with additional parameters 57 | specific to the RND1 (Radical Numerics Diffusion v1) architecture. 58 | 59 | Args: 60 | moe_backend: Backend for MoE computation ("hf", "vllm", "sglang" or "flashinfer") 61 | num_diffusion_steps: Default number of diffusion steps for generation 62 | mask_token_id: Token ID used for masking (default: 151669 for Qwen) 63 | **kwargs: Additional arguments passed to Qwen3MoeConfig 64 | """ 65 | 66 | model_type = "rnd1" 67 | 68 | def __init__( 69 | self, 70 | moe_backend: str = "hf", 71 | num_diffusion_steps: int = 256, 72 | mask_token_id: int = 151669, 73 | **kwargs, 74 | ): 75 | # Force non-causal and no caching for RND1 76 | kwargs["use_cache"] = False 77 | kwargs["is_causal"] = False 78 | 79 | super().__init__(**kwargs) 80 | 81 | # Set defaults after pretrained init to prevent overrides 82 | self.set_config_defaults() 83 | 84 | # QoL: set attn impl directly from config 85 | if "attn_implementation" in kwargs: 86 | self._attn_implementation = kwargs["attn_implementation"] 87 | 88 | # RND1-specific parameters 89 | self.moe_backend = moe_backend 90 | self.num_diffusion_steps = num_diffusion_steps 91 | self.mask_token_id = mask_token_id 92 | 93 | # Ensure bidirectional attention and no caching 94 | self.is_causal = False 95 | self.use_cache = False 96 | 97 | def set_config_defaults(self): 98 | """ 99 | Ensure model defaults are set according to final training checkpoint 100 | 101 | Qwen3MoeConfig defaults don't match Qwen/Qwen3-30B-A3B settings from which 102 | RND1 is derived. 103 | """ 104 | for k, v in CONFIG_DEFAULTS.items(): 105 | setattr(self, k, v) 106 | 107 | def to_dict(self): 108 | """ 109 | Serializes configuration to dictionary with auto_map for Hub. 110 | 111 | The auto_map ensures that when users load from HuggingFace Hub, 112 | the correct custom classes are automatically resolved. 113 | """ 114 | data = super().to_dict() 115 | data.setdefault( 116 | "auto_map", 117 | { 118 | "AutoConfig": "configuration_rnd.RND1Config", 119 | "AutoModel": "modeling_rnd.RND1Model", 120 | "AutoModelForMaskedLM": "modeling_rnd.RND1LM", 121 | }, 122 | ) 123 | return data 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

3 | RND1: Scaling Diffusion Language Models 4 |

5 |

6 | 7 | ![???](https://github.com/user-attachments/assets/c2c54f94-a7f5-4b76-987d-f15de4efaef6) 8 | 9 | 10 | This repository contains an inference harness for Radical Numerics Diffusion 1 (RND1), an experimental diffusion language model. RND1-Base-0910 is a 30B‑parameter sparse Mixture‑of‑Experts model with 3B active parameters per token, converted from an autoregressive base (Qwen3-30B-A3B) via continual pretraining on 500B tokens. 11 | 12 | We release RND1 models to catalyze further research on inference and post-training of DLMs. 13 | 14 | For more details, see: 15 | 16 | **Blog:** https://www.radicalnumerics.ai/blog/rnd1 17 | 18 | **Report:** https://www.radicalnumerics.ai/assets/rnd1_report.pdf 19 | 20 | **🤗:** https://huggingface.co/radicalnumerics/RND1-Base-0910 21 | 22 | **Models:** 23 | * **RND1-Base-0910**: first base model in the RND1 family. It has not been post-trained for specific usage. 24 | 25 | 26 | ## Installation 27 | 28 | ```bash 29 | # tested with Python 3.12 30 | pip install torch transformers accelerate numpy rich 31 | ``` 32 | 33 | ```bash 34 | # backends enable faster inference through optimized MoE kernels: 35 | pip install flashinfer-python 36 | pip install sglang[all] 37 | pip install vllm 38 | ``` 39 | 40 | ## Quick Start 41 | 42 | 43 | 44 | ```bash 45 | # Task mode (default) - for instructions, questions, or requests 46 | python demo_rnd_generation.py --prompt "Write a Python function that finds the longest common subsequence of two strings. Include comments explaining the algorithm." --moe_backend hf 47 | 48 | # Completion mode - for text continuation 49 | python demo_rnd_generation.py --mode completion --prompt "The key to understanding quantum computing lies in" --moe_backend hf 50 | 51 | # Sampling parameters 52 | python demo_rnd_generation.py --top_k 50 --temperature 0.7 --prompt "Explain how neural networks learn in simple terms" --moe_backend hf 53 | ``` 54 | 55 | 56 | > [!WARNING] 57 | > Selecting a non-Huggingface MoE backend is highly encouraged for faster generation. Note however that non-HF backends currently support a single GPU only, so you need to set e.g. `export CUDA_VISIBLE_DEVICES=0` before running the script. If you use `flashinfer-python`, JIT compilation the first time the code is run may take a while unless `flashinfer-jit-cache` is installed. 58 | 59 | ### Demo Parameters 60 | 61 | - `--mode`: Generation mode - 'task' or 'completion' (default: task) 62 | - `task`: For instructions, questions, or requests (adds "Question:" prefix) 63 | - `completion`: For text continuation (no prefix added) 64 | - `--max_new_tokens`: Number of new tokens to generate (default: 256) 65 | - `--num_steps`: Diffusion denoising steps (default: 256) 66 | - `--temperature`: Sampling temperature, 0.0 for greedy (default: 0.01) 67 | - `--top_k`: Top-k filtering - keeps only k most likely tokens (works with greedy and sampling) 68 | - `--top_p`: Nucleus filtering - keeps tokens with cumulative probability ≤ p (works with greedy and sampling) 69 | - `--moe_backend`: Choose backend: hf, vllm, sglang, flashinfer (default: hf) 70 | - `--no_viz`: Disable visualization 71 | - `--add_eos_at_end`: Add End of Sequence (EOS) token at the end of the sequence; useful to force the model to come to a coherent end (default: False) 72 | 73 | ## Python API 74 | 75 | ```python 76 | from transformers import AutoTokenizer 77 | from rnd import RND1LM 78 | 79 | # Load tokenizer 80 | tokenizer = AutoTokenizer.from_pretrained("radicalnumerics/RND1-Base-0910", trust_remote_code=True) 81 | 82 | # Load model 83 | model = RND1LM.from_pretrained( 84 | "radicalnumerics/RND1-Base-0910", 85 | dtype="bfloat16", 86 | device_map="auto", 87 | trust_remote_code=True, 88 | moe_backend="hf", # hf (default), sglang, vllm, flashinfer 89 | ) 90 | 91 | # Generate - Task mode (for instructions and questions) 92 | prompt = "Write a Python function that finds the longest common subsequence of two strings. Include comments explaining the algorithm." 93 | inputs = tokenizer(f"Question: {prompt}\nAnswer:", return_tensors="pt") 94 | input_ids = inputs.input_ids.to(model.device) 95 | 96 | # Generate 97 | output = model.generate( 98 | inputs=input_ids, 99 | max_new_tokens=256, 100 | num_diffusion_steps=256, 101 | temperature=0.01, 102 | ) 103 | 104 | # Decode only the generated part 105 | text = tokenizer.decode(output[0], skip_special_tokens=True) 106 | print(text) 107 | ``` 108 | 109 | ## Project Structure 110 | 111 | ``` 112 | RND_dev/ 113 | ├── README.md # This file 114 | ├── demo_rnd_generation.py # Demo script with command-line interface 115 | └── rnd/ # Core RND1 package 116 | ├── __init__.py # Package exports 117 | ├── configuration_rnd.py # RND1 model configuration 118 | ├── modeling_rnd.py # Core model implementation 119 | ├── generation_config.py # Generation configuration 120 | ├── generation_utils.py # Generation mixin and utilities 121 | ├── sampling.py # Diffusion sampling algorithm 122 | └── terminal_visualizer.py # Live visualization (optional) 123 | ``` 124 | 125 | 126 | --- 127 | 128 |

129 | Radical Numerics Logo 130 |

131 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[codz] 4 | *$py.class 5 | 6 | # Temp files 7 | profiler_traces/ 8 | 9 | # Agent files and commonly used patters for agentic coding 10 | RISA.md 11 | .risa/ 12 | AGENTS.md 13 | TASK.md 14 | 15 | # data 16 | data/ 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py.cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | cover/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | db.sqlite3-journal 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | .pybuilder/ 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | # For a library or package, you might want to ignore these files since the code is 99 | # intended to run in multiple environments; otherwise, check them in: 100 | # .python-version 101 | 102 | # pipenv 103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 106 | # install all needed dependencies. 107 | # Pipfile.lock 108 | 109 | # UV 110 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 111 | # This is especially recommended for binary packages to ensure reproducibility, and is more 112 | # commonly ignored for libraries. 113 | uv.lock 114 | 115 | # poetry 116 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 117 | # This is especially recommended for binary packages to ensure reproducibility, and is more 118 | # commonly ignored for libraries. 119 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 120 | # poetry.lock 121 | # poetry.toml 122 | 123 | # pdm 124 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 125 | # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. 126 | # https://pdm-project.org/en/latest/usage/project/#working-with-version-control 127 | # pdm.lock 128 | # pdm.toml 129 | .pdm-python 130 | .pdm-build/ 131 | 132 | # pixi 133 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. 134 | # pixi.lock 135 | # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one 136 | # in the .venv directory. It is recommended not to include this directory in version control. 137 | .pixi 138 | 139 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 140 | __pypackages__/ 141 | 142 | # Celery stuff 143 | celerybeat-schedule 144 | celerybeat.pid 145 | 146 | # Redis 147 | *.rdb 148 | *.aof 149 | *.pid 150 | 151 | # RabbitMQ 152 | mnesia/ 153 | rabbitmq/ 154 | rabbitmq-data/ 155 | 156 | # ActiveMQ 157 | activemq-data/ 158 | 159 | # SageMath parsed files 160 | *.sage.py 161 | 162 | # Environments 163 | .env 164 | .envrc 165 | .venv 166 | env/ 167 | venv/ 168 | ENV/ 169 | env.bak/ 170 | venv.bak/ 171 | 172 | # Spyder project settings 173 | .spyderproject 174 | .spyproject 175 | 176 | # Rope project PEAR Contributorssettings 177 | .ropeproject 178 | 179 | # mkdocs documentation 180 | /site 181 | 182 | # mypy 183 | .mypy_cache/ 184 | .dmypy.json 185 | dmypy.json 186 | 187 | # Pyre type checker 188 | .pyre/ 189 | 190 | # pytype static type analyzer 191 | .pytype/ 192 | 193 | # Cython debug symbols 194 | cython_debug/ 195 | 196 | # PyCharm 197 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 198 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 199 | # and can be added to the global gitignore or merged into this file. For a more nuclear 200 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 201 | # .idea/ 202 | 203 | # Abstra 204 | # Abstra is an AI-powered process automation framework. 205 | # Ignore directories containing user credentials, local state, and settings. 206 | # Learn more at https://abstra.io/docs 207 | .abstra/ 208 | 209 | # Visual Studio Code 210 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 211 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 212 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 213 | # you could uncomment the following to ignore the entire vscode folder 214 | # .vscode/ 215 | 216 | # Ruff stuff: 217 | .ruff_cache/ 218 | 219 | # PyPI configuration file 220 | .pypirc 221 | 222 | # Marimo 223 | marimo/_static/ 224 | marimo/_lsp/ 225 | __marimo__/ 226 | 227 | # Streamlit 228 | .streamlit/secrets.toml -------------------------------------------------------------------------------- /assets/rn-logo-desktop-vector.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 61 | 62 | 63 | 64 | 65 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /rnd/generation_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Radical Numerics Inc. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0, found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | RND1 Generation Utilities. 8 | 9 | This module provides generation utilities and mixins for RND1 models, 10 | including the main GenerationMixin class that integrates with HuggingFace. 11 | """ 12 | 13 | from typing import Any 14 | 15 | import torch 16 | 17 | from transformers import GenerationMixin as HFGenerationMixin 18 | from transformers.generation import GenerationConfig 19 | 20 | from .generation_config import RND1GenerationConfig 21 | from .sampling import diffusion_sample 22 | 23 | 24 | class RND1GenerationMixin(HFGenerationMixin): 25 | """ 26 | Generation mixin for RND1 models. 27 | 28 | This mixin provides generation methods compatible with HuggingFace's 29 | generation API while using RND1's diffusion-based sampling internally. 30 | """ 31 | 32 | def generate( 33 | self, 34 | inputs: torch.LongTensor | None = None, 35 | generation_config: GenerationConfig | None = None, 36 | # RND1-specific parameters 37 | prefix_ids: torch.LongTensor | None = None, 38 | suffix_ids: torch.LongTensor | None = None, 39 | infill_length: int | None = None, 40 | return_dict_in_generate: bool | None = None, 41 | **kwargs, # Accept all kwargs to be compatible with pipelines 42 | ) -> torch.LongTensor | dict[str, Any]: 43 | """ 44 | Generate text using RND1's diffusion-based sampling. 45 | 46 | Follows HuggingFace's standard generate API, using diffusion sampling 47 | internally. Supports both standard generation and infilling. 48 | 49 | Args: 50 | inputs: Input token IDs to use as prefix (standard HF parameter) 51 | generation_config: Generation configuration object. Default is RND1GenerationConfig. 52 | prefix_ids: Alternative to inputs for infilling tasks 53 | suffix_ids: Optional suffix for infilling tasks 54 | infill_length: Length of infill region (for infilling) 55 | return_dict_in_generate: Whether to return GenerateDecoderOnlyOutput 56 | **kwargs: Additional arguments (accepted for compatibility). These will be passed to the config constructor. 57 | 58 | Returns: 59 | Generated token IDs or GenerateDecoderOnlyOutput 60 | """ 61 | if generation_config is not None: 62 | gen_config = generation_config 63 | model_kwargs = kwargs.copy() 64 | else: 65 | # Only prepare config from kwargs if no config was provided 66 | gen_config, model_kwargs = self._prepare_generation_config(RND1GenerationConfig(), **kwargs) 67 | 68 | device = next(self.parameters()).device 69 | 70 | if inputs is not None: 71 | prefix_ids = inputs.to(device) 72 | elif prefix_ids is not None: 73 | prefix_ids = prefix_ids.to(device) 74 | else: 75 | prefix_ids = None 76 | 77 | if suffix_ids is not None: 78 | suffix_ids = suffix_ids.to(device) 79 | 80 | eos_token_id = gen_config.eos_token_id or getattr(self.config, "eos_token_id", 151645) 81 | pad_token_id = gen_config.pad_token_id or getattr(self.config, "pad_token_id", 151643) 82 | bos_token_id = gen_config.bos_token_id or getattr(self.config, "bos_token_id", None) 83 | mask_token_id = getattr(gen_config, "mask_token_id", getattr(self.config, "mask_token_id", 151669)) 84 | 85 | if infill_length is not None and prefix_ids is not None: 86 | # Infilling mode: use specified infill_length 87 | prefix_len = prefix_ids.shape[1] if prefix_ids is not None else 0 88 | suffix_len = suffix_ids.shape[1] if suffix_ids is not None else 0 89 | seq_len = prefix_len + infill_length + suffix_len 90 | else: 91 | # Standard generation mode 92 | if prefix_ids is not None: 93 | prefix_len = prefix_ids.shape[1] 94 | if gen_config.max_new_tokens is not None: 95 | seq_len = prefix_len + gen_config.max_new_tokens 96 | else: 97 | seq_len = gen_config.max_length or self.config.max_position_embeddings 98 | else: 99 | seq_len = gen_config.max_length or self.config.max_position_embeddings 100 | 101 | num_diffusion_steps = getattr( 102 | gen_config, "num_diffusion_steps", getattr(self.config, "num_diffusion_steps", 256) 103 | ) 104 | 105 | temperature = float(getattr(gen_config, "temperature", 1.0)) 106 | top_k = getattr(gen_config, "top_k", None) 107 | top_p = getattr(gen_config, "top_p", None) 108 | 109 | greedy = getattr( 110 | gen_config, "greedy", not bool(gen_config.do_sample) if hasattr(gen_config, "do_sample") else True 111 | ) 112 | 113 | with torch.inference_mode(): 114 | sequences = diffusion_sample( 115 | model=self, 116 | seq_len=seq_len, 117 | num_steps=num_diffusion_steps, 118 | mask_token_id=mask_token_id, 119 | temperature=temperature, 120 | top_k=top_k, 121 | top_p=top_p, 122 | greedy=greedy, 123 | prefix_ids=prefix_ids, 124 | suffix_ids=suffix_ids, 125 | infill_length=infill_length, 126 | eos_token_id=eos_token_id, 127 | pad_token_id=pad_token_id, 128 | bos_token_id=bos_token_id, 129 | device=device, 130 | visualizer=model_kwargs.get("visualizer", None), # Optional visualizer from kwargs, 131 | add_eos_at_end=getattr(gen_config, "add_eos_at_end", False), 132 | ) 133 | 134 | if return_dict_in_generate or getattr(gen_config, "return_dict_in_generate", False): 135 | from transformers.generation.utils import GenerateDecoderOnlyOutput 136 | 137 | return GenerateDecoderOnlyOutput(sequences=sequences) 138 | 139 | return sequences 140 | 141 | def generate_with_visualization( 142 | self, 143 | tokenizer, 144 | inputs: torch.LongTensor | None = None, 145 | generation_config: GenerationConfig | None = None, 146 | suffix_ids: torch.LongTensor | None = None, 147 | infill_length: int | None = None, 148 | **kwargs, 149 | ) -> torch.LongTensor: 150 | """ 151 | Generate with live visualization (for demos). 152 | 153 | This method requires a tokenizer to display the generation process. 154 | For production use, prefer `generate()`. 155 | 156 | Args: 157 | tokenizer: Tokenizer for decoding tokens to text 158 | inputs: Input token IDs to use as prefix 159 | generation_config: Generation configuration object 160 | suffix_ids: Optional suffix token IDs 161 | infill_length: Length of infill region 162 | **kwargs: Additional arguments for backward compatibility 163 | 164 | Returns: 165 | Generated token IDs as LongTensor 166 | """ 167 | from .terminal_visualizer import TerminalVisualizer 168 | 169 | visualizer = TerminalVisualizer(tokenizer, show_visualization=True) 170 | 171 | return self.generate( 172 | inputs=inputs, 173 | generation_config=generation_config, 174 | suffix_ids=suffix_ids, 175 | infill_length=infill_length, 176 | visualizer=visualizer, 177 | return_dict_in_generate=False, 178 | **kwargs, 179 | ) 180 | 181 | def prepare_inputs_for_generation( 182 | self, 183 | input_ids: torch.LongTensor, 184 | **kwargs, 185 | ) -> dict[str, Any]: 186 | """ 187 | Prepare inputs for generation (required by HuggingFace). 188 | 189 | For RND1, we don't use the standard autoregressive generation, 190 | so this just returns the input_ids. 191 | """ 192 | return {"input_ids": input_ids} 193 | -------------------------------------------------------------------------------- /rnd/terminal_visualizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Radical Numerics Inc. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0, found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Terminal visualization for RND1 generation. 8 | 9 | This module provides real-time visualization of the diffusion denoising process, 10 | showing token evolution and generation progress in the terminal using rich 11 | formatting when available. 12 | """ 13 | 14 | import torch 15 | 16 | from tqdm import tqdm 17 | 18 | try: 19 | from rich.console import Console 20 | from rich.layout import Layout 21 | from rich.live import Live 22 | from rich.panel import Panel 23 | from rich.progress import BarColumn, MofNCompleteColumn, Progress, TextColumn, TimeRemainingColumn 24 | from rich.text import Text 25 | 26 | RICH_AVAILABLE = True 27 | except ImportError: 28 | RICH_AVAILABLE = False 29 | 30 | 31 | class TerminalVisualizer: 32 | """ 33 | Rich-based visualization for diffusion process with live updates. 34 | 35 | Provides real-time visualization of the token denoising process during 36 | diffusion-based language generation, with colored highlighting of masked 37 | positions and progress tracking. 38 | """ 39 | 40 | def __init__(self, tokenizer, show_visualization: bool = True): 41 | """ 42 | Initialize the terminal visualizer. 43 | 44 | Args: 45 | tokenizer: The tokenizer for decoding tokens to text 46 | show_visualization: Whether to show visualization (requires rich) 47 | """ 48 | self.tokenizer = tokenizer 49 | self.show_visualization = show_visualization and RICH_AVAILABLE 50 | if not RICH_AVAILABLE and show_visualization: 51 | print("Warning: Install 'rich' for better visualization. Falling back to simple progress bar.") 52 | self.show_visualization = False 53 | 54 | if self.show_visualization: 55 | self.console = Console() 56 | self.live = None 57 | self.progress = None 58 | self.layout = None 59 | else: 60 | self.pbar = None 61 | 62 | self.current_tokens = None 63 | self.mask_positions = None 64 | self.total_steps = 0 65 | self.current_step = 0 66 | 67 | def start_visualization(self, initial_tokens: torch.LongTensor, mask_positions: torch.BoolTensor, total_steps: int): 68 | """ 69 | Start the visualization. 70 | 71 | Args: 72 | initial_tokens: Initial token IDs (possibly masked) 73 | mask_positions: Boolean mask indicating which positions are masked 74 | total_steps: Total number of diffusion steps 75 | """ 76 | if not self.show_visualization: 77 | self.pbar = tqdm(total=total_steps, desc="Diffusion") 78 | return 79 | 80 | self.current_tokens = initial_tokens.clone() 81 | self.mask_positions = mask_positions 82 | self.total_steps = total_steps 83 | self.current_step = 0 84 | 85 | self.layout = Layout() 86 | self.layout.split_column( 87 | Layout(name="header", size=3), Layout(name="text", ratio=1), Layout(name="progress", size=3) 88 | ) 89 | 90 | self.progress = Progress( 91 | TextColumn("[bold blue]Diffusion"), 92 | BarColumn(), 93 | MofNCompleteColumn(), 94 | TextColumn("•"), 95 | TextColumn("[cyan]Masks: {task.fields[masks]}"), 96 | TimeRemainingColumn(), 97 | ) 98 | self.progress_task = self.progress.add_task("Generating", total=total_steps, masks=mask_positions.sum().item()) 99 | 100 | self.live = Live(self.layout, console=self.console, refresh_per_second=4) 101 | self.live.start() 102 | self._update_display() 103 | 104 | def update_step( 105 | self, 106 | tokens: torch.LongTensor, 107 | maskable: torch.BoolTensor | None, 108 | step: int, 109 | entropy: torch.FloatTensor | None = None, 110 | confidence: torch.FloatTensor | None = None, 111 | ): 112 | """ 113 | Update visualization for current step. 114 | 115 | Args: 116 | tokens: Current token IDs 117 | maskable: Boolean mask of remaining masked positions 118 | step: Current step number 119 | entropy: Optional entropy scores for each position 120 | confidence: Optional confidence scores for each position 121 | """ 122 | if not self.show_visualization: 123 | if self.pbar: 124 | self.pbar.update(1) 125 | masks = maskable.sum().item() if maskable is not None else 0 126 | self.pbar.set_postfix({"masks": masks}) 127 | return 128 | 129 | self.current_tokens = tokens.clone() 130 | self.mask_positions = maskable 131 | self.current_step = step 132 | 133 | masks_remaining = maskable.sum().item() if maskable is not None else 0 134 | self.progress.update(self.progress_task, advance=1, masks=masks_remaining) 135 | 136 | self._update_display() 137 | 138 | def _update_display(self): 139 | """Update the live display.""" 140 | if not self.live: 141 | return 142 | 143 | header = Text("RND1-Base Generation", style="bold magenta", justify="center") 144 | self.layout["header"].update(Panel(header, border_style="bright_blue")) 145 | 146 | text_display = self._format_text_with_masks() 147 | self.layout["text"].update( 148 | Panel( 149 | text_display, 150 | title="[bold]Generated Text", 151 | subtitle=f"[dim]Step {self.current_step}/{self.total_steps}[/dim]", 152 | border_style="cyan", 153 | ) 154 | ) 155 | 156 | self.layout["progress"].update(Panel(self.progress)) 157 | 158 | def _format_text_with_masks(self) -> Text: 159 | """ 160 | Format text with colored masks. 161 | 162 | Returns: 163 | Rich Text object with formatted tokens 164 | """ 165 | text = Text() 166 | 167 | if self.current_tokens is None: 168 | return text 169 | 170 | token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens 171 | mask_flags = ( 172 | self.mask_positions[0] 173 | if self.mask_positions is not None and self.mask_positions.dim() > 1 174 | else self.mask_positions 175 | ) 176 | 177 | for i, token_id in enumerate(token_ids): 178 | if mask_flags is not None and i < len(mask_flags) and mask_flags[i]: 179 | # Alternate colors for visual effect 180 | text.append( 181 | "[MASK]", style="bold red on yellow" if self.current_step % 2 == 0 else "bold yellow on red" 182 | ) 183 | else: 184 | try: 185 | token_str = self.tokenizer.decode([token_id.item()], skip_special_tokens=False) 186 | # Skip special tokens in display 187 | if token_str not in ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "", ""]: 188 | # Color based on position 189 | text.append(token_str, style="green" if i < len(token_ids) // 2 else "cyan") 190 | except Exception: 191 | continue 192 | 193 | return text 194 | 195 | def stop_visualization(self): 196 | """Stop the visualization and display final result.""" 197 | if not self.show_visualization: 198 | if self.pbar: 199 | self.pbar.close() 200 | print("\n✨ Generation complete!\n") 201 | return 202 | 203 | if self.live: 204 | self.live.stop() 205 | 206 | self.console.print("\n[bold green]✨ Generation complete![/bold green]\n") 207 | 208 | # Display final text 209 | if self.current_tokens is not None: 210 | try: 211 | token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens 212 | final_text = self.tokenizer.decode(token_ids, skip_special_tokens=True) 213 | 214 | self.console.print( 215 | Panel(final_text, title="[bold]Final Generated Text", border_style="green", padding=(1, 2)) 216 | ) 217 | except Exception: 218 | pass 219 | 220 | 221 | class SimpleProgressBar: 222 | """ 223 | Simple progress bar fallback when rich is not available. 224 | 225 | Provides basic progress tracking using tqdm when the rich library 226 | is not installed. 227 | """ 228 | 229 | def __init__(self, total_steps: int): 230 | """ 231 | Initialize simple progress bar. 232 | 233 | Args: 234 | total_steps: Total number of steps 235 | """ 236 | self.pbar = tqdm(total=total_steps, desc="Diffusion") 237 | 238 | def update(self, masks_remaining: int = 0): 239 | """ 240 | Update progress bar. 241 | 242 | Args: 243 | masks_remaining: Number of masks still remaining 244 | """ 245 | self.pbar.update(1) 246 | self.pbar.set_postfix({"masks": masks_remaining}) 247 | 248 | def close(self): 249 | """Close the progress bar.""" 250 | self.pbar.close() 251 | print("\n✨ Generation complete!\n") 252 | -------------------------------------------------------------------------------- /rnd/sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Radical Numerics Inc. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0, found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | RND1 sampling module for masked diffusion generation. 8 | 9 | This module implements entropy-based token selection for iterative denoising 10 | in diffusion language models. Supports both greedy and stochastic sampling 11 | with optional prefix/suffix constraints and infilling. 12 | """ 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | def apply_top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor: 19 | """ 20 | Apply top-k filtering to logits: with non-top-k values set to -inf 21 | """ 22 | top_k_values, top_k_indices = torch.topk(logits, min(k, logits.size(-1)), dim=-1) 23 | filtered_logits = torch.full_like(logits, float("-inf")) 24 | filtered_logits.scatter_(-1, top_k_indices, top_k_values) 25 | return filtered_logits 26 | 27 | 28 | def apply_top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor: 29 | """ 30 | Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf 31 | """ 32 | sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) 33 | cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) 34 | 35 | # Remove tokens with cumulative probability above threshold 36 | sorted_indices_to_remove = cumulative_probs > p 37 | sorted_indices_to_remove[..., 0] = False # Keep at least one token 38 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 39 | 40 | indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove) 41 | return logits.masked_fill(indices_to_remove, float("-inf")) 42 | 43 | 44 | @torch.no_grad() 45 | def diffusion_sample( 46 | model: nn.Module, 47 | seq_len: int = 256, 48 | num_steps: int = 256, 49 | top_k: int | None = None, 50 | top_p: float | None = None, 51 | temperature: float = 1.0, 52 | greedy: bool = True, 53 | mask_token_id: int = 151669, 54 | prefix_ids: torch.LongTensor | None = None, 55 | suffix_ids: torch.LongTensor | None = None, 56 | infill_length: int | None = None, 57 | eos_token_id: int = 151645, 58 | pad_token_id: int | None = None, 59 | bos_token_id: int | None = None, 60 | device: str | torch.device | None = None, 61 | visualizer: object | None = None, 62 | add_eos_at_end: bool = False, 63 | ) -> torch.LongTensor: 64 | """ 65 | Perform masked diffusion sampling with entropy-based token selection. 66 | 67 | Args: 68 | model: The RND1 language model 69 | seq_len: Target sequence length 70 | num_steps: Number of denoising steps 71 | top_k: Optional top-k filtering for sampling (None = no filtering) 72 | top_p: Optional nucleus (top-p) filtering for sampling (None = no filtering) 73 | When both top_k and top_p are set, top_k is applied first, then top_p 74 | temperature: Temperature for sampling (higher = more random, lower = more deterministic) 75 | Values close to 0 are clamped to 1e-8 to avoid division by zero 76 | greedy: Whether to use greedy sampling (True) or stochastic (False) 77 | mask_token_id: Token ID for masked positions (default: 151669) 78 | prefix_ids: Optional prefix token IDs to preserve 79 | suffix_ids: Optional suffix token IDs to preserve 80 | infill_length: Length of infill region between prefix/suffix 81 | eos_token_id: End of sequence token ID (default: 151645) 82 | pad_token_id: Padding token ID (default: None, uses 0 if needed) 83 | bos_token_id: Beginning of sequence token ID (default: None) 84 | device: Device for computation (None = infer from model) 85 | visualizer: Optional visualizer for live visualization 86 | add_eos_at_end: Whether to force EOS token at the end of the sequence 87 | 88 | Returns: 89 | Generated token IDs as LongTensor 90 | """ 91 | model.eval() 92 | 93 | if device is None: 94 | device = next(model.parameters()).device 95 | else: 96 | device = torch.device(device) 97 | 98 | if pad_token_id is None: 99 | pad_token_id = 0 100 | 101 | # Build initial masked sequence 102 | # When prefix_ids is provided, we create a sequence of length seq_len where: 103 | # - The prefix occupies the first pre_len positions 104 | # - The remaining (seq_len - pre_len) positions are filled with mask tokens to be generated 105 | if prefix_ids is not None or suffix_ids is not None: 106 | if prefix_ids is not None: 107 | prefix_ids = ( 108 | prefix_ids.to(device) 109 | if isinstance(prefix_ids, torch.Tensor) 110 | else torch.tensor(prefix_ids, device=device) 111 | ) 112 | pre_len = prefix_ids.shape[-1] if prefix_ids.dim() > 0 else 0 113 | else: 114 | pre_len = 0 115 | 116 | if suffix_ids is not None: 117 | suffix_ids = ( 118 | suffix_ids.to(device) 119 | if isinstance(suffix_ids, torch.Tensor) 120 | else torch.tensor(suffix_ids, device=device) 121 | ) 122 | suf_len = suffix_ids.shape[-1] if suffix_ids.dim() > 0 else 0 123 | else: 124 | suf_len = 0 125 | 126 | reserved = 1 if eos_token_id is not None else 0 127 | used = pre_len + suf_len + reserved 128 | 129 | if used > seq_len: 130 | raise ValueError( 131 | f"Combined length of prefix ({pre_len}), suffix ({suf_len}), " 132 | f"and special tokens ({reserved}) = {used} exceeds seq_len ({seq_len}). " 133 | f"Please increase seq_len or reduce input lengths." 134 | ) 135 | elif used == seq_len: 136 | raise ValueError( 137 | f"No space for generation: prefix ({pre_len}) + suffix ({suf_len}) " 138 | f"+ special tokens ({reserved}) = seq_len ({seq_len}). " 139 | f"Need at least 1 position for generation." 140 | ) 141 | 142 | infill_length = min(infill_length or (seq_len - used), seq_len - used) 143 | 144 | x = torch.full((1, seq_len), pad_token_id, dtype=torch.long, device=device) 145 | pos = 0 146 | # if bos_token_id is not None: 147 | # x[0, pos] = bos_token_id; pos += 1 148 | if eos_token_id is not None and add_eos_at_end: 149 | x[0, -1] = eos_token_id 150 | if pre_len > 0: 151 | x[0, pos : pos + pre_len] = prefix_ids.flatten()[:pre_len] 152 | pos += pre_len 153 | fill_start, fill_end = pos, pos + infill_length 154 | x[0, fill_start:fill_end] = mask_token_id 155 | # print(fill_start, fill_end, seq_len, used, x[0, -1]) 156 | pos = fill_end 157 | if suf_len > 0: 158 | x[0, pos : pos + suf_len] = suffix_ids.flatten()[:suf_len] 159 | pos += suf_len 160 | 161 | init_maskable = torch.zeros_like(x, dtype=torch.bool) 162 | init_maskable[0, fill_start:fill_end] = True 163 | else: 164 | x = torch.full((1, seq_len), mask_token_id, dtype=torch.long, device=device) 165 | if bos_token_id is not None: 166 | x[0, 0] = bos_token_id 167 | if eos_token_id is not None and add_eos_at_end: 168 | x[0, -1] = eos_token_id 169 | init_maskable = x.eq(mask_token_id) 170 | 171 | if bos_token_id is not None: 172 | init_maskable[:, 0] = False 173 | if eos_token_id is not None: 174 | init_maskable &= x.ne(eos_token_id) 175 | init_maskable &= x.ne(pad_token_id) 176 | 177 | maskable = init_maskable.clone() 178 | xt = x.clone() 179 | 180 | if visualizer: 181 | visualizer.start_visualization(xt, maskable, num_steps) 182 | 183 | def forward_scores(tokens): 184 | """Compute predictions and entropy scores for next tokens.""" 185 | # Try with input_ids parameter first (standard HF models) 186 | try: 187 | model_output = model(input_ids=tokens) 188 | except TypeError: 189 | # Fall back to positional argument 190 | model_output = model(tokens) 191 | 192 | # Apply temperature scaling (with safety for near-zero temperature) 193 | safe_temperature = max(temperature, 1e-8) # Prevent division by zero 194 | logits = model_output.logits / safe_temperature 195 | 196 | # Apply filtering strategies 197 | # Note: When both top_k and top_p are provided, they are applied sequentially: 198 | # First top_k filters to k tokens, then top_p filters from those k tokens 199 | if top_k is not None and top_k > 0: 200 | logits = apply_top_k_filtering(logits, top_k) 201 | 202 | if top_p is not None and 0 < top_p < 1.0: 203 | logits = apply_top_p_filtering(logits, top_p) 204 | 205 | # Convert to log probabilities 206 | logp = torch.log_softmax(logits, dim=-1) 207 | 208 | # Greedy or stochastic sampling 209 | if greedy: 210 | pred_next = logp.argmax(-1) 211 | else: 212 | pred_next = torch.distributions.Categorical(logits=logp).sample() 213 | 214 | conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1) 215 | 216 | p = logp.exp() 217 | ent_next = -(p * logp).sum(-1) 218 | 219 | # Shift predictions: pos i predicts token i+1 220 | pred_i = tokens.clone() 221 | conf_i = torch.full_like(conf_next, torch.finfo(conf_next.dtype).min) 222 | ent_i = torch.zeros_like(ent_next) 223 | 224 | pred_i[:, 1:] = pred_next[:, :-1] 225 | conf_i[:, 1:] = conf_next[:, :-1] 226 | ent_i[:, 1:] = ent_next[:, :-1] 227 | 228 | return pred_i, conf_i, ent_i 229 | 230 | pred_i, conf_i, ent_i = forward_scores(xt) 231 | total_masked = init_maskable.sum(1, keepdim=True) 232 | finf = torch.finfo(conf_i.dtype) 233 | 234 | for step in range(num_steps - 1, 0, -1): 235 | rate = step / num_steps 236 | cutoff_len = (total_masked * rate).long().clamp(min=0) 237 | 238 | # Choose HIGH-entropy tokens to keep masked 239 | sel_scores = ent_i.masked_fill(~maskable, -finf.max) 240 | B, L = sel_scores.shape 241 | k_max = cutoff_len.max().item() 242 | if k_max > 0: 243 | sss, idx = torch.topk(sel_scores, k_max, dim=-1, largest=True) 244 | keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool) 245 | for b in range(B): 246 | k_b = int(cutoff_len[b].item()) 247 | if k_b > 0: 248 | keep_mask[b, idx[b, :k_b]] = True 249 | else: 250 | keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool) 251 | 252 | to_unmask = maskable & ~keep_mask 253 | if to_unmask.any(): 254 | xt[to_unmask] = pred_i[to_unmask] 255 | maskable[to_unmask] = False 256 | 257 | if visualizer: 258 | visualizer.update_step(xt, maskable, num_steps - step, ent_i, conf_i) 259 | 260 | if maskable.any(): 261 | pred_i, conf_i, ent_i = forward_scores(xt) 262 | 263 | if maskable.any(): 264 | xt[maskable] = pred_i[maskable] 265 | 266 | if visualizer: 267 | visualizer.stop_visualization() 268 | 269 | return xt 270 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2025 Radical Numerics Inc. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /demo_rnd_generation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Demo script for RND1 generation. 4 | """ 5 | 6 | import argparse 7 | import os 8 | import random 9 | import sys 10 | 11 | import numpy as np 12 | import torch 13 | 14 | from transformers import AutoTokenizer 15 | 16 | # Add RND1 module to path for local testing 17 | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 18 | 19 | 20 | def set_seed(seed: int): 21 | """Set random seed for reproducibility.""" 22 | random.seed(seed) 23 | np.random.seed(seed) 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed_all(seed) 26 | 27 | 28 | def demo_completion( 29 | model_path: str, 30 | checkpoint_path: str = None, 31 | device: str = "cuda:0", 32 | use_bfloat16: bool = True, 33 | show_visualization: bool = True, 34 | num_steps: int = 64, 35 | max_new_tokens: int = 256, 36 | custom_prompt: str = None, 37 | temperature: float = 1.0, 38 | top_k: int = None, 39 | top_p: float = None, 40 | mask_token_id: int = 151669, 41 | seed: int = None, 42 | moe_backend: str = "hf", 43 | mode: str = "task", 44 | add_eos_at_end: bool = False, 45 | ): 46 | """ 47 | Demonstrate text completion using RND1. 48 | 49 | Args: 50 | model_path: Path to base model or HuggingFace model ID 51 | checkpoint_path: Path to custom checkpoint (if any) 52 | device: Device to run on (e.g., cuda:0, cpu) 53 | use_bfloat16: Whether to use bfloat16 precision 54 | show_visualization: Whether to show live visualization (requires rich) 55 | num_steps: Number of diffusion steps 56 | max_new_tokens: Maximum number of tokens to generate 57 | custom_prompt: Custom prompt to use instead of default examples 58 | temperature: Temperature for sampling (0.0 = greedy) 59 | top_k: Top-k filtering for sampling (None = disabled) 60 | top_p: Top-p (nucleus) filtering for sampling (None = disabled) 61 | mask_token_id: Token ID for mask token 62 | seed: Random seed for reproducibility 63 | moe_backend: MoE backend to use ('hf', 'vllm', 'sglang', 'flashinfer') 64 | mode: Generation mode ('task' for Q&A format, 'completion' for continuation) 65 | add_eos_at_end: Whether to add EOS token at the end of the sequence 66 | """ 67 | # if seed is not None: 68 | if seed is None: 69 | # generate a random seed 70 | seed = random.randint(0, 1000000) 71 | print(f"Seed not provided, using random seed: {seed}") 72 | set_seed(seed) 73 | 74 | from rnd.configuration_rnd import RND1Config 75 | from rnd.modeling_rnd import RND1LM 76 | 77 | print("Loading tokenizer...") 78 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 79 | 80 | dtype = torch.bfloat16 if use_bfloat16 else torch.float32 81 | print(f"Using dtype: {dtype}") 82 | 83 | if moe_backend == "hf": 84 | print( 85 | "\n⚠️ Note: HuggingFace backend is slower. " 86 | "Consider using --moe_backend vllm, sglang or flashinfer for better performance.\n" 87 | ) 88 | 89 | # Load from checkpoint if provided, otherwise from model_path 90 | load_path = checkpoint_path if checkpoint_path else model_path 91 | 92 | print(f"Loading model from {load_path}...") 93 | 94 | # Load config and set RND1-specific settings 95 | cfg = RND1Config.from_pretrained(load_path) 96 | cfg.model_type = "rnd1" 97 | cfg.attn_implementation = "sdpa" 98 | cfg.moe_backend = moe_backend 99 | 100 | # Load model with RND1LM 101 | model = RND1LM.from_pretrained( 102 | load_path, 103 | config=cfg, 104 | dtype=dtype, 105 | device_map="auto" if device == "cuda:0" else device, 106 | trust_remote_code=True, 107 | use_safetensors=True, 108 | low_cpu_mem_usage=True, 109 | ) 110 | print("Model loaded") 111 | model = model.eval() 112 | 113 | if custom_prompt: 114 | prompts = [custom_prompt] 115 | else: 116 | # Default prompts based on mode 117 | if mode == "task": 118 | prompts = [ 119 | "Write a Python function that finds the longest common subsequence of two strings." 120 | "Include comments explaining the algorithm." 121 | ] 122 | else: 123 | prompts = ["The key to understanding quantum computing lies in"] 124 | 125 | greedy = temperature == 0.0 126 | 127 | for i, user_prompt in enumerate(prompts): 128 | print(f"\n{'=' * 60}") 129 | print(f"Mode: {mode.upper()}") 130 | print(f"Prompt {i + 1}: {user_prompt[:100]}...") 131 | print(f"{'=' * 60}\n") 132 | 133 | if mode == "task": 134 | # Task mode: Add "Question: " prefix if not already present 135 | if not user_prompt.strip().startswith("Question:"): 136 | prompt = f"Question: {user_prompt}\nAnswer:" 137 | else: 138 | prompt = user_prompt 139 | else: 140 | # Completion mode: Use prompt as-is for continuation 141 | prompt = user_prompt 142 | 143 | inputs = tokenizer(prompt, return_tensors="pt") 144 | input_ids = inputs.input_ids.to(device if device != "auto" else "cuda") 145 | 146 | print("Generation parameters:") 147 | print(f" Prompt length: {input_ids.shape[1]} tokens") 148 | print(f" Max new tokens: {max_new_tokens}") 149 | print(f" Total sequence: {input_ids.shape[1] + max_new_tokens} tokens") 150 | print(f" Diffusion steps: {num_steps}") 151 | print(f" Temperature: {temperature}") 152 | print(f" Greedy: {greedy}") 153 | if top_k: 154 | print(f" Top-k: {top_k}") 155 | if top_p: 156 | print(f" Top-p: {top_p}") 157 | print() 158 | 159 | # Create explicit generation config that takes priority over model defaults 160 | from rnd.generation_config import RND1GenerationConfig 161 | 162 | gen_config = RND1GenerationConfig( 163 | max_new_tokens=max_new_tokens, 164 | num_diffusion_steps=num_steps, 165 | mask_token_id=mask_token_id, 166 | temperature=temperature if not greedy else 0.0, 167 | top_k=top_k, 168 | top_p=top_p, 169 | greedy=greedy, 170 | eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id else 151645, 171 | pad_token_id=tokenizer.pad_token_id, 172 | bos_token_id=tokenizer.bos_token_id, 173 | add_eos_at_end=add_eos_at_end, 174 | ) 175 | 176 | with torch.no_grad(): 177 | if show_visualization and hasattr(model, "generate_with_visualization"): 178 | # Use method with visualization support (requires tokenizer) 179 | output = model.generate_with_visualization( 180 | tokenizer=tokenizer, 181 | inputs=input_ids, 182 | generation_config=gen_config, 183 | ) 184 | else: 185 | # Use standard generate method with explicit config 186 | output = model.generate( 187 | inputs=input_ids, 188 | generation_config=gen_config, 189 | ) 190 | 191 | generated_tokens = output[0][len(input_ids[0]) :] 192 | generation = tokenizer.decode(generated_tokens.tolist(), skip_special_tokens=True) 193 | 194 | if not show_visualization: # by default the viz shows final response too 195 | print("\nGenerated response:") 196 | print(generation) 197 | 198 | print(f"\n(Generation completed in {num_steps} diffusion steps)") 199 | 200 | 201 | def main(): 202 | parser = argparse.ArgumentParser( 203 | description="RND1 diffusion model demo with live visualization", 204 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 205 | ) 206 | 207 | # Model configuration 208 | model_group = parser.add_argument_group("Model Configuration") 209 | model_group.add_argument( 210 | "--model_path", type=str, default="radicalnumerics/RND1-Base-0910", help="Path to model or HuggingFace model ID" 211 | ) 212 | model_group.add_argument("--checkpoint", type=str, default=None, help="Path to custom checkpoint file or directory") 213 | model_group.add_argument("--device", type=str, default="cuda:0", help="Device to run on (e.g., cuda:0, cpu)") 214 | model_group.add_argument("--fp32", action="store_true", help="Use FP32 precision instead of BF16") 215 | 216 | # Generation configuration 217 | gen_group = parser.add_argument_group("Generation Settings") 218 | gen_group.add_argument("--num_steps", type=int, default=256, help="Number of diffusion steps") 219 | gen_group.add_argument("--max_new_tokens", type=int, default=256, help="Maximum number of tokens to generate") 220 | gen_group.add_argument("--prompt", type=str, default=None, help="Custom prompt to use for generation") 221 | gen_group.add_argument( 222 | "--mode", 223 | type=str, 224 | default="task", 225 | choices=["task", "completion"], 226 | help="Generation mode: 'task' (Q&A format for instructions) or 'completion' (text continuation)", 227 | ) 228 | gen_group.add_argument("--mask_token_id", type=int, default=151669, help="Token ID for mask token") 229 | 230 | # Sampling configuration 231 | sampling_group = parser.add_argument_group("Sampling Parameters") 232 | sampling_group.add_argument( 233 | "--temperature", type=float, default=0.01, help="Temperature for sampling (0.0 = greedy/deterministic)" 234 | ) 235 | sampling_group.add_argument( 236 | "--top_k", type=int, default=None, help="Top-k filtering: keep only k most likely tokens" 237 | ) 238 | sampling_group.add_argument( 239 | "--top_p", 240 | type=float, 241 | default=None, 242 | help="Top-p (nucleus) filtering: keep tokens with cumulative probability <= p", 243 | ) 244 | 245 | # Visualization 246 | viz_group = parser.add_argument_group("Visualization") 247 | viz_group.add_argument( 248 | "--no_viz", action="store_true", help="Disable live visualization during generation (requires rich library)" 249 | ) 250 | 251 | # Other settings 252 | other_group = parser.add_argument_group("Other Settings") 253 | other_group.add_argument("--seed", type=int, default=1234, help="Random seed for reproducibility") 254 | 255 | moe_backend_group = parser.add_argument_group("MoE Backend") 256 | moe_backend_group.add_argument( 257 | "--moe_backend", 258 | type=str, 259 | default="hf", 260 | choices=["hf", "vllm", "sglang", "flashinfer"], 261 | help="MoE backend to use for sparse mixture of experts layers", 262 | ) 263 | add_eos_at_end_group = parser.add_argument_group("EOS Token") 264 | add_eos_at_end_group.add_argument( 265 | "--add_eos_at_end", 266 | action="store_true", 267 | help="Add End of Sequence (EOS) token at the end of the sequence. " 268 | "This can be useful to force the model to generate a complete sentence.", 269 | ) 270 | 271 | args = parser.parse_args() 272 | 273 | if args.temperature < 0: 274 | parser.error("Temperature must be non-negative") 275 | if args.top_k is not None and args.top_k <= 0: 276 | parser.error("Top-k must be positive") 277 | if args.top_p is not None and (args.top_p <= 0 or args.top_p > 1): 278 | parser.error("Top-p must be between 0 and 1") 279 | 280 | print("\n" + "=" * 60) 281 | print("RND1 Diffusion Language Model Demo") 282 | print("=" * 60) 283 | print("Configuration:") 284 | print(f" Model: {args.model_path}") 285 | if args.checkpoint: 286 | print(f" Checkpoint: {args.checkpoint}") 287 | print(f" Device: {args.device}") 288 | print(f" Precision: {'FP32' if args.fp32 else 'BF16'}") 289 | print( 290 | f" Mode: {args.mode.upper()} ({'Q&A format for instructions' if args.mode == 'task' else 'Text continuation'})" 291 | ) 292 | print(f" Random seed: {args.seed}") 293 | print(f" Diffusion steps: {args.num_steps}") 294 | print(f" Max new tokens: {args.max_new_tokens}") 295 | print(" Algorithm: Entropy-based selection") 296 | print(f" Temperature: {args.temperature}") 297 | if args.top_k: 298 | print(f" Top-k: {args.top_k}") 299 | if args.top_p: 300 | print(f" Top-p: {args.top_p}") 301 | print(f" MoE Backend: {args.moe_backend}") 302 | print(f" Visualization: {'Enabled' if not args.no_viz else 'Disabled'}") 303 | print("=" * 60 + "\n") 304 | 305 | demo_completion( 306 | model_path=args.model_path, 307 | checkpoint_path=args.checkpoint, 308 | device=args.device, 309 | use_bfloat16=not args.fp32, 310 | show_visualization=not args.no_viz, 311 | num_steps=args.num_steps, 312 | max_new_tokens=args.max_new_tokens, 313 | custom_prompt=args.prompt, 314 | temperature=args.temperature, 315 | top_k=args.top_k, 316 | top_p=args.top_p, 317 | mask_token_id=args.mask_token_id, 318 | seed=args.seed, 319 | moe_backend=args.moe_backend, 320 | mode=args.mode, 321 | add_eos_at_end=args.add_eos_at_end, 322 | ) 323 | 324 | 325 | if __name__ == "__main__": 326 | main() 327 | -------------------------------------------------------------------------------- /rnd/modeling_rnd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Radical Numerics Inc. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0, found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | RND1 model implementation. 8 | 9 | This module implements the RND1 architecture with bidirectional attention for 10 | diffusion-based language modeling. Includes support for Mixture of Experts (MoE) 11 | with multiple backend options (HF, vLLM, SGLang, FlashInfer). 12 | 13 | Based on the Qwen3Moe architecture: 14 | https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py 15 | """ 16 | 17 | from __future__ import annotations 18 | 19 | import os 20 | 21 | import torch 22 | import torch.nn.functional as F 23 | 24 | from torch import nn 25 | from transformers.cache_utils import Cache 26 | from transformers.configuration_utils import PretrainedConfig 27 | from transformers.generation import GenerationConfig 28 | from transformers.modeling_outputs import MaskedLMOutput, MoeModelOutputWithPast 29 | from transformers.modeling_utils import PreTrainedModel 30 | from transformers.models.qwen3_moe.modeling_qwen3_moe import ( 31 | Qwen3MoeMLP, 32 | Qwen3MoeRMSNorm, 33 | Qwen3MoeRotaryEmbedding, 34 | apply_rotary_pos_emb, 35 | ) 36 | from transformers.utils import logging 37 | 38 | from .configuration_rnd import RND1Config 39 | from .generation_utils import RND1GenerationMixin 40 | 41 | vllm_import_error = None 42 | try: 43 | from vllm.model_executor.layers.fused_moe.fused_moe import ( 44 | fused_experts as fused_experts_vllm, 45 | fused_topk as fused_topk_vllm, 46 | ) 47 | from vllm.model_executor.layers.layernorm import RMSNorm as VLLMRMSNorm 48 | except ImportError as e: 49 | fused_experts_vllm = None 50 | fused_topk_vllm = None 51 | vllm_import_error = e 52 | 53 | sglang_import_error = None 54 | try: 55 | from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe as sglang_fused_moe 56 | 57 | # from sglang.srt.layers.layernorm import RMSNorm as SGLangRMSNorm # TODO: buggy atm 58 | from sglang.srt.layers.moe.topk import StandardTopKOutput 59 | except ImportError as e: 60 | sglang_fused_moe = None 61 | StandardTopKOutput = None 62 | sglang_import_error = e 63 | 64 | flashinfer_import_error = None 65 | try: 66 | import flashinfer.fused_moe as fused_moe 67 | ## TODO: below needs flashinfer>=0.4.0, but has some bug atm 68 | # from flashinfer.norm import rmsnorm as flashinfer_rmsnorm 69 | # class FlashInferRMSNorm(Qwen3MoeRMSNorm): 70 | # """Wrapper around FlashInfer RMSNorm to match Qwen3MoeRMSNorm interface""" 71 | # def forward(self, hidden_states): 72 | # return flashinfer_rmsnorm(hidden_states, self.weight, self.variance_epsilon) 73 | except ImportError as e: 74 | fused_moe = None 75 | flashinfer_import_error = e 76 | logger = logging.get_logger(__name__) 77 | 78 | 79 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 80 | """Expand key/value heads to match query heads for grouped-query attention.""" 81 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 82 | if n_rep == 1: 83 | return hidden_states 84 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 85 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 86 | 87 | 88 | class RND1Attention(nn.Module): 89 | """RND1 attention layer with bidirectional attention for diffusion modeling.""" 90 | 91 | def __init__(self, config: RND1Config, layer_idx: int): 92 | super().__init__() 93 | self.config = config 94 | self.layer_idx = layer_idx 95 | 96 | self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) 97 | self.num_heads = config.num_attention_heads 98 | self.num_key_value_heads = config.num_key_value_heads 99 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 100 | 101 | self.scaling = self.head_dim**-0.5 102 | self.attention_dropout = config.attention_dropout 103 | self.is_causal = False 104 | 105 | self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) 106 | self.k_proj = nn.Linear( 107 | config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias 108 | ) 109 | self.v_proj = nn.Linear( 110 | config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias 111 | ) 112 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) 113 | 114 | if config.moe_backend == "vllm": 115 | RMSNormClass = VLLMRMSNorm 116 | else: 117 | RMSNormClass = Qwen3MoeRMSNorm 118 | self.q_norm = RMSNormClass(self.head_dim, eps=config.rms_norm_eps) 119 | self.k_norm = RMSNormClass(self.head_dim, eps=config.rms_norm_eps) 120 | 121 | self.sliding_window = getattr(config, "sliding_window", None) 122 | 123 | self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config) 124 | 125 | def forward( 126 | self, 127 | hidden_states: torch.Tensor, 128 | attention_mask: torch.Tensor | None = None, 129 | position_ids: torch.LongTensor | None = None, 130 | past_key_values: Cache | tuple[torch.Tensor, torch.Tensor] | None = None, 131 | cache_position: torch.LongTensor | None = None, 132 | position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, 133 | dual_cache: bool | None = False, 134 | replace_position: torch.Tensor | None = None, 135 | **kwargs, 136 | ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | tuple[torch.Tensor, torch.Tensor] | None]: 137 | bsz, q_len, _ = hidden_states.size() 138 | input_shape = hidden_states.shape[:-1] 139 | hidden_shape = (*input_shape, -1, self.head_dim) 140 | 141 | query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) 142 | key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) 143 | value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) 144 | 145 | cos, sin = position_embeddings 146 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) 147 | 148 | key_states = repeat_kv(key_states, self.num_key_value_groups) 149 | value_states = repeat_kv(value_states, self.num_key_value_groups) 150 | 151 | use_sdpa = getattr(self.config, "_attn_implementation", "eager") == "sdpa" 152 | 153 | if use_sdpa: 154 | if attention_mask is not None and isinstance(attention_mask, torch.Tensor): 155 | if attention_mask.dtype not in [torch.bool, torch.float32, torch.float16, torch.bfloat16]: 156 | attention_mask = attention_mask.to(dtype=query_states.dtype) 157 | 158 | assert not self.is_causal, f"Attention layer {self.layer_idx} is causal" 159 | attn_out = torch.nn.functional.scaled_dot_product_attention( 160 | query_states, 161 | key_states, 162 | value_states, 163 | attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None, 164 | dropout_p=self.attention_dropout if self.training else 0.0, 165 | is_causal=self.is_causal, 166 | ) 167 | attn_out = attn_out.transpose(1, 2).contiguous() 168 | attn_out = attn_out.view(bsz, q_len, self.num_heads * self.head_dim) 169 | attn_out = self.o_proj(attn_out) 170 | return attn_out, None 171 | 172 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling 173 | 174 | if attention_mask is not None: 175 | attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]] 176 | 177 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 178 | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) 179 | 180 | attn_out = torch.matmul(attn_weights, value_states) 181 | attn_out = attn_out.transpose(1, 2).contiguous().view(hidden_states.size(0), hidden_states.size(1), -1) 182 | attn_out = self.o_proj(attn_out) 183 | 184 | return attn_out, None 185 | 186 | 187 | class RND1DecoderLayer(nn.Module): 188 | """RND1 decoder layer with bidirectional attention for diffusion language modeling.""" 189 | 190 | def __init__(self, config: RND1Config, layer_idx: int): 191 | super().__init__() 192 | self.self_attn = RND1Attention(config, layer_idx) 193 | self.mlp = RND1SparseMoeBlock(config) 194 | if config.moe_backend == "vllm": 195 | RMSNormClass = VLLMRMSNorm 196 | else: 197 | RMSNormClass = Qwen3MoeRMSNorm 198 | self.input_layernorm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps) 199 | self.post_attention_layernorm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps) 200 | 201 | def forward( 202 | self, 203 | hidden_states: torch.Tensor, 204 | attention_mask: torch.Tensor | None = None, 205 | position_ids: torch.LongTensor | None = None, 206 | position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, 207 | replace_position: torch.Tensor | None = None, 208 | **kwargs, 209 | ) -> tuple[torch.FloatTensor, torch.Tensor | None]: 210 | residual = hidden_states 211 | hidden_states = self.input_layernorm(hidden_states) 212 | 213 | attn_out, attn_weights = self.self_attn( 214 | hidden_states, 215 | attention_mask=attention_mask, 216 | position_ids=position_ids, 217 | position_embeddings=position_embeddings, 218 | replace_position=replace_position, 219 | ) 220 | hidden_states = residual + attn_out 221 | 222 | residual = hidden_states 223 | hidden_states = self.post_attention_layernorm(hidden_states) 224 | ff_out = self.mlp(hidden_states) 225 | if isinstance(ff_out, tuple): 226 | ff_out = ff_out[0] 227 | hidden_states = residual + ff_out 228 | 229 | return hidden_states, attn_weights 230 | 231 | 232 | class RND1SparseMoeBlock(nn.Module): 233 | """RND1 Sparse MoE block with multiple backend support (HF, vLLM, SGLang, FlashInfer).""" 234 | 235 | def __init__(self, config: RND1Config): 236 | super().__init__() 237 | self.config = config 238 | self.backend = getattr(config, "moe_backend", "hf") 239 | self.num_experts = config.num_experts 240 | self.top_k = config.num_experts_per_tok 241 | self.norm_topk_prob = config.norm_topk_prob 242 | self.hidden_size = config.hidden_size 243 | self.intermediate_size = getattr(config, "moe_intermediate_size", config.intermediate_size) 244 | 245 | self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False) 246 | self.experts = nn.ModuleList( 247 | [Qwen3MoeMLP(config, intermediate_size=self.intermediate_size) for _ in range(self.num_experts)] 248 | ) 249 | 250 | # Cached weight tensors for optimized backends 251 | self._w1 = None 252 | self._w2 = None 253 | 254 | @torch.no_grad() 255 | def _initialize_weights( 256 | self, 257 | free_experts: bool = True, 258 | mode: str = "vllm", 259 | ) -> None: 260 | logger.info(f"Initializing weights for {mode} backend") 261 | # Stack directly on device where weights already reside (loaded by HF) 262 | gate_list: list[torch.Tensor] = [] 263 | up_list: list[torch.Tensor] = [] 264 | down_list: list[torch.Tensor] = [] 265 | 266 | # Collect weight references without any device moves 267 | for expert in self.experts: 268 | gate_list.append(expert.gate_proj.weight.data) 269 | up_list.append(expert.up_proj.weight.data) 270 | down_list.append(expert.down_proj.weight.data) 271 | 272 | gate_w_stacked = torch.stack(gate_list, dim=0).contiguous() 273 | up_w_stacked = torch.stack(up_list, dim=0).contiguous() 274 | down_w_stacked = torch.stack(down_list, dim=0).contiguous() 275 | 276 | if mode == "flashinfer": 277 | w1 = torch.cat([up_w_stacked, gate_w_stacked], dim=1) # FlashInfer expects [up; gate] ordering 278 | else: 279 | w1 = torch.cat([gate_w_stacked, up_w_stacked], dim=1) 280 | w2 = down_w_stacked 281 | self._w1 = w1 282 | self._w2 = w2 283 | 284 | if free_experts: 285 | # Free per-expert modules to reclaim memory 286 | logger.info(f"Freeing experts for {mode} backend") 287 | del self.experts 288 | self.experts = None 289 | 290 | def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 291 | """Forward pass with expert routing and computation.""" 292 | batch_size, sequence_length, hidden_dim = hidden_states.shape 293 | x = hidden_states.view(-1, hidden_dim) 294 | 295 | # Expert routing 296 | router_logits = self.gate(x) 297 | routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) 298 | 299 | if self.backend == "vllm": 300 | routing_weights, selected_experts, _ = fused_topk_vllm( 301 | hidden_states=x, 302 | gating_output=router_logits, 303 | topk=self.top_k, 304 | renormalize=self.norm_topk_prob, 305 | ) 306 | else: 307 | routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) 308 | if self.norm_topk_prob: 309 | routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) 310 | 311 | if self.backend == "hf": 312 | final_hidden_states = torch.zeros( 313 | (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device 314 | ) 315 | 316 | expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) 317 | expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() 318 | 319 | for expert_idx in expert_hit: 320 | expert_layer = self.experts[expert_idx] 321 | idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) 322 | current_state = x[top_x] 323 | current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] 324 | final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) 325 | out = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) 326 | return out, router_logits.view(batch_size, sequence_length, -1) 327 | 328 | elif self.backend == "flashinfer": 329 | # if self._flashinfer_fc1_weights is None or self._flashinfer_fc2_weights is None: 330 | # self._initialize_flashinfer_weights() 331 | if self._w1 is None or self._w2 is None: 332 | self._initialize_weights(mode="flashinfer") 333 | 334 | result = fused_moe.cutlass_fused_moe( 335 | input=x, 336 | token_selected_experts=selected_experts.to(torch.int), 337 | token_final_scales=routing_weights.to(torch.float32), 338 | fc1_expert_weights=self._w1, 339 | fc2_expert_weights=self._w2, 340 | output_dtype=x.dtype, 341 | quant_scales=None, 342 | ) 343 | if isinstance(result, (list, tuple)): 344 | out_flat = result[0] 345 | else: 346 | out_flat = result 347 | out = out_flat.view(batch_size, sequence_length, hidden_dim) 348 | return out, router_logits.view(batch_size, sequence_length, -1) 349 | 350 | elif self.backend == "sglang": 351 | if self._w1 is None or self._w2 is None: 352 | self._initialize_weights(mode="sglang") 353 | 354 | topk_output = StandardTopKOutput( 355 | topk_weights=routing_weights, 356 | topk_ids=selected_experts, 357 | router_logits=router_logits, 358 | ) 359 | 360 | out_flat = sglang_fused_moe( 361 | hidden_states=x, 362 | w1=self._w1, 363 | w2=self._w2, 364 | topk_output=topk_output, 365 | ) 366 | out = out_flat.view(batch_size, sequence_length, hidden_dim) 367 | return out, router_logits.view(batch_size, sequence_length, -1) 368 | 369 | elif self.backend == "vllm": 370 | if self._w1 is None or self._w2 is None: 371 | self._initialize_weights() 372 | 373 | out_flat = fused_experts_vllm( 374 | x, 375 | self._w1, 376 | self._w2, 377 | routing_weights, 378 | selected_experts, 379 | ) 380 | out = out_flat.view(batch_size, sequence_length, hidden_dim) 381 | return out, router_logits.view(batch_size, sequence_length, -1) 382 | 383 | else: 384 | raise ValueError(f"Invalid backend: {self.backend}") 385 | 386 | 387 | class RND1PreTrainedModel(PreTrainedModel): 388 | """Base class for RND1 models with weight initialization and loading support.""" 389 | 390 | config_class = RND1Config 391 | base_model_prefix = "model" 392 | supports_gradient_checkpointing = True 393 | _no_split_modules = ["RND1DecoderLayer"] 394 | _skip_keys_device_placement = "past_key_values" 395 | _supports_flash_attn_2 = True 396 | _supports_sdpa = True 397 | _supports_cache_class = True 398 | _supports_quantized_cache = True 399 | _supports_static_cache = True 400 | 401 | def _init_weights(self, module): 402 | """Initialize weights using normal distribution.""" 403 | std = self.config.initializer_range 404 | if isinstance(module, nn.Linear): 405 | module.weight.data.normal_(mean=0.0, std=std) 406 | if module.bias is not None: 407 | module.bias.data.zero_() 408 | elif isinstance(module, nn.Embedding): 409 | module.weight.data.normal_(mean=0.0, std=std) 410 | if module.padding_idx is not None: 411 | module.weight.data[module.padding_idx].zero_() 412 | 413 | @classmethod 414 | def from_pretrained( 415 | cls, 416 | pretrained_model_name_or_path: str | os.PathLike | None, 417 | *model_args, 418 | config: PretrainedConfig | str | os.PathLike | None = None, 419 | cache_dir: str | os.PathLike | None = None, 420 | ignore_mismatched_sizes: bool = False, 421 | force_download: bool = False, 422 | local_files_only: bool = False, 423 | token: str | bool | None = None, 424 | revision: str = "main", 425 | use_safetensors: bool | None = None, 426 | weights_only: bool = True, 427 | **kwargs, 428 | ): 429 | """Load pretrained model with generation config.""" 430 | 431 | # Catch backend errors early 432 | backend = getattr(config, "moe_backend", "hf") 433 | if backend == "sglang" and sglang_import_error is not None: 434 | raise RuntimeError(f"sglang is not available. Import error: {sglang_import_error}") 435 | elif backend == "flashinfer" and flashinfer_import_error is not None: 436 | raise RuntimeError(f"flashinfer is not available. Import error: {flashinfer_import_error}") 437 | elif backend == "vllm" and vllm_import_error is not None: 438 | raise RuntimeError(f"vllm is not available. Import error: {vllm_import_error}") 439 | 440 | _model = super().from_pretrained( 441 | pretrained_model_name_or_path, 442 | *model_args, 443 | config=config, 444 | cache_dir=cache_dir, 445 | ignore_mismatched_sizes=ignore_mismatched_sizes, 446 | force_download=force_download, 447 | local_files_only=local_files_only, 448 | token=token, 449 | revision=revision, 450 | use_safetensors=use_safetensors, 451 | weights_only=weights_only, 452 | **kwargs, 453 | ) 454 | 455 | resume_download = kwargs.get("resume_download", None) 456 | proxies = kwargs.get("proxies", None) 457 | subfolder = kwargs.get("subfolder", "") 458 | from_auto_class = kwargs.get("_from_auto", False) 459 | from_pipeline = kwargs.get("_from_pipeline", None) 460 | 461 | _model.generation_config = GenerationConfig.from_pretrained( 462 | pretrained_model_name_or_path, 463 | cache_dir=cache_dir, 464 | force_download=force_download, 465 | resume_download=resume_download, 466 | proxies=proxies, 467 | local_files_only=local_files_only, 468 | token=token, 469 | revision=revision, 470 | subfolder=subfolder, 471 | _from_auto=from_auto_class, 472 | _from_pipeline=from_pipeline, 473 | ) 474 | 475 | # If configured to use a fused backend, pack fused tensors once after load 476 | try: 477 | if backend in ("sglang", "vllm"): 478 | # Walk decoder layers and initialize fused weights 479 | model_core = getattr(_model, "model", _model) 480 | layers = getattr(model_core, "layers", None) 481 | if isinstance(layers, nn.ModuleList): 482 | for layer in layers: 483 | mlp = getattr(layer, "mlp", None) 484 | if hasattr(mlp, "_initialize_weights"): 485 | mlp._initialize_weights( 486 | free_experts=True, 487 | mode=backend, 488 | ) 489 | except Exception as _e: 490 | logger.warning(f"Backend {backend} weight processing skipped: {_e}") 491 | 492 | return _model 493 | 494 | 495 | class RND1Model(RND1PreTrainedModel): 496 | """RND1 transformer model with bidirectional attention for diffusion language modeling.""" 497 | 498 | def __init__(self, config: RND1Config): 499 | super().__init__(config) 500 | 501 | self.padding_idx = config.pad_token_id 502 | self.vocab_size = config.vocab_size 503 | 504 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 505 | self.layers = nn.ModuleList([RND1DecoderLayer(config, i) for i in range(config.num_hidden_layers)]) 506 | if config.moe_backend == "vllm": 507 | RMSNormClass = VLLMRMSNorm 508 | else: 509 | RMSNormClass = Qwen3MoeRMSNorm 510 | self.norm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps) 511 | 512 | self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config) 513 | 514 | self.post_init() 515 | 516 | def forward( 517 | self, 518 | input_ids: torch.LongTensor | None = None, 519 | attention_mask: torch.Tensor | None = None, 520 | position_ids: torch.LongTensor | None = None, 521 | inputs_embeds: torch.FloatTensor | None = None, 522 | **kwargs, 523 | ) -> MoeModelOutputWithPast: 524 | """Forward pass through the RND1 model.""" 525 | 526 | if (input_ids is None) == (inputs_embeds is None): 527 | raise ValueError("You must specify exactly one of input_ids or inputs_embeds") 528 | 529 | if inputs_embeds is None: 530 | inputs_embeds = self.embed_tokens(input_ids) 531 | 532 | if position_ids is None: 533 | position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) 534 | 535 | position_embeddings = self.rotary_emb(inputs_embeds, position_ids) 536 | 537 | hidden_states = inputs_embeds 538 | 539 | for layer in self.layers: 540 | hidden_states, _ = layer( 541 | hidden_states, 542 | attention_mask=attention_mask, 543 | position_ids=position_ids, 544 | position_embeddings=position_embeddings, 545 | ) 546 | 547 | hidden_states = self.norm(hidden_states) 548 | 549 | return MoeModelOutputWithPast( 550 | last_hidden_state=hidden_states, 551 | router_logits=None, 552 | ) 553 | 554 | 555 | class RND1LM(RND1PreTrainedModel, RND1GenerationMixin): 556 | """Radical Numerics Diffusion Language Model with bidirectional attention.""" 557 | 558 | def __init__(self, config: RND1Config): 559 | super().__init__(config) 560 | self.model = RND1Model(config) 561 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 562 | self.post_init() 563 | 564 | def get_input_embeddings(self): 565 | """Get the input embeddings layer.""" 566 | return self.model.embed_tokens 567 | 568 | def set_input_embeddings(self, value): 569 | """Set the input embeddings layer.""" 570 | self.model.embed_tokens = value 571 | 572 | def get_output_embeddings(self): 573 | """Get the output embeddings layer (lm_head).""" 574 | return self.lm_head 575 | 576 | def set_output_embeddings(self, new_embeddings): 577 | """Set the output embeddings layer (lm_head).""" 578 | self.lm_head = new_embeddings 579 | 580 | @classmethod 581 | def can_generate(cls) -> bool: 582 | """Indicates this model can generate text.""" 583 | return True 584 | 585 | def forward( 586 | self, 587 | input_ids: torch.LongTensor | None = None, 588 | attention_mask: torch.Tensor | None = None, 589 | position_ids: torch.LongTensor | None = None, 590 | inputs_embeds: torch.FloatTensor | None = None, 591 | labels: torch.LongTensor | None = None, 592 | **kwargs, 593 | ) -> MaskedLMOutput: 594 | """Forward pass with optional loss computation.""" 595 | outputs = self.model( 596 | input_ids=input_ids, 597 | attention_mask=attention_mask, 598 | position_ids=position_ids, 599 | inputs_embeds=inputs_embeds, 600 | **kwargs, 601 | ) 602 | logits = self.lm_head(outputs.last_hidden_state) 603 | 604 | loss = None 605 | if labels is not None: 606 | loss_fct = nn.CrossEntropyLoss() 607 | loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) 608 | 609 | return MaskedLMOutput( 610 | loss=loss, 611 | logits=logits, 612 | ) 613 | --------------------------------------------------------------------------------