├── NOTICE
├── .pre-commit-config.yaml
├── pyproject.toml
├── rnd
├── __init__.py
├── generation_config.py
├── configuration_rnd.py
├── generation_utils.py
├── terminal_visualizer.py
├── sampling.py
└── modeling_rnd.py
├── README.md
├── .gitignore
├── assets
└── rn-logo-desktop-vector.svg
├── LICENSE
└── demo_rnd_generation.py
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright 2025 Radical Numerics Inc.
2 |
3 | This source code is licensed under the Apache License, Version 2.0, found in the
4 | LICENSE file in the root directory of this source tree.
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # Usage
2 | # uv run pre-commit install
3 | # uv run pre-commit run --all-files
4 |
5 | repos:
6 | - repo: https://github.com/pre-commit/pre-commit-hooks
7 | rev: v6.0.0
8 | hooks:
9 | - id: check-added-large-files
10 | - id: check-case-conflict
11 | - id: check-merge-conflict
12 | - id: check-symlinks
13 | - id: mixed-line-ending
14 | - id: trailing-whitespace
15 |
16 | - repo: https://github.com/astral-sh/ruff-pre-commit
17 | rev: v0.13.2
18 | hooks:
19 | - id: ruff-format
20 | - id: ruff-check
21 | args: [--fix]
22 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "rnd"
7 | version = "0.1.0"
8 | dependencies = [
9 | "accelerate",
10 | "torch>=2.8",
11 | "transformers",
12 | "rich"
13 | ]
14 |
15 | [project.optional-dependencies]
16 | flashinfer = [
17 | "flashinfer-python",
18 | ]
19 | sglang = ["sglang[all]"]
20 | vllm = ["vllm"]
21 | linting = ["ruff>=0.13.2", "pre-commit>=4.0.0"]
22 |
23 | [tool.setuptools]
24 | packages = ["rnd"]
25 |
26 | [tool.ruff]
27 | line-length = 120
28 | target-version = "py311"
29 | show-fixes = false
30 | extend-exclude = ["*.ipynb"]
31 |
32 | [tool.ruff.lint]
33 | select = ["F", "E", "W", "I001", "UP"]
34 | task-tags = ["TODO", "FIXME"]
35 |
36 | [tool.ruff.lint.per-file-ignores]
37 | "__init__.py" = ["E402", "F401"]
38 |
39 | [tool.ruff.lint.isort]
40 | known-first-party = []
41 | known-third-party = []
42 | section-order = [
43 | "future",
44 | "standard-library",
45 | "third-party",
46 | "first-party",
47 | "local-folder",
48 | ]
49 | combine-as-imports = true
50 | split-on-trailing-comma = false
51 | lines-between-types = 1
--------------------------------------------------------------------------------
/rnd/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Radical Numerics Inc.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0, found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """
7 | Radical Numerics Diffusion (RND1) - Diffusion-based Language Model.
8 | """
9 |
10 | from .configuration_rnd import RND1Config
11 | from .generation_config import RND1GenerationConfig
12 | from .generation_utils import RND1GenerationMixin
13 | from .modeling_rnd import RND1LM, RND1Attention, RND1DecoderLayer, RND1Model, RND1PreTrainedModel, RND1SparseMoeBlock
14 | from .sampling import apply_top_k_filtering, apply_top_p_filtering, diffusion_sample
15 | from .terminal_visualizer import SimpleProgressBar, TerminalVisualizer
16 |
17 | __version__ = "0.1.0"
18 |
19 | __all__ = [
20 | "RND1Config",
21 | "RND1GenerationConfig",
22 | "RND1LM",
23 | "RND1Model",
24 | "RND1PreTrainedModel",
25 | "RND1Attention",
26 | "RND1DecoderLayer",
27 | "RND1SparseMoeBlock",
28 | "RND1GenerationMixin",
29 | "TerminalVisualizer",
30 | "SimpleProgressBar",
31 | ]
32 |
33 | # Register with HuggingFace Auto classes for local usage
34 | try:
35 | from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM
36 |
37 | AutoConfig.register("rnd1", RND1Config)
38 | AutoModel.register(RND1Config, RND1Model)
39 | AutoModelForMaskedLM.register(RND1Config, RND1LM)
40 | except ImportError:
41 | # transformers not available or Auto classes not imported
42 | pass
43 |
--------------------------------------------------------------------------------
/rnd/generation_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Radical Numerics Inc.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0, found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """
7 | RND1 Generation Configuration.
8 |
9 | This module defines the generation configuration for RND1 models,
10 | controlling the diffusion-based generation process.
11 | """
12 |
13 | from transformers.generation.configuration_utils import GenerationConfig
14 |
15 |
16 | class RND1GenerationConfig(GenerationConfig):
17 | """
18 | Configuration class for RND1 generation parameters.
19 |
20 | This class extends the base GenerationConfig to include parameters
21 | specific to diffusion-based language generation.
22 |
23 | Args:
24 | max_length: Maximum sequence length
25 | num_diffusion_steps: Number of denoising steps in the diffusion process
26 | mask_token_id: Token ID used for masking during diffusion
27 | temperature: Temperature for sampling (higher = more random)
28 | top_k: Optional top-k filtering
29 | top_p: Optional nucleus (top-p) filtering
30 | greedy: Whether to use greedy decoding (True) or stochastic sampling (False)
31 | **kwargs: Additional arguments passed to GenerationConfig
32 | """
33 |
34 | def __init__(
35 | self,
36 | max_length: int = 256,
37 | num_diffusion_steps: int = 256,
38 | mask_token_id: int = 151669,
39 | temperature: float = 0.1,
40 | top_k: int | None = None,
41 | top_p: float | None = None,
42 | greedy: bool = False,
43 | bos_token_id: int = None,
44 | eos_token_id: int = None,
45 | pad_token_id: int = None,
46 | use_cache: bool = False,
47 | **kwargs,
48 | ):
49 | # Force no caching for RND generation
50 | # kwargs['use_cache'] = False
51 | kwargs.pop("use_cache", None)
52 | super().__init__(
53 | max_length=max_length,
54 | bos_token_id=bos_token_id,
55 | eos_token_id=eos_token_id,
56 | pad_token_id=pad_token_id,
57 | temperature=temperature,
58 | top_k=top_k,
59 | top_p=top_p,
60 | do_sample=not greedy,
61 | use_cache=False,
62 | **kwargs,
63 | )
64 |
65 | # RND-specific parameters
66 | self.num_diffusion_steps = num_diffusion_steps
67 | self.mask_token_id = mask_token_id
68 | self.greedy = greedy
69 |
70 | def to_dict(self):
71 | """Convert configuration to dictionary."""
72 | output = super().to_dict()
73 | output["num_diffusion_steps"] = self.num_diffusion_steps
74 | output["mask_token_id"] = self.mask_token_id
75 | output["greedy"] = self.greedy
76 | return output
77 |
--------------------------------------------------------------------------------
/rnd/configuration_rnd.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Radical Numerics Inc.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0, found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """
7 | RND1 Model Configuration.
8 |
9 | This module defines the configuration class for RND1 models.
10 | The default settings are derived from Qwen/Qwen3-30B-A3B and augmented
11 | with RND1-specific parameters.
12 | """
13 |
14 | from transformers.configuration_utils import PretrainedConfig
15 |
16 | # Qwen3-30B-A3B / checkpoint defaults
17 | CONFIG_DEFAULTS = {
18 | "attention_bias": False,
19 | "attention_dropout": 0.0,
20 | "decoder_sparse_step": 1,
21 | "eos_token_id": 151645,
22 | "head_dim": 128,
23 | "hidden_act": "silu",
24 | "hidden_size": 2048,
25 | "initializer_range": 0.02,
26 | "intermediate_size": 6144,
27 | "max_position_embeddings": 40960,
28 | "max_window_layers": 48,
29 | "mlp_only_layers": [],
30 | "moe_intermediate_size": 768,
31 | "norm_topk_prob": True,
32 | "num_attention_heads": 32,
33 | "num_experts": 128,
34 | "num_experts_per_tok": 8,
35 | "num_hidden_layers": 48,
36 | "num_key_value_heads": 4,
37 | "output_router_logits": False,
38 | "pad_token_id": 151643,
39 | "rms_norm_eps": 1e-06,
40 | "rope_scaling": False,
41 | "rope_theta": 1000000.0,
42 | "router_aux_loss_coef": 0.001,
43 | "sliding_window": False,
44 | "tie_word_embeddings": False,
45 | "dtype": "bfloat16",
46 | "use_cache": False,
47 | "use_sliding_window": False,
48 | "vocab_size": 151936,
49 | }
50 |
51 |
52 | class RND1Config(PretrainedConfig):
53 | """
54 | Configuration class for RND1 models.
55 |
56 | This configuration extends Qwen3MoeConfig with additional parameters
57 | specific to the RND1 (Radical Numerics Diffusion v1) architecture.
58 |
59 | Args:
60 | moe_backend: Backend for MoE computation ("hf", "vllm", "sglang" or "flashinfer")
61 | num_diffusion_steps: Default number of diffusion steps for generation
62 | mask_token_id: Token ID used for masking (default: 151669 for Qwen)
63 | **kwargs: Additional arguments passed to Qwen3MoeConfig
64 | """
65 |
66 | model_type = "rnd1"
67 |
68 | def __init__(
69 | self,
70 | moe_backend: str = "hf",
71 | num_diffusion_steps: int = 256,
72 | mask_token_id: int = 151669,
73 | **kwargs,
74 | ):
75 | # Force non-causal and no caching for RND1
76 | kwargs["use_cache"] = False
77 | kwargs["is_causal"] = False
78 |
79 | super().__init__(**kwargs)
80 |
81 | # Set defaults after pretrained init to prevent overrides
82 | self.set_config_defaults()
83 |
84 | # QoL: set attn impl directly from config
85 | if "attn_implementation" in kwargs:
86 | self._attn_implementation = kwargs["attn_implementation"]
87 |
88 | # RND1-specific parameters
89 | self.moe_backend = moe_backend
90 | self.num_diffusion_steps = num_diffusion_steps
91 | self.mask_token_id = mask_token_id
92 |
93 | # Ensure bidirectional attention and no caching
94 | self.is_causal = False
95 | self.use_cache = False
96 |
97 | def set_config_defaults(self):
98 | """
99 | Ensure model defaults are set according to final training checkpoint
100 |
101 | Qwen3MoeConfig defaults don't match Qwen/Qwen3-30B-A3B settings from which
102 | RND1 is derived.
103 | """
104 | for k, v in CONFIG_DEFAULTS.items():
105 | setattr(self, k, v)
106 |
107 | def to_dict(self):
108 | """
109 | Serializes configuration to dictionary with auto_map for Hub.
110 |
111 | The auto_map ensures that when users load from HuggingFace Hub,
112 | the correct custom classes are automatically resolved.
113 | """
114 | data = super().to_dict()
115 | data.setdefault(
116 | "auto_map",
117 | {
118 | "AutoConfig": "configuration_rnd.RND1Config",
119 | "AutoModel": "modeling_rnd.RND1Model",
120 | "AutoModelForMaskedLM": "modeling_rnd.RND1LM",
121 | },
122 | )
123 | return data
124 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | RND1: Scaling Diffusion Language Models
4 |
5 |
6 |
7 | 
8 |
9 |
10 | This repository contains an inference harness for Radical Numerics Diffusion 1 (RND1), an experimental diffusion language model. RND1-Base-0910 is a 30B‑parameter sparse Mixture‑of‑Experts model with 3B active parameters per token, converted from an autoregressive base (Qwen3-30B-A3B) via continual pretraining on 500B tokens.
11 |
12 | We release RND1 models to catalyze further research on inference and post-training of DLMs.
13 |
14 | For more details, see:
15 |
16 | **Blog:** https://www.radicalnumerics.ai/blog/rnd1
17 |
18 | **Report:** https://www.radicalnumerics.ai/assets/rnd1_report.pdf
19 |
20 | **🤗:** https://huggingface.co/radicalnumerics/RND1-Base-0910
21 |
22 | **Models:**
23 | * **RND1-Base-0910**: first base model in the RND1 family. It has not been post-trained for specific usage.
24 |
25 |
26 | ## Installation
27 |
28 | ```bash
29 | # tested with Python 3.12
30 | pip install torch transformers accelerate numpy rich
31 | ```
32 |
33 | ```bash
34 | # backends enable faster inference through optimized MoE kernels:
35 | pip install flashinfer-python
36 | pip install sglang[all]
37 | pip install vllm
38 | ```
39 |
40 | ## Quick Start
41 |
42 |
43 |
44 | ```bash
45 | # Task mode (default) - for instructions, questions, or requests
46 | python demo_rnd_generation.py --prompt "Write a Python function that finds the longest common subsequence of two strings. Include comments explaining the algorithm." --moe_backend hf
47 |
48 | # Completion mode - for text continuation
49 | python demo_rnd_generation.py --mode completion --prompt "The key to understanding quantum computing lies in" --moe_backend hf
50 |
51 | # Sampling parameters
52 | python demo_rnd_generation.py --top_k 50 --temperature 0.7 --prompt "Explain how neural networks learn in simple terms" --moe_backend hf
53 | ```
54 |
55 |
56 | > [!WARNING]
57 | > Selecting a non-Huggingface MoE backend is highly encouraged for faster generation. Note however that non-HF backends currently support a single GPU only, so you need to set e.g. `export CUDA_VISIBLE_DEVICES=0` before running the script. If you use `flashinfer-python`, JIT compilation the first time the code is run may take a while unless `flashinfer-jit-cache` is installed.
58 |
59 | ### Demo Parameters
60 |
61 | - `--mode`: Generation mode - 'task' or 'completion' (default: task)
62 | - `task`: For instructions, questions, or requests (adds "Question:" prefix)
63 | - `completion`: For text continuation (no prefix added)
64 | - `--max_new_tokens`: Number of new tokens to generate (default: 256)
65 | - `--num_steps`: Diffusion denoising steps (default: 256)
66 | - `--temperature`: Sampling temperature, 0.0 for greedy (default: 0.01)
67 | - `--top_k`: Top-k filtering - keeps only k most likely tokens (works with greedy and sampling)
68 | - `--top_p`: Nucleus filtering - keeps tokens with cumulative probability ≤ p (works with greedy and sampling)
69 | - `--moe_backend`: Choose backend: hf, vllm, sglang, flashinfer (default: hf)
70 | - `--no_viz`: Disable visualization
71 | - `--add_eos_at_end`: Add End of Sequence (EOS) token at the end of the sequence; useful to force the model to come to a coherent end (default: False)
72 |
73 | ## Python API
74 |
75 | ```python
76 | from transformers import AutoTokenizer
77 | from rnd import RND1LM
78 |
79 | # Load tokenizer
80 | tokenizer = AutoTokenizer.from_pretrained("radicalnumerics/RND1-Base-0910", trust_remote_code=True)
81 |
82 | # Load model
83 | model = RND1LM.from_pretrained(
84 | "radicalnumerics/RND1-Base-0910",
85 | dtype="bfloat16",
86 | device_map="auto",
87 | trust_remote_code=True,
88 | moe_backend="hf", # hf (default), sglang, vllm, flashinfer
89 | )
90 |
91 | # Generate - Task mode (for instructions and questions)
92 | prompt = "Write a Python function that finds the longest common subsequence of two strings. Include comments explaining the algorithm."
93 | inputs = tokenizer(f"Question: {prompt}\nAnswer:", return_tensors="pt")
94 | input_ids = inputs.input_ids.to(model.device)
95 |
96 | # Generate
97 | output = model.generate(
98 | inputs=input_ids,
99 | max_new_tokens=256,
100 | num_diffusion_steps=256,
101 | temperature=0.01,
102 | )
103 |
104 | # Decode only the generated part
105 | text = tokenizer.decode(output[0], skip_special_tokens=True)
106 | print(text)
107 | ```
108 |
109 | ## Project Structure
110 |
111 | ```
112 | RND_dev/
113 | ├── README.md # This file
114 | ├── demo_rnd_generation.py # Demo script with command-line interface
115 | └── rnd/ # Core RND1 package
116 | ├── __init__.py # Package exports
117 | ├── configuration_rnd.py # RND1 model configuration
118 | ├── modeling_rnd.py # Core model implementation
119 | ├── generation_config.py # Generation configuration
120 | ├── generation_utils.py # Generation mixin and utilities
121 | ├── sampling.py # Diffusion sampling algorithm
122 | └── terminal_visualizer.py # Live visualization (optional)
123 | ```
124 |
125 |
126 | ---
127 |
128 |
129 |
130 |
131 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[codz]
4 | *$py.class
5 |
6 | # Temp files
7 | profiler_traces/
8 |
9 | # Agent files and commonly used patters for agentic coding
10 | RISA.md
11 | .risa/
12 | AGENTS.md
13 | TASK.md
14 |
15 | # data
16 | data/
17 |
18 | # C extensions
19 | *.so
20 |
21 | # Distribution / packaging
22 | .Python
23 | build/
24 | develop-eggs/
25 | dist/
26 | downloads/
27 | eggs/
28 | .eggs/
29 | lib/
30 | lib64/
31 | parts/
32 | sdist/
33 | var/
34 | wheels/
35 | share/python-wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 | MANIFEST
40 |
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .nox/
55 | .coverage
56 | .coverage.*
57 | .cache
58 | nosetests.xml
59 | coverage.xml
60 | *.cover
61 | *.py.cover
62 | .hypothesis/
63 | .pytest_cache/
64 | cover/
65 |
66 | # Translations
67 | *.mo
68 | *.pot
69 |
70 | # Django stuff:
71 | *.log
72 | local_settings.py
73 | db.sqlite3
74 | db.sqlite3-journal
75 |
76 | # Flask stuff:
77 | instance/
78 | .webassets-cache
79 |
80 | # Scrapy stuff:
81 | .scrapy
82 |
83 | # Sphinx documentation
84 | docs/_build/
85 |
86 | # PyBuilder
87 | .pybuilder/
88 | target/
89 |
90 | # Jupyter Notebook
91 | .ipynb_checkpoints
92 |
93 | # IPython
94 | profile_default/
95 | ipython_config.py
96 |
97 | # pyenv
98 | # For a library or package, you might want to ignore these files since the code is
99 | # intended to run in multiple environments; otherwise, check them in:
100 | # .python-version
101 |
102 | # pipenv
103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
106 | # install all needed dependencies.
107 | # Pipfile.lock
108 |
109 | # UV
110 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
111 | # This is especially recommended for binary packages to ensure reproducibility, and is more
112 | # commonly ignored for libraries.
113 | uv.lock
114 |
115 | # poetry
116 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
117 | # This is especially recommended for binary packages to ensure reproducibility, and is more
118 | # commonly ignored for libraries.
119 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
120 | # poetry.lock
121 | # poetry.toml
122 |
123 | # pdm
124 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
125 | # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
126 | # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
127 | # pdm.lock
128 | # pdm.toml
129 | .pdm-python
130 | .pdm-build/
131 |
132 | # pixi
133 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
134 | # pixi.lock
135 | # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
136 | # in the .venv directory. It is recommended not to include this directory in version control.
137 | .pixi
138 |
139 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
140 | __pypackages__/
141 |
142 | # Celery stuff
143 | celerybeat-schedule
144 | celerybeat.pid
145 |
146 | # Redis
147 | *.rdb
148 | *.aof
149 | *.pid
150 |
151 | # RabbitMQ
152 | mnesia/
153 | rabbitmq/
154 | rabbitmq-data/
155 |
156 | # ActiveMQ
157 | activemq-data/
158 |
159 | # SageMath parsed files
160 | *.sage.py
161 |
162 | # Environments
163 | .env
164 | .envrc
165 | .venv
166 | env/
167 | venv/
168 | ENV/
169 | env.bak/
170 | venv.bak/
171 |
172 | # Spyder project settings
173 | .spyderproject
174 | .spyproject
175 |
176 | # Rope project PEAR Contributorssettings
177 | .ropeproject
178 |
179 | # mkdocs documentation
180 | /site
181 |
182 | # mypy
183 | .mypy_cache/
184 | .dmypy.json
185 | dmypy.json
186 |
187 | # Pyre type checker
188 | .pyre/
189 |
190 | # pytype static type analyzer
191 | .pytype/
192 |
193 | # Cython debug symbols
194 | cython_debug/
195 |
196 | # PyCharm
197 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
198 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
199 | # and can be added to the global gitignore or merged into this file. For a more nuclear
200 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
201 | # .idea/
202 |
203 | # Abstra
204 | # Abstra is an AI-powered process automation framework.
205 | # Ignore directories containing user credentials, local state, and settings.
206 | # Learn more at https://abstra.io/docs
207 | .abstra/
208 |
209 | # Visual Studio Code
210 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
211 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
212 | # and can be added to the global gitignore or merged into this file. However, if you prefer,
213 | # you could uncomment the following to ignore the entire vscode folder
214 | # .vscode/
215 |
216 | # Ruff stuff:
217 | .ruff_cache/
218 |
219 | # PyPI configuration file
220 | .pypirc
221 |
222 | # Marimo
223 | marimo/_static/
224 | marimo/_lsp/
225 | __marimo__/
226 |
227 | # Streamlit
228 | .streamlit/secrets.toml
--------------------------------------------------------------------------------
/assets/rn-logo-desktop-vector.svg:
--------------------------------------------------------------------------------
1 |
2 |
134 |
--------------------------------------------------------------------------------
/rnd/generation_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Radical Numerics Inc.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0, found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """
7 | RND1 Generation Utilities.
8 |
9 | This module provides generation utilities and mixins for RND1 models,
10 | including the main GenerationMixin class that integrates with HuggingFace.
11 | """
12 |
13 | from typing import Any
14 |
15 | import torch
16 |
17 | from transformers import GenerationMixin as HFGenerationMixin
18 | from transformers.generation import GenerationConfig
19 |
20 | from .generation_config import RND1GenerationConfig
21 | from .sampling import diffusion_sample
22 |
23 |
24 | class RND1GenerationMixin(HFGenerationMixin):
25 | """
26 | Generation mixin for RND1 models.
27 |
28 | This mixin provides generation methods compatible with HuggingFace's
29 | generation API while using RND1's diffusion-based sampling internally.
30 | """
31 |
32 | def generate(
33 | self,
34 | inputs: torch.LongTensor | None = None,
35 | generation_config: GenerationConfig | None = None,
36 | # RND1-specific parameters
37 | prefix_ids: torch.LongTensor | None = None,
38 | suffix_ids: torch.LongTensor | None = None,
39 | infill_length: int | None = None,
40 | return_dict_in_generate: bool | None = None,
41 | **kwargs, # Accept all kwargs to be compatible with pipelines
42 | ) -> torch.LongTensor | dict[str, Any]:
43 | """
44 | Generate text using RND1's diffusion-based sampling.
45 |
46 | Follows HuggingFace's standard generate API, using diffusion sampling
47 | internally. Supports both standard generation and infilling.
48 |
49 | Args:
50 | inputs: Input token IDs to use as prefix (standard HF parameter)
51 | generation_config: Generation configuration object. Default is RND1GenerationConfig.
52 | prefix_ids: Alternative to inputs for infilling tasks
53 | suffix_ids: Optional suffix for infilling tasks
54 | infill_length: Length of infill region (for infilling)
55 | return_dict_in_generate: Whether to return GenerateDecoderOnlyOutput
56 | **kwargs: Additional arguments (accepted for compatibility). These will be passed to the config constructor.
57 |
58 | Returns:
59 | Generated token IDs or GenerateDecoderOnlyOutput
60 | """
61 | if generation_config is not None:
62 | gen_config = generation_config
63 | model_kwargs = kwargs.copy()
64 | else:
65 | # Only prepare config from kwargs if no config was provided
66 | gen_config, model_kwargs = self._prepare_generation_config(RND1GenerationConfig(), **kwargs)
67 |
68 | device = next(self.parameters()).device
69 |
70 | if inputs is not None:
71 | prefix_ids = inputs.to(device)
72 | elif prefix_ids is not None:
73 | prefix_ids = prefix_ids.to(device)
74 | else:
75 | prefix_ids = None
76 |
77 | if suffix_ids is not None:
78 | suffix_ids = suffix_ids.to(device)
79 |
80 | eos_token_id = gen_config.eos_token_id or getattr(self.config, "eos_token_id", 151645)
81 | pad_token_id = gen_config.pad_token_id or getattr(self.config, "pad_token_id", 151643)
82 | bos_token_id = gen_config.bos_token_id or getattr(self.config, "bos_token_id", None)
83 | mask_token_id = getattr(gen_config, "mask_token_id", getattr(self.config, "mask_token_id", 151669))
84 |
85 | if infill_length is not None and prefix_ids is not None:
86 | # Infilling mode: use specified infill_length
87 | prefix_len = prefix_ids.shape[1] if prefix_ids is not None else 0
88 | suffix_len = suffix_ids.shape[1] if suffix_ids is not None else 0
89 | seq_len = prefix_len + infill_length + suffix_len
90 | else:
91 | # Standard generation mode
92 | if prefix_ids is not None:
93 | prefix_len = prefix_ids.shape[1]
94 | if gen_config.max_new_tokens is not None:
95 | seq_len = prefix_len + gen_config.max_new_tokens
96 | else:
97 | seq_len = gen_config.max_length or self.config.max_position_embeddings
98 | else:
99 | seq_len = gen_config.max_length or self.config.max_position_embeddings
100 |
101 | num_diffusion_steps = getattr(
102 | gen_config, "num_diffusion_steps", getattr(self.config, "num_diffusion_steps", 256)
103 | )
104 |
105 | temperature = float(getattr(gen_config, "temperature", 1.0))
106 | top_k = getattr(gen_config, "top_k", None)
107 | top_p = getattr(gen_config, "top_p", None)
108 |
109 | greedy = getattr(
110 | gen_config, "greedy", not bool(gen_config.do_sample) if hasattr(gen_config, "do_sample") else True
111 | )
112 |
113 | with torch.inference_mode():
114 | sequences = diffusion_sample(
115 | model=self,
116 | seq_len=seq_len,
117 | num_steps=num_diffusion_steps,
118 | mask_token_id=mask_token_id,
119 | temperature=temperature,
120 | top_k=top_k,
121 | top_p=top_p,
122 | greedy=greedy,
123 | prefix_ids=prefix_ids,
124 | suffix_ids=suffix_ids,
125 | infill_length=infill_length,
126 | eos_token_id=eos_token_id,
127 | pad_token_id=pad_token_id,
128 | bos_token_id=bos_token_id,
129 | device=device,
130 | visualizer=model_kwargs.get("visualizer", None), # Optional visualizer from kwargs,
131 | add_eos_at_end=getattr(gen_config, "add_eos_at_end", False),
132 | )
133 |
134 | if return_dict_in_generate or getattr(gen_config, "return_dict_in_generate", False):
135 | from transformers.generation.utils import GenerateDecoderOnlyOutput
136 |
137 | return GenerateDecoderOnlyOutput(sequences=sequences)
138 |
139 | return sequences
140 |
141 | def generate_with_visualization(
142 | self,
143 | tokenizer,
144 | inputs: torch.LongTensor | None = None,
145 | generation_config: GenerationConfig | None = None,
146 | suffix_ids: torch.LongTensor | None = None,
147 | infill_length: int | None = None,
148 | **kwargs,
149 | ) -> torch.LongTensor:
150 | """
151 | Generate with live visualization (for demos).
152 |
153 | This method requires a tokenizer to display the generation process.
154 | For production use, prefer `generate()`.
155 |
156 | Args:
157 | tokenizer: Tokenizer for decoding tokens to text
158 | inputs: Input token IDs to use as prefix
159 | generation_config: Generation configuration object
160 | suffix_ids: Optional suffix token IDs
161 | infill_length: Length of infill region
162 | **kwargs: Additional arguments for backward compatibility
163 |
164 | Returns:
165 | Generated token IDs as LongTensor
166 | """
167 | from .terminal_visualizer import TerminalVisualizer
168 |
169 | visualizer = TerminalVisualizer(tokenizer, show_visualization=True)
170 |
171 | return self.generate(
172 | inputs=inputs,
173 | generation_config=generation_config,
174 | suffix_ids=suffix_ids,
175 | infill_length=infill_length,
176 | visualizer=visualizer,
177 | return_dict_in_generate=False,
178 | **kwargs,
179 | )
180 |
181 | def prepare_inputs_for_generation(
182 | self,
183 | input_ids: torch.LongTensor,
184 | **kwargs,
185 | ) -> dict[str, Any]:
186 | """
187 | Prepare inputs for generation (required by HuggingFace).
188 |
189 | For RND1, we don't use the standard autoregressive generation,
190 | so this just returns the input_ids.
191 | """
192 | return {"input_ids": input_ids}
193 |
--------------------------------------------------------------------------------
/rnd/terminal_visualizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Radical Numerics Inc.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0, found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """
7 | Terminal visualization for RND1 generation.
8 |
9 | This module provides real-time visualization of the diffusion denoising process,
10 | showing token evolution and generation progress in the terminal using rich
11 | formatting when available.
12 | """
13 |
14 | import torch
15 |
16 | from tqdm import tqdm
17 |
18 | try:
19 | from rich.console import Console
20 | from rich.layout import Layout
21 | from rich.live import Live
22 | from rich.panel import Panel
23 | from rich.progress import BarColumn, MofNCompleteColumn, Progress, TextColumn, TimeRemainingColumn
24 | from rich.text import Text
25 |
26 | RICH_AVAILABLE = True
27 | except ImportError:
28 | RICH_AVAILABLE = False
29 |
30 |
31 | class TerminalVisualizer:
32 | """
33 | Rich-based visualization for diffusion process with live updates.
34 |
35 | Provides real-time visualization of the token denoising process during
36 | diffusion-based language generation, with colored highlighting of masked
37 | positions and progress tracking.
38 | """
39 |
40 | def __init__(self, tokenizer, show_visualization: bool = True):
41 | """
42 | Initialize the terminal visualizer.
43 |
44 | Args:
45 | tokenizer: The tokenizer for decoding tokens to text
46 | show_visualization: Whether to show visualization (requires rich)
47 | """
48 | self.tokenizer = tokenizer
49 | self.show_visualization = show_visualization and RICH_AVAILABLE
50 | if not RICH_AVAILABLE and show_visualization:
51 | print("Warning: Install 'rich' for better visualization. Falling back to simple progress bar.")
52 | self.show_visualization = False
53 |
54 | if self.show_visualization:
55 | self.console = Console()
56 | self.live = None
57 | self.progress = None
58 | self.layout = None
59 | else:
60 | self.pbar = None
61 |
62 | self.current_tokens = None
63 | self.mask_positions = None
64 | self.total_steps = 0
65 | self.current_step = 0
66 |
67 | def start_visualization(self, initial_tokens: torch.LongTensor, mask_positions: torch.BoolTensor, total_steps: int):
68 | """
69 | Start the visualization.
70 |
71 | Args:
72 | initial_tokens: Initial token IDs (possibly masked)
73 | mask_positions: Boolean mask indicating which positions are masked
74 | total_steps: Total number of diffusion steps
75 | """
76 | if not self.show_visualization:
77 | self.pbar = tqdm(total=total_steps, desc="Diffusion")
78 | return
79 |
80 | self.current_tokens = initial_tokens.clone()
81 | self.mask_positions = mask_positions
82 | self.total_steps = total_steps
83 | self.current_step = 0
84 |
85 | self.layout = Layout()
86 | self.layout.split_column(
87 | Layout(name="header", size=3), Layout(name="text", ratio=1), Layout(name="progress", size=3)
88 | )
89 |
90 | self.progress = Progress(
91 | TextColumn("[bold blue]Diffusion"),
92 | BarColumn(),
93 | MofNCompleteColumn(),
94 | TextColumn("•"),
95 | TextColumn("[cyan]Masks: {task.fields[masks]}"),
96 | TimeRemainingColumn(),
97 | )
98 | self.progress_task = self.progress.add_task("Generating", total=total_steps, masks=mask_positions.sum().item())
99 |
100 | self.live = Live(self.layout, console=self.console, refresh_per_second=4)
101 | self.live.start()
102 | self._update_display()
103 |
104 | def update_step(
105 | self,
106 | tokens: torch.LongTensor,
107 | maskable: torch.BoolTensor | None,
108 | step: int,
109 | entropy: torch.FloatTensor | None = None,
110 | confidence: torch.FloatTensor | None = None,
111 | ):
112 | """
113 | Update visualization for current step.
114 |
115 | Args:
116 | tokens: Current token IDs
117 | maskable: Boolean mask of remaining masked positions
118 | step: Current step number
119 | entropy: Optional entropy scores for each position
120 | confidence: Optional confidence scores for each position
121 | """
122 | if not self.show_visualization:
123 | if self.pbar:
124 | self.pbar.update(1)
125 | masks = maskable.sum().item() if maskable is not None else 0
126 | self.pbar.set_postfix({"masks": masks})
127 | return
128 |
129 | self.current_tokens = tokens.clone()
130 | self.mask_positions = maskable
131 | self.current_step = step
132 |
133 | masks_remaining = maskable.sum().item() if maskable is not None else 0
134 | self.progress.update(self.progress_task, advance=1, masks=masks_remaining)
135 |
136 | self._update_display()
137 |
138 | def _update_display(self):
139 | """Update the live display."""
140 | if not self.live:
141 | return
142 |
143 | header = Text("RND1-Base Generation", style="bold magenta", justify="center")
144 | self.layout["header"].update(Panel(header, border_style="bright_blue"))
145 |
146 | text_display = self._format_text_with_masks()
147 | self.layout["text"].update(
148 | Panel(
149 | text_display,
150 | title="[bold]Generated Text",
151 | subtitle=f"[dim]Step {self.current_step}/{self.total_steps}[/dim]",
152 | border_style="cyan",
153 | )
154 | )
155 |
156 | self.layout["progress"].update(Panel(self.progress))
157 |
158 | def _format_text_with_masks(self) -> Text:
159 | """
160 | Format text with colored masks.
161 |
162 | Returns:
163 | Rich Text object with formatted tokens
164 | """
165 | text = Text()
166 |
167 | if self.current_tokens is None:
168 | return text
169 |
170 | token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens
171 | mask_flags = (
172 | self.mask_positions[0]
173 | if self.mask_positions is not None and self.mask_positions.dim() > 1
174 | else self.mask_positions
175 | )
176 |
177 | for i, token_id in enumerate(token_ids):
178 | if mask_flags is not None and i < len(mask_flags) and mask_flags[i]:
179 | # Alternate colors for visual effect
180 | text.append(
181 | "[MASK]", style="bold red on yellow" if self.current_step % 2 == 0 else "bold yellow on red"
182 | )
183 | else:
184 | try:
185 | token_str = self.tokenizer.decode([token_id.item()], skip_special_tokens=False)
186 | # Skip special tokens in display
187 | if token_str not in ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "", ""]:
188 | # Color based on position
189 | text.append(token_str, style="green" if i < len(token_ids) // 2 else "cyan")
190 | except Exception:
191 | continue
192 |
193 | return text
194 |
195 | def stop_visualization(self):
196 | """Stop the visualization and display final result."""
197 | if not self.show_visualization:
198 | if self.pbar:
199 | self.pbar.close()
200 | print("\n✨ Generation complete!\n")
201 | return
202 |
203 | if self.live:
204 | self.live.stop()
205 |
206 | self.console.print("\n[bold green]✨ Generation complete![/bold green]\n")
207 |
208 | # Display final text
209 | if self.current_tokens is not None:
210 | try:
211 | token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens
212 | final_text = self.tokenizer.decode(token_ids, skip_special_tokens=True)
213 |
214 | self.console.print(
215 | Panel(final_text, title="[bold]Final Generated Text", border_style="green", padding=(1, 2))
216 | )
217 | except Exception:
218 | pass
219 |
220 |
221 | class SimpleProgressBar:
222 | """
223 | Simple progress bar fallback when rich is not available.
224 |
225 | Provides basic progress tracking using tqdm when the rich library
226 | is not installed.
227 | """
228 |
229 | def __init__(self, total_steps: int):
230 | """
231 | Initialize simple progress bar.
232 |
233 | Args:
234 | total_steps: Total number of steps
235 | """
236 | self.pbar = tqdm(total=total_steps, desc="Diffusion")
237 |
238 | def update(self, masks_remaining: int = 0):
239 | """
240 | Update progress bar.
241 |
242 | Args:
243 | masks_remaining: Number of masks still remaining
244 | """
245 | self.pbar.update(1)
246 | self.pbar.set_postfix({"masks": masks_remaining})
247 |
248 | def close(self):
249 | """Close the progress bar."""
250 | self.pbar.close()
251 | print("\n✨ Generation complete!\n")
252 |
--------------------------------------------------------------------------------
/rnd/sampling.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Radical Numerics Inc.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0, found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """
7 | RND1 sampling module for masked diffusion generation.
8 |
9 | This module implements entropy-based token selection for iterative denoising
10 | in diffusion language models. Supports both greedy and stochastic sampling
11 | with optional prefix/suffix constraints and infilling.
12 | """
13 |
14 | import torch
15 | import torch.nn as nn
16 |
17 |
18 | def apply_top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
19 | """
20 | Apply top-k filtering to logits: with non-top-k values set to -inf
21 | """
22 | top_k_values, top_k_indices = torch.topk(logits, min(k, logits.size(-1)), dim=-1)
23 | filtered_logits = torch.full_like(logits, float("-inf"))
24 | filtered_logits.scatter_(-1, top_k_indices, top_k_values)
25 | return filtered_logits
26 |
27 |
28 | def apply_top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
29 | """
30 | Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf
31 | """
32 | sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
33 | cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
34 |
35 | # Remove tokens with cumulative probability above threshold
36 | sorted_indices_to_remove = cumulative_probs > p
37 | sorted_indices_to_remove[..., 0] = False # Keep at least one token
38 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
39 |
40 | indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
41 | return logits.masked_fill(indices_to_remove, float("-inf"))
42 |
43 |
44 | @torch.no_grad()
45 | def diffusion_sample(
46 | model: nn.Module,
47 | seq_len: int = 256,
48 | num_steps: int = 256,
49 | top_k: int | None = None,
50 | top_p: float | None = None,
51 | temperature: float = 1.0,
52 | greedy: bool = True,
53 | mask_token_id: int = 151669,
54 | prefix_ids: torch.LongTensor | None = None,
55 | suffix_ids: torch.LongTensor | None = None,
56 | infill_length: int | None = None,
57 | eos_token_id: int = 151645,
58 | pad_token_id: int | None = None,
59 | bos_token_id: int | None = None,
60 | device: str | torch.device | None = None,
61 | visualizer: object | None = None,
62 | add_eos_at_end: bool = False,
63 | ) -> torch.LongTensor:
64 | """
65 | Perform masked diffusion sampling with entropy-based token selection.
66 |
67 | Args:
68 | model: The RND1 language model
69 | seq_len: Target sequence length
70 | num_steps: Number of denoising steps
71 | top_k: Optional top-k filtering for sampling (None = no filtering)
72 | top_p: Optional nucleus (top-p) filtering for sampling (None = no filtering)
73 | When both top_k and top_p are set, top_k is applied first, then top_p
74 | temperature: Temperature for sampling (higher = more random, lower = more deterministic)
75 | Values close to 0 are clamped to 1e-8 to avoid division by zero
76 | greedy: Whether to use greedy sampling (True) or stochastic (False)
77 | mask_token_id: Token ID for masked positions (default: 151669)
78 | prefix_ids: Optional prefix token IDs to preserve
79 | suffix_ids: Optional suffix token IDs to preserve
80 | infill_length: Length of infill region between prefix/suffix
81 | eos_token_id: End of sequence token ID (default: 151645)
82 | pad_token_id: Padding token ID (default: None, uses 0 if needed)
83 | bos_token_id: Beginning of sequence token ID (default: None)
84 | device: Device for computation (None = infer from model)
85 | visualizer: Optional visualizer for live visualization
86 | add_eos_at_end: Whether to force EOS token at the end of the sequence
87 |
88 | Returns:
89 | Generated token IDs as LongTensor
90 | """
91 | model.eval()
92 |
93 | if device is None:
94 | device = next(model.parameters()).device
95 | else:
96 | device = torch.device(device)
97 |
98 | if pad_token_id is None:
99 | pad_token_id = 0
100 |
101 | # Build initial masked sequence
102 | # When prefix_ids is provided, we create a sequence of length seq_len where:
103 | # - The prefix occupies the first pre_len positions
104 | # - The remaining (seq_len - pre_len) positions are filled with mask tokens to be generated
105 | if prefix_ids is not None or suffix_ids is not None:
106 | if prefix_ids is not None:
107 | prefix_ids = (
108 | prefix_ids.to(device)
109 | if isinstance(prefix_ids, torch.Tensor)
110 | else torch.tensor(prefix_ids, device=device)
111 | )
112 | pre_len = prefix_ids.shape[-1] if prefix_ids.dim() > 0 else 0
113 | else:
114 | pre_len = 0
115 |
116 | if suffix_ids is not None:
117 | suffix_ids = (
118 | suffix_ids.to(device)
119 | if isinstance(suffix_ids, torch.Tensor)
120 | else torch.tensor(suffix_ids, device=device)
121 | )
122 | suf_len = suffix_ids.shape[-1] if suffix_ids.dim() > 0 else 0
123 | else:
124 | suf_len = 0
125 |
126 | reserved = 1 if eos_token_id is not None else 0
127 | used = pre_len + suf_len + reserved
128 |
129 | if used > seq_len:
130 | raise ValueError(
131 | f"Combined length of prefix ({pre_len}), suffix ({suf_len}), "
132 | f"and special tokens ({reserved}) = {used} exceeds seq_len ({seq_len}). "
133 | f"Please increase seq_len or reduce input lengths."
134 | )
135 | elif used == seq_len:
136 | raise ValueError(
137 | f"No space for generation: prefix ({pre_len}) + suffix ({suf_len}) "
138 | f"+ special tokens ({reserved}) = seq_len ({seq_len}). "
139 | f"Need at least 1 position for generation."
140 | )
141 |
142 | infill_length = min(infill_length or (seq_len - used), seq_len - used)
143 |
144 | x = torch.full((1, seq_len), pad_token_id, dtype=torch.long, device=device)
145 | pos = 0
146 | # if bos_token_id is not None:
147 | # x[0, pos] = bos_token_id; pos += 1
148 | if eos_token_id is not None and add_eos_at_end:
149 | x[0, -1] = eos_token_id
150 | if pre_len > 0:
151 | x[0, pos : pos + pre_len] = prefix_ids.flatten()[:pre_len]
152 | pos += pre_len
153 | fill_start, fill_end = pos, pos + infill_length
154 | x[0, fill_start:fill_end] = mask_token_id
155 | # print(fill_start, fill_end, seq_len, used, x[0, -1])
156 | pos = fill_end
157 | if suf_len > 0:
158 | x[0, pos : pos + suf_len] = suffix_ids.flatten()[:suf_len]
159 | pos += suf_len
160 |
161 | init_maskable = torch.zeros_like(x, dtype=torch.bool)
162 | init_maskable[0, fill_start:fill_end] = True
163 | else:
164 | x = torch.full((1, seq_len), mask_token_id, dtype=torch.long, device=device)
165 | if bos_token_id is not None:
166 | x[0, 0] = bos_token_id
167 | if eos_token_id is not None and add_eos_at_end:
168 | x[0, -1] = eos_token_id
169 | init_maskable = x.eq(mask_token_id)
170 |
171 | if bos_token_id is not None:
172 | init_maskable[:, 0] = False
173 | if eos_token_id is not None:
174 | init_maskable &= x.ne(eos_token_id)
175 | init_maskable &= x.ne(pad_token_id)
176 |
177 | maskable = init_maskable.clone()
178 | xt = x.clone()
179 |
180 | if visualizer:
181 | visualizer.start_visualization(xt, maskable, num_steps)
182 |
183 | def forward_scores(tokens):
184 | """Compute predictions and entropy scores for next tokens."""
185 | # Try with input_ids parameter first (standard HF models)
186 | try:
187 | model_output = model(input_ids=tokens)
188 | except TypeError:
189 | # Fall back to positional argument
190 | model_output = model(tokens)
191 |
192 | # Apply temperature scaling (with safety for near-zero temperature)
193 | safe_temperature = max(temperature, 1e-8) # Prevent division by zero
194 | logits = model_output.logits / safe_temperature
195 |
196 | # Apply filtering strategies
197 | # Note: When both top_k and top_p are provided, they are applied sequentially:
198 | # First top_k filters to k tokens, then top_p filters from those k tokens
199 | if top_k is not None and top_k > 0:
200 | logits = apply_top_k_filtering(logits, top_k)
201 |
202 | if top_p is not None and 0 < top_p < 1.0:
203 | logits = apply_top_p_filtering(logits, top_p)
204 |
205 | # Convert to log probabilities
206 | logp = torch.log_softmax(logits, dim=-1)
207 |
208 | # Greedy or stochastic sampling
209 | if greedy:
210 | pred_next = logp.argmax(-1)
211 | else:
212 | pred_next = torch.distributions.Categorical(logits=logp).sample()
213 |
214 | conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1)
215 |
216 | p = logp.exp()
217 | ent_next = -(p * logp).sum(-1)
218 |
219 | # Shift predictions: pos i predicts token i+1
220 | pred_i = tokens.clone()
221 | conf_i = torch.full_like(conf_next, torch.finfo(conf_next.dtype).min)
222 | ent_i = torch.zeros_like(ent_next)
223 |
224 | pred_i[:, 1:] = pred_next[:, :-1]
225 | conf_i[:, 1:] = conf_next[:, :-1]
226 | ent_i[:, 1:] = ent_next[:, :-1]
227 |
228 | return pred_i, conf_i, ent_i
229 |
230 | pred_i, conf_i, ent_i = forward_scores(xt)
231 | total_masked = init_maskable.sum(1, keepdim=True)
232 | finf = torch.finfo(conf_i.dtype)
233 |
234 | for step in range(num_steps - 1, 0, -1):
235 | rate = step / num_steps
236 | cutoff_len = (total_masked * rate).long().clamp(min=0)
237 |
238 | # Choose HIGH-entropy tokens to keep masked
239 | sel_scores = ent_i.masked_fill(~maskable, -finf.max)
240 | B, L = sel_scores.shape
241 | k_max = cutoff_len.max().item()
242 | if k_max > 0:
243 | sss, idx = torch.topk(sel_scores, k_max, dim=-1, largest=True)
244 | keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool)
245 | for b in range(B):
246 | k_b = int(cutoff_len[b].item())
247 | if k_b > 0:
248 | keep_mask[b, idx[b, :k_b]] = True
249 | else:
250 | keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool)
251 |
252 | to_unmask = maskable & ~keep_mask
253 | if to_unmask.any():
254 | xt[to_unmask] = pred_i[to_unmask]
255 | maskable[to_unmask] = False
256 |
257 | if visualizer:
258 | visualizer.update_step(xt, maskable, num_steps - step, ent_i, conf_i)
259 |
260 | if maskable.any():
261 | pred_i, conf_i, ent_i = forward_scores(xt)
262 |
263 | if maskable.any():
264 | xt[maskable] = pred_i[maskable]
265 |
266 | if visualizer:
267 | visualizer.stop_visualization()
268 |
269 | return xt
270 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2025 Radical Numerics Inc.
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/demo_rnd_generation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Demo script for RND1 generation.
4 | """
5 |
6 | import argparse
7 | import os
8 | import random
9 | import sys
10 |
11 | import numpy as np
12 | import torch
13 |
14 | from transformers import AutoTokenizer
15 |
16 | # Add RND1 module to path for local testing
17 | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
18 |
19 |
20 | def set_seed(seed: int):
21 | """Set random seed for reproducibility."""
22 | random.seed(seed)
23 | np.random.seed(seed)
24 | torch.manual_seed(seed)
25 | torch.cuda.manual_seed_all(seed)
26 |
27 |
28 | def demo_completion(
29 | model_path: str,
30 | checkpoint_path: str = None,
31 | device: str = "cuda:0",
32 | use_bfloat16: bool = True,
33 | show_visualization: bool = True,
34 | num_steps: int = 64,
35 | max_new_tokens: int = 256,
36 | custom_prompt: str = None,
37 | temperature: float = 1.0,
38 | top_k: int = None,
39 | top_p: float = None,
40 | mask_token_id: int = 151669,
41 | seed: int = None,
42 | moe_backend: str = "hf",
43 | mode: str = "task",
44 | add_eos_at_end: bool = False,
45 | ):
46 | """
47 | Demonstrate text completion using RND1.
48 |
49 | Args:
50 | model_path: Path to base model or HuggingFace model ID
51 | checkpoint_path: Path to custom checkpoint (if any)
52 | device: Device to run on (e.g., cuda:0, cpu)
53 | use_bfloat16: Whether to use bfloat16 precision
54 | show_visualization: Whether to show live visualization (requires rich)
55 | num_steps: Number of diffusion steps
56 | max_new_tokens: Maximum number of tokens to generate
57 | custom_prompt: Custom prompt to use instead of default examples
58 | temperature: Temperature for sampling (0.0 = greedy)
59 | top_k: Top-k filtering for sampling (None = disabled)
60 | top_p: Top-p (nucleus) filtering for sampling (None = disabled)
61 | mask_token_id: Token ID for mask token
62 | seed: Random seed for reproducibility
63 | moe_backend: MoE backend to use ('hf', 'vllm', 'sglang', 'flashinfer')
64 | mode: Generation mode ('task' for Q&A format, 'completion' for continuation)
65 | add_eos_at_end: Whether to add EOS token at the end of the sequence
66 | """
67 | # if seed is not None:
68 | if seed is None:
69 | # generate a random seed
70 | seed = random.randint(0, 1000000)
71 | print(f"Seed not provided, using random seed: {seed}")
72 | set_seed(seed)
73 |
74 | from rnd.configuration_rnd import RND1Config
75 | from rnd.modeling_rnd import RND1LM
76 |
77 | print("Loading tokenizer...")
78 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
79 |
80 | dtype = torch.bfloat16 if use_bfloat16 else torch.float32
81 | print(f"Using dtype: {dtype}")
82 |
83 | if moe_backend == "hf":
84 | print(
85 | "\n⚠️ Note: HuggingFace backend is slower. "
86 | "Consider using --moe_backend vllm, sglang or flashinfer for better performance.\n"
87 | )
88 |
89 | # Load from checkpoint if provided, otherwise from model_path
90 | load_path = checkpoint_path if checkpoint_path else model_path
91 |
92 | print(f"Loading model from {load_path}...")
93 |
94 | # Load config and set RND1-specific settings
95 | cfg = RND1Config.from_pretrained(load_path)
96 | cfg.model_type = "rnd1"
97 | cfg.attn_implementation = "sdpa"
98 | cfg.moe_backend = moe_backend
99 |
100 | # Load model with RND1LM
101 | model = RND1LM.from_pretrained(
102 | load_path,
103 | config=cfg,
104 | dtype=dtype,
105 | device_map="auto" if device == "cuda:0" else device,
106 | trust_remote_code=True,
107 | use_safetensors=True,
108 | low_cpu_mem_usage=True,
109 | )
110 | print("Model loaded")
111 | model = model.eval()
112 |
113 | if custom_prompt:
114 | prompts = [custom_prompt]
115 | else:
116 | # Default prompts based on mode
117 | if mode == "task":
118 | prompts = [
119 | "Write a Python function that finds the longest common subsequence of two strings."
120 | "Include comments explaining the algorithm."
121 | ]
122 | else:
123 | prompts = ["The key to understanding quantum computing lies in"]
124 |
125 | greedy = temperature == 0.0
126 |
127 | for i, user_prompt in enumerate(prompts):
128 | print(f"\n{'=' * 60}")
129 | print(f"Mode: {mode.upper()}")
130 | print(f"Prompt {i + 1}: {user_prompt[:100]}...")
131 | print(f"{'=' * 60}\n")
132 |
133 | if mode == "task":
134 | # Task mode: Add "Question: " prefix if not already present
135 | if not user_prompt.strip().startswith("Question:"):
136 | prompt = f"Question: {user_prompt}\nAnswer:"
137 | else:
138 | prompt = user_prompt
139 | else:
140 | # Completion mode: Use prompt as-is for continuation
141 | prompt = user_prompt
142 |
143 | inputs = tokenizer(prompt, return_tensors="pt")
144 | input_ids = inputs.input_ids.to(device if device != "auto" else "cuda")
145 |
146 | print("Generation parameters:")
147 | print(f" Prompt length: {input_ids.shape[1]} tokens")
148 | print(f" Max new tokens: {max_new_tokens}")
149 | print(f" Total sequence: {input_ids.shape[1] + max_new_tokens} tokens")
150 | print(f" Diffusion steps: {num_steps}")
151 | print(f" Temperature: {temperature}")
152 | print(f" Greedy: {greedy}")
153 | if top_k:
154 | print(f" Top-k: {top_k}")
155 | if top_p:
156 | print(f" Top-p: {top_p}")
157 | print()
158 |
159 | # Create explicit generation config that takes priority over model defaults
160 | from rnd.generation_config import RND1GenerationConfig
161 |
162 | gen_config = RND1GenerationConfig(
163 | max_new_tokens=max_new_tokens,
164 | num_diffusion_steps=num_steps,
165 | mask_token_id=mask_token_id,
166 | temperature=temperature if not greedy else 0.0,
167 | top_k=top_k,
168 | top_p=top_p,
169 | greedy=greedy,
170 | eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id else 151645,
171 | pad_token_id=tokenizer.pad_token_id,
172 | bos_token_id=tokenizer.bos_token_id,
173 | add_eos_at_end=add_eos_at_end,
174 | )
175 |
176 | with torch.no_grad():
177 | if show_visualization and hasattr(model, "generate_with_visualization"):
178 | # Use method with visualization support (requires tokenizer)
179 | output = model.generate_with_visualization(
180 | tokenizer=tokenizer,
181 | inputs=input_ids,
182 | generation_config=gen_config,
183 | )
184 | else:
185 | # Use standard generate method with explicit config
186 | output = model.generate(
187 | inputs=input_ids,
188 | generation_config=gen_config,
189 | )
190 |
191 | generated_tokens = output[0][len(input_ids[0]) :]
192 | generation = tokenizer.decode(generated_tokens.tolist(), skip_special_tokens=True)
193 |
194 | if not show_visualization: # by default the viz shows final response too
195 | print("\nGenerated response:")
196 | print(generation)
197 |
198 | print(f"\n(Generation completed in {num_steps} diffusion steps)")
199 |
200 |
201 | def main():
202 | parser = argparse.ArgumentParser(
203 | description="RND1 diffusion model demo with live visualization",
204 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
205 | )
206 |
207 | # Model configuration
208 | model_group = parser.add_argument_group("Model Configuration")
209 | model_group.add_argument(
210 | "--model_path", type=str, default="radicalnumerics/RND1-Base-0910", help="Path to model or HuggingFace model ID"
211 | )
212 | model_group.add_argument("--checkpoint", type=str, default=None, help="Path to custom checkpoint file or directory")
213 | model_group.add_argument("--device", type=str, default="cuda:0", help="Device to run on (e.g., cuda:0, cpu)")
214 | model_group.add_argument("--fp32", action="store_true", help="Use FP32 precision instead of BF16")
215 |
216 | # Generation configuration
217 | gen_group = parser.add_argument_group("Generation Settings")
218 | gen_group.add_argument("--num_steps", type=int, default=256, help="Number of diffusion steps")
219 | gen_group.add_argument("--max_new_tokens", type=int, default=256, help="Maximum number of tokens to generate")
220 | gen_group.add_argument("--prompt", type=str, default=None, help="Custom prompt to use for generation")
221 | gen_group.add_argument(
222 | "--mode",
223 | type=str,
224 | default="task",
225 | choices=["task", "completion"],
226 | help="Generation mode: 'task' (Q&A format for instructions) or 'completion' (text continuation)",
227 | )
228 | gen_group.add_argument("--mask_token_id", type=int, default=151669, help="Token ID for mask token")
229 |
230 | # Sampling configuration
231 | sampling_group = parser.add_argument_group("Sampling Parameters")
232 | sampling_group.add_argument(
233 | "--temperature", type=float, default=0.01, help="Temperature for sampling (0.0 = greedy/deterministic)"
234 | )
235 | sampling_group.add_argument(
236 | "--top_k", type=int, default=None, help="Top-k filtering: keep only k most likely tokens"
237 | )
238 | sampling_group.add_argument(
239 | "--top_p",
240 | type=float,
241 | default=None,
242 | help="Top-p (nucleus) filtering: keep tokens with cumulative probability <= p",
243 | )
244 |
245 | # Visualization
246 | viz_group = parser.add_argument_group("Visualization")
247 | viz_group.add_argument(
248 | "--no_viz", action="store_true", help="Disable live visualization during generation (requires rich library)"
249 | )
250 |
251 | # Other settings
252 | other_group = parser.add_argument_group("Other Settings")
253 | other_group.add_argument("--seed", type=int, default=1234, help="Random seed for reproducibility")
254 |
255 | moe_backend_group = parser.add_argument_group("MoE Backend")
256 | moe_backend_group.add_argument(
257 | "--moe_backend",
258 | type=str,
259 | default="hf",
260 | choices=["hf", "vllm", "sglang", "flashinfer"],
261 | help="MoE backend to use for sparse mixture of experts layers",
262 | )
263 | add_eos_at_end_group = parser.add_argument_group("EOS Token")
264 | add_eos_at_end_group.add_argument(
265 | "--add_eos_at_end",
266 | action="store_true",
267 | help="Add End of Sequence (EOS) token at the end of the sequence. "
268 | "This can be useful to force the model to generate a complete sentence.",
269 | )
270 |
271 | args = parser.parse_args()
272 |
273 | if args.temperature < 0:
274 | parser.error("Temperature must be non-negative")
275 | if args.top_k is not None and args.top_k <= 0:
276 | parser.error("Top-k must be positive")
277 | if args.top_p is not None and (args.top_p <= 0 or args.top_p > 1):
278 | parser.error("Top-p must be between 0 and 1")
279 |
280 | print("\n" + "=" * 60)
281 | print("RND1 Diffusion Language Model Demo")
282 | print("=" * 60)
283 | print("Configuration:")
284 | print(f" Model: {args.model_path}")
285 | if args.checkpoint:
286 | print(f" Checkpoint: {args.checkpoint}")
287 | print(f" Device: {args.device}")
288 | print(f" Precision: {'FP32' if args.fp32 else 'BF16'}")
289 | print(
290 | f" Mode: {args.mode.upper()} ({'Q&A format for instructions' if args.mode == 'task' else 'Text continuation'})"
291 | )
292 | print(f" Random seed: {args.seed}")
293 | print(f" Diffusion steps: {args.num_steps}")
294 | print(f" Max new tokens: {args.max_new_tokens}")
295 | print(" Algorithm: Entropy-based selection")
296 | print(f" Temperature: {args.temperature}")
297 | if args.top_k:
298 | print(f" Top-k: {args.top_k}")
299 | if args.top_p:
300 | print(f" Top-p: {args.top_p}")
301 | print(f" MoE Backend: {args.moe_backend}")
302 | print(f" Visualization: {'Enabled' if not args.no_viz else 'Disabled'}")
303 | print("=" * 60 + "\n")
304 |
305 | demo_completion(
306 | model_path=args.model_path,
307 | checkpoint_path=args.checkpoint,
308 | device=args.device,
309 | use_bfloat16=not args.fp32,
310 | show_visualization=not args.no_viz,
311 | num_steps=args.num_steps,
312 | max_new_tokens=args.max_new_tokens,
313 | custom_prompt=args.prompt,
314 | temperature=args.temperature,
315 | top_k=args.top_k,
316 | top_p=args.top_p,
317 | mask_token_id=args.mask_token_id,
318 | seed=args.seed,
319 | moe_backend=args.moe_backend,
320 | mode=args.mode,
321 | add_eos_at_end=args.add_eos_at_end,
322 | )
323 |
324 |
325 | if __name__ == "__main__":
326 | main()
327 |
--------------------------------------------------------------------------------
/rnd/modeling_rnd.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Radical Numerics Inc.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0, found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """
7 | RND1 model implementation.
8 |
9 | This module implements the RND1 architecture with bidirectional attention for
10 | diffusion-based language modeling. Includes support for Mixture of Experts (MoE)
11 | with multiple backend options (HF, vLLM, SGLang, FlashInfer).
12 |
13 | Based on the Qwen3Moe architecture:
14 | https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
15 | """
16 |
17 | from __future__ import annotations
18 |
19 | import os
20 |
21 | import torch
22 | import torch.nn.functional as F
23 |
24 | from torch import nn
25 | from transformers.cache_utils import Cache
26 | from transformers.configuration_utils import PretrainedConfig
27 | from transformers.generation import GenerationConfig
28 | from transformers.modeling_outputs import MaskedLMOutput, MoeModelOutputWithPast
29 | from transformers.modeling_utils import PreTrainedModel
30 | from transformers.models.qwen3_moe.modeling_qwen3_moe import (
31 | Qwen3MoeMLP,
32 | Qwen3MoeRMSNorm,
33 | Qwen3MoeRotaryEmbedding,
34 | apply_rotary_pos_emb,
35 | )
36 | from transformers.utils import logging
37 |
38 | from .configuration_rnd import RND1Config
39 | from .generation_utils import RND1GenerationMixin
40 |
41 | vllm_import_error = None
42 | try:
43 | from vllm.model_executor.layers.fused_moe.fused_moe import (
44 | fused_experts as fused_experts_vllm,
45 | fused_topk as fused_topk_vllm,
46 | )
47 | from vllm.model_executor.layers.layernorm import RMSNorm as VLLMRMSNorm
48 | except ImportError as e:
49 | fused_experts_vllm = None
50 | fused_topk_vllm = None
51 | vllm_import_error = e
52 |
53 | sglang_import_error = None
54 | try:
55 | from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe as sglang_fused_moe
56 |
57 | # from sglang.srt.layers.layernorm import RMSNorm as SGLangRMSNorm # TODO: buggy atm
58 | from sglang.srt.layers.moe.topk import StandardTopKOutput
59 | except ImportError as e:
60 | sglang_fused_moe = None
61 | StandardTopKOutput = None
62 | sglang_import_error = e
63 |
64 | flashinfer_import_error = None
65 | try:
66 | import flashinfer.fused_moe as fused_moe
67 | ## TODO: below needs flashinfer>=0.4.0, but has some bug atm
68 | # from flashinfer.norm import rmsnorm as flashinfer_rmsnorm
69 | # class FlashInferRMSNorm(Qwen3MoeRMSNorm):
70 | # """Wrapper around FlashInfer RMSNorm to match Qwen3MoeRMSNorm interface"""
71 | # def forward(self, hidden_states):
72 | # return flashinfer_rmsnorm(hidden_states, self.weight, self.variance_epsilon)
73 | except ImportError as e:
74 | fused_moe = None
75 | flashinfer_import_error = e
76 | logger = logging.get_logger(__name__)
77 |
78 |
79 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
80 | """Expand key/value heads to match query heads for grouped-query attention."""
81 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape
82 | if n_rep == 1:
83 | return hidden_states
84 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
85 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
86 |
87 |
88 | class RND1Attention(nn.Module):
89 | """RND1 attention layer with bidirectional attention for diffusion modeling."""
90 |
91 | def __init__(self, config: RND1Config, layer_idx: int):
92 | super().__init__()
93 | self.config = config
94 | self.layer_idx = layer_idx
95 |
96 | self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
97 | self.num_heads = config.num_attention_heads
98 | self.num_key_value_heads = config.num_key_value_heads
99 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads
100 |
101 | self.scaling = self.head_dim**-0.5
102 | self.attention_dropout = config.attention_dropout
103 | self.is_causal = False
104 |
105 | self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
106 | self.k_proj = nn.Linear(
107 | config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias
108 | )
109 | self.v_proj = nn.Linear(
110 | config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias
111 | )
112 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
113 |
114 | if config.moe_backend == "vllm":
115 | RMSNormClass = VLLMRMSNorm
116 | else:
117 | RMSNormClass = Qwen3MoeRMSNorm
118 | self.q_norm = RMSNormClass(self.head_dim, eps=config.rms_norm_eps)
119 | self.k_norm = RMSNormClass(self.head_dim, eps=config.rms_norm_eps)
120 |
121 | self.sliding_window = getattr(config, "sliding_window", None)
122 |
123 | self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)
124 |
125 | def forward(
126 | self,
127 | hidden_states: torch.Tensor,
128 | attention_mask: torch.Tensor | None = None,
129 | position_ids: torch.LongTensor | None = None,
130 | past_key_values: Cache | tuple[torch.Tensor, torch.Tensor] | None = None,
131 | cache_position: torch.LongTensor | None = None,
132 | position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
133 | dual_cache: bool | None = False,
134 | replace_position: torch.Tensor | None = None,
135 | **kwargs,
136 | ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | tuple[torch.Tensor, torch.Tensor] | None]:
137 | bsz, q_len, _ = hidden_states.size()
138 | input_shape = hidden_states.shape[:-1]
139 | hidden_shape = (*input_shape, -1, self.head_dim)
140 |
141 | query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
142 | key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
143 | value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
144 |
145 | cos, sin = position_embeddings
146 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
147 |
148 | key_states = repeat_kv(key_states, self.num_key_value_groups)
149 | value_states = repeat_kv(value_states, self.num_key_value_groups)
150 |
151 | use_sdpa = getattr(self.config, "_attn_implementation", "eager") == "sdpa"
152 |
153 | if use_sdpa:
154 | if attention_mask is not None and isinstance(attention_mask, torch.Tensor):
155 | if attention_mask.dtype not in [torch.bool, torch.float32, torch.float16, torch.bfloat16]:
156 | attention_mask = attention_mask.to(dtype=query_states.dtype)
157 |
158 | assert not self.is_causal, f"Attention layer {self.layer_idx} is causal"
159 | attn_out = torch.nn.functional.scaled_dot_product_attention(
160 | query_states,
161 | key_states,
162 | value_states,
163 | attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None,
164 | dropout_p=self.attention_dropout if self.training else 0.0,
165 | is_causal=self.is_causal,
166 | )
167 | attn_out = attn_out.transpose(1, 2).contiguous()
168 | attn_out = attn_out.view(bsz, q_len, self.num_heads * self.head_dim)
169 | attn_out = self.o_proj(attn_out)
170 | return attn_out, None
171 |
172 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
173 |
174 | if attention_mask is not None:
175 | attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]]
176 |
177 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
178 | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
179 |
180 | attn_out = torch.matmul(attn_weights, value_states)
181 | attn_out = attn_out.transpose(1, 2).contiguous().view(hidden_states.size(0), hidden_states.size(1), -1)
182 | attn_out = self.o_proj(attn_out)
183 |
184 | return attn_out, None
185 |
186 |
187 | class RND1DecoderLayer(nn.Module):
188 | """RND1 decoder layer with bidirectional attention for diffusion language modeling."""
189 |
190 | def __init__(self, config: RND1Config, layer_idx: int):
191 | super().__init__()
192 | self.self_attn = RND1Attention(config, layer_idx)
193 | self.mlp = RND1SparseMoeBlock(config)
194 | if config.moe_backend == "vllm":
195 | RMSNormClass = VLLMRMSNorm
196 | else:
197 | RMSNormClass = Qwen3MoeRMSNorm
198 | self.input_layernorm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps)
199 | self.post_attention_layernorm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps)
200 |
201 | def forward(
202 | self,
203 | hidden_states: torch.Tensor,
204 | attention_mask: torch.Tensor | None = None,
205 | position_ids: torch.LongTensor | None = None,
206 | position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
207 | replace_position: torch.Tensor | None = None,
208 | **kwargs,
209 | ) -> tuple[torch.FloatTensor, torch.Tensor | None]:
210 | residual = hidden_states
211 | hidden_states = self.input_layernorm(hidden_states)
212 |
213 | attn_out, attn_weights = self.self_attn(
214 | hidden_states,
215 | attention_mask=attention_mask,
216 | position_ids=position_ids,
217 | position_embeddings=position_embeddings,
218 | replace_position=replace_position,
219 | )
220 | hidden_states = residual + attn_out
221 |
222 | residual = hidden_states
223 | hidden_states = self.post_attention_layernorm(hidden_states)
224 | ff_out = self.mlp(hidden_states)
225 | if isinstance(ff_out, tuple):
226 | ff_out = ff_out[0]
227 | hidden_states = residual + ff_out
228 |
229 | return hidden_states, attn_weights
230 |
231 |
232 | class RND1SparseMoeBlock(nn.Module):
233 | """RND1 Sparse MoE block with multiple backend support (HF, vLLM, SGLang, FlashInfer)."""
234 |
235 | def __init__(self, config: RND1Config):
236 | super().__init__()
237 | self.config = config
238 | self.backend = getattr(config, "moe_backend", "hf")
239 | self.num_experts = config.num_experts
240 | self.top_k = config.num_experts_per_tok
241 | self.norm_topk_prob = config.norm_topk_prob
242 | self.hidden_size = config.hidden_size
243 | self.intermediate_size = getattr(config, "moe_intermediate_size", config.intermediate_size)
244 |
245 | self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
246 | self.experts = nn.ModuleList(
247 | [Qwen3MoeMLP(config, intermediate_size=self.intermediate_size) for _ in range(self.num_experts)]
248 | )
249 |
250 | # Cached weight tensors for optimized backends
251 | self._w1 = None
252 | self._w2 = None
253 |
254 | @torch.no_grad()
255 | def _initialize_weights(
256 | self,
257 | free_experts: bool = True,
258 | mode: str = "vllm",
259 | ) -> None:
260 | logger.info(f"Initializing weights for {mode} backend")
261 | # Stack directly on device where weights already reside (loaded by HF)
262 | gate_list: list[torch.Tensor] = []
263 | up_list: list[torch.Tensor] = []
264 | down_list: list[torch.Tensor] = []
265 |
266 | # Collect weight references without any device moves
267 | for expert in self.experts:
268 | gate_list.append(expert.gate_proj.weight.data)
269 | up_list.append(expert.up_proj.weight.data)
270 | down_list.append(expert.down_proj.weight.data)
271 |
272 | gate_w_stacked = torch.stack(gate_list, dim=0).contiguous()
273 | up_w_stacked = torch.stack(up_list, dim=0).contiguous()
274 | down_w_stacked = torch.stack(down_list, dim=0).contiguous()
275 |
276 | if mode == "flashinfer":
277 | w1 = torch.cat([up_w_stacked, gate_w_stacked], dim=1) # FlashInfer expects [up; gate] ordering
278 | else:
279 | w1 = torch.cat([gate_w_stacked, up_w_stacked], dim=1)
280 | w2 = down_w_stacked
281 | self._w1 = w1
282 | self._w2 = w2
283 |
284 | if free_experts:
285 | # Free per-expert modules to reclaim memory
286 | logger.info(f"Freeing experts for {mode} backend")
287 | del self.experts
288 | self.experts = None
289 |
290 | def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
291 | """Forward pass with expert routing and computation."""
292 | batch_size, sequence_length, hidden_dim = hidden_states.shape
293 | x = hidden_states.view(-1, hidden_dim)
294 |
295 | # Expert routing
296 | router_logits = self.gate(x)
297 | routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
298 |
299 | if self.backend == "vllm":
300 | routing_weights, selected_experts, _ = fused_topk_vllm(
301 | hidden_states=x,
302 | gating_output=router_logits,
303 | topk=self.top_k,
304 | renormalize=self.norm_topk_prob,
305 | )
306 | else:
307 | routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
308 | if self.norm_topk_prob:
309 | routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
310 |
311 | if self.backend == "hf":
312 | final_hidden_states = torch.zeros(
313 | (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
314 | )
315 |
316 | expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
317 | expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
318 |
319 | for expert_idx in expert_hit:
320 | expert_layer = self.experts[expert_idx]
321 | idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
322 | current_state = x[top_x]
323 | current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
324 | final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
325 | out = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
326 | return out, router_logits.view(batch_size, sequence_length, -1)
327 |
328 | elif self.backend == "flashinfer":
329 | # if self._flashinfer_fc1_weights is None or self._flashinfer_fc2_weights is None:
330 | # self._initialize_flashinfer_weights()
331 | if self._w1 is None or self._w2 is None:
332 | self._initialize_weights(mode="flashinfer")
333 |
334 | result = fused_moe.cutlass_fused_moe(
335 | input=x,
336 | token_selected_experts=selected_experts.to(torch.int),
337 | token_final_scales=routing_weights.to(torch.float32),
338 | fc1_expert_weights=self._w1,
339 | fc2_expert_weights=self._w2,
340 | output_dtype=x.dtype,
341 | quant_scales=None,
342 | )
343 | if isinstance(result, (list, tuple)):
344 | out_flat = result[0]
345 | else:
346 | out_flat = result
347 | out = out_flat.view(batch_size, sequence_length, hidden_dim)
348 | return out, router_logits.view(batch_size, sequence_length, -1)
349 |
350 | elif self.backend == "sglang":
351 | if self._w1 is None or self._w2 is None:
352 | self._initialize_weights(mode="sglang")
353 |
354 | topk_output = StandardTopKOutput(
355 | topk_weights=routing_weights,
356 | topk_ids=selected_experts,
357 | router_logits=router_logits,
358 | )
359 |
360 | out_flat = sglang_fused_moe(
361 | hidden_states=x,
362 | w1=self._w1,
363 | w2=self._w2,
364 | topk_output=topk_output,
365 | )
366 | out = out_flat.view(batch_size, sequence_length, hidden_dim)
367 | return out, router_logits.view(batch_size, sequence_length, -1)
368 |
369 | elif self.backend == "vllm":
370 | if self._w1 is None or self._w2 is None:
371 | self._initialize_weights()
372 |
373 | out_flat = fused_experts_vllm(
374 | x,
375 | self._w1,
376 | self._w2,
377 | routing_weights,
378 | selected_experts,
379 | )
380 | out = out_flat.view(batch_size, sequence_length, hidden_dim)
381 | return out, router_logits.view(batch_size, sequence_length, -1)
382 |
383 | else:
384 | raise ValueError(f"Invalid backend: {self.backend}")
385 |
386 |
387 | class RND1PreTrainedModel(PreTrainedModel):
388 | """Base class for RND1 models with weight initialization and loading support."""
389 |
390 | config_class = RND1Config
391 | base_model_prefix = "model"
392 | supports_gradient_checkpointing = True
393 | _no_split_modules = ["RND1DecoderLayer"]
394 | _skip_keys_device_placement = "past_key_values"
395 | _supports_flash_attn_2 = True
396 | _supports_sdpa = True
397 | _supports_cache_class = True
398 | _supports_quantized_cache = True
399 | _supports_static_cache = True
400 |
401 | def _init_weights(self, module):
402 | """Initialize weights using normal distribution."""
403 | std = self.config.initializer_range
404 | if isinstance(module, nn.Linear):
405 | module.weight.data.normal_(mean=0.0, std=std)
406 | if module.bias is not None:
407 | module.bias.data.zero_()
408 | elif isinstance(module, nn.Embedding):
409 | module.weight.data.normal_(mean=0.0, std=std)
410 | if module.padding_idx is not None:
411 | module.weight.data[module.padding_idx].zero_()
412 |
413 | @classmethod
414 | def from_pretrained(
415 | cls,
416 | pretrained_model_name_or_path: str | os.PathLike | None,
417 | *model_args,
418 | config: PretrainedConfig | str | os.PathLike | None = None,
419 | cache_dir: str | os.PathLike | None = None,
420 | ignore_mismatched_sizes: bool = False,
421 | force_download: bool = False,
422 | local_files_only: bool = False,
423 | token: str | bool | None = None,
424 | revision: str = "main",
425 | use_safetensors: bool | None = None,
426 | weights_only: bool = True,
427 | **kwargs,
428 | ):
429 | """Load pretrained model with generation config."""
430 |
431 | # Catch backend errors early
432 | backend = getattr(config, "moe_backend", "hf")
433 | if backend == "sglang" and sglang_import_error is not None:
434 | raise RuntimeError(f"sglang is not available. Import error: {sglang_import_error}")
435 | elif backend == "flashinfer" and flashinfer_import_error is not None:
436 | raise RuntimeError(f"flashinfer is not available. Import error: {flashinfer_import_error}")
437 | elif backend == "vllm" and vllm_import_error is not None:
438 | raise RuntimeError(f"vllm is not available. Import error: {vllm_import_error}")
439 |
440 | _model = super().from_pretrained(
441 | pretrained_model_name_or_path,
442 | *model_args,
443 | config=config,
444 | cache_dir=cache_dir,
445 | ignore_mismatched_sizes=ignore_mismatched_sizes,
446 | force_download=force_download,
447 | local_files_only=local_files_only,
448 | token=token,
449 | revision=revision,
450 | use_safetensors=use_safetensors,
451 | weights_only=weights_only,
452 | **kwargs,
453 | )
454 |
455 | resume_download = kwargs.get("resume_download", None)
456 | proxies = kwargs.get("proxies", None)
457 | subfolder = kwargs.get("subfolder", "")
458 | from_auto_class = kwargs.get("_from_auto", False)
459 | from_pipeline = kwargs.get("_from_pipeline", None)
460 |
461 | _model.generation_config = GenerationConfig.from_pretrained(
462 | pretrained_model_name_or_path,
463 | cache_dir=cache_dir,
464 | force_download=force_download,
465 | resume_download=resume_download,
466 | proxies=proxies,
467 | local_files_only=local_files_only,
468 | token=token,
469 | revision=revision,
470 | subfolder=subfolder,
471 | _from_auto=from_auto_class,
472 | _from_pipeline=from_pipeline,
473 | )
474 |
475 | # If configured to use a fused backend, pack fused tensors once after load
476 | try:
477 | if backend in ("sglang", "vllm"):
478 | # Walk decoder layers and initialize fused weights
479 | model_core = getattr(_model, "model", _model)
480 | layers = getattr(model_core, "layers", None)
481 | if isinstance(layers, nn.ModuleList):
482 | for layer in layers:
483 | mlp = getattr(layer, "mlp", None)
484 | if hasattr(mlp, "_initialize_weights"):
485 | mlp._initialize_weights(
486 | free_experts=True,
487 | mode=backend,
488 | )
489 | except Exception as _e:
490 | logger.warning(f"Backend {backend} weight processing skipped: {_e}")
491 |
492 | return _model
493 |
494 |
495 | class RND1Model(RND1PreTrainedModel):
496 | """RND1 transformer model with bidirectional attention for diffusion language modeling."""
497 |
498 | def __init__(self, config: RND1Config):
499 | super().__init__(config)
500 |
501 | self.padding_idx = config.pad_token_id
502 | self.vocab_size = config.vocab_size
503 |
504 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
505 | self.layers = nn.ModuleList([RND1DecoderLayer(config, i) for i in range(config.num_hidden_layers)])
506 | if config.moe_backend == "vllm":
507 | RMSNormClass = VLLMRMSNorm
508 | else:
509 | RMSNormClass = Qwen3MoeRMSNorm
510 | self.norm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps)
511 |
512 | self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)
513 |
514 | self.post_init()
515 |
516 | def forward(
517 | self,
518 | input_ids: torch.LongTensor | None = None,
519 | attention_mask: torch.Tensor | None = None,
520 | position_ids: torch.LongTensor | None = None,
521 | inputs_embeds: torch.FloatTensor | None = None,
522 | **kwargs,
523 | ) -> MoeModelOutputWithPast:
524 | """Forward pass through the RND1 model."""
525 |
526 | if (input_ids is None) == (inputs_embeds is None):
527 | raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
528 |
529 | if inputs_embeds is None:
530 | inputs_embeds = self.embed_tokens(input_ids)
531 |
532 | if position_ids is None:
533 | position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
534 |
535 | position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
536 |
537 | hidden_states = inputs_embeds
538 |
539 | for layer in self.layers:
540 | hidden_states, _ = layer(
541 | hidden_states,
542 | attention_mask=attention_mask,
543 | position_ids=position_ids,
544 | position_embeddings=position_embeddings,
545 | )
546 |
547 | hidden_states = self.norm(hidden_states)
548 |
549 | return MoeModelOutputWithPast(
550 | last_hidden_state=hidden_states,
551 | router_logits=None,
552 | )
553 |
554 |
555 | class RND1LM(RND1PreTrainedModel, RND1GenerationMixin):
556 | """Radical Numerics Diffusion Language Model with bidirectional attention."""
557 |
558 | def __init__(self, config: RND1Config):
559 | super().__init__(config)
560 | self.model = RND1Model(config)
561 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
562 | self.post_init()
563 |
564 | def get_input_embeddings(self):
565 | """Get the input embeddings layer."""
566 | return self.model.embed_tokens
567 |
568 | def set_input_embeddings(self, value):
569 | """Set the input embeddings layer."""
570 | self.model.embed_tokens = value
571 |
572 | def get_output_embeddings(self):
573 | """Get the output embeddings layer (lm_head)."""
574 | return self.lm_head
575 |
576 | def set_output_embeddings(self, new_embeddings):
577 | """Set the output embeddings layer (lm_head)."""
578 | self.lm_head = new_embeddings
579 |
580 | @classmethod
581 | def can_generate(cls) -> bool:
582 | """Indicates this model can generate text."""
583 | return True
584 |
585 | def forward(
586 | self,
587 | input_ids: torch.LongTensor | None = None,
588 | attention_mask: torch.Tensor | None = None,
589 | position_ids: torch.LongTensor | None = None,
590 | inputs_embeds: torch.FloatTensor | None = None,
591 | labels: torch.LongTensor | None = None,
592 | **kwargs,
593 | ) -> MaskedLMOutput:
594 | """Forward pass with optional loss computation."""
595 | outputs = self.model(
596 | input_ids=input_ids,
597 | attention_mask=attention_mask,
598 | position_ids=position_ids,
599 | inputs_embeds=inputs_embeds,
600 | **kwargs,
601 | )
602 | logits = self.lm_head(outputs.last_hidden_state)
603 |
604 | loss = None
605 | if labels is not None:
606 | loss_fct = nn.CrossEntropyLoss()
607 | loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
608 |
609 | return MaskedLMOutput(
610 | loss=loss,
611 | logits=logits,
612 | )
613 |
--------------------------------------------------------------------------------