├── 01_quantization ├── .gitignore ├── README.md ├── export.py ├── model.py └── train.py ├── 02_inference ├── .gitignore ├── CMakeLists.txt ├── README.md ├── cmake │ └── fetch │ │ ├── fmt.cmake │ │ ├── nlohmann.cmake │ │ ├── spdlog.cmake │ │ └── stb.cmake ├── include │ ├── CMakeLists.txt │ ├── model.hpp │ ├── operator.hpp │ ├── operator_factory.hpp │ ├── operators │ │ ├── conv2d.hpp │ │ ├── dequant_stub.hpp │ │ ├── linear.hpp │ │ ├── maxpool2d.hpp │ │ ├── padding.hpp │ │ ├── quant_stub.hpp │ │ └── relu.hpp │ └── tensor.hpp └── tutorials │ ├── CMakeLists.txt │ └── demo.cc ├── 03_hardware ├── .gitignore ├── README.md ├── axi_dma_ctrl │ ├── README.md │ ├── axi_dma_read_ctrl.sv │ ├── axi_dma_write_ctrl.sv │ ├── tb_axi_dma_read_ctrl.sv │ └── tb_axi_dma_write_ctrl.sv ├── axi_ir │ └── README.md └── systolic_array │ ├── README.md │ ├── systolic_array.sv │ └── systolic_array_tb.sv ├── 04_software ├── .gitignore ├── README.md ├── driver │ ├── Makefile │ ├── include │ │ ├── accel.h │ │ ├── accel_config.h │ │ └── accel_types.h │ ├── src │ │ ├── accel.c │ │ └── accel_config.c │ └── test │ │ ├── accel_test.h │ │ ├── test_accel.c │ │ └── test_config.c ├── hal │ ├── Makefile │ ├── include │ │ ├── hal.h │ │ ├── hal_base.h │ │ ├── hal_config.h │ │ ├── hal_io.h │ │ └── hal_mem.h │ ├── src │ │ ├── hal_base.c │ │ ├── hal_config.c │ │ ├── hal_io.c │ │ └── hal_mem.c │ └── test │ │ ├── hal_test.h │ │ ├── test_hal_init.c │ │ ├── test_hal_io.c │ │ └── test_hal_mem.c └── runtime │ ├── CMakeLists.txt │ ├── cmake │ └── accel_driver.cmake │ ├── include │ ├── CMakeLists.txt │ ├── accel.hpp │ └── accel │ │ ├── buffer.hpp │ │ ├── runtime.hpp │ │ └── types.hpp │ └── tutorials │ ├── CMakeLists.txt │ └── demo.cc ├── LICENSE ├── README-CN.md ├── README.md └── imgs ├── dma_read_block_design.png ├── dma_read_simulation.png ├── dma_write_block_design.png ├── dma_write_simulation.png ├── logo.svg └── overview.png /01_quantization/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | 173 | # IDE and editor files 174 | .idea/ 175 | .vscode/ 176 | *.swp 177 | *.swo -------------------------------------------------------------------------------- /01_quantization/README.md: -------------------------------------------------------------------------------- 1 | # A Quantization Aware Training Example (LeNet) 2 | This project utilizes PyTorch to perform quantization-aware training on a LeNet model. 3 | 4 | ## How to Use 5 | ```bash 6 | python train.py 7 | python export.py 8 | ``` 9 | 10 | ## Project Structure 11 | - `model.py` - Model definition 12 | - `train.py` - Training script 13 | - `export.py` - Export script 14 | 15 | ## References 16 | - [PyTorch Quantization](https://pytorch.org/docs/1.4.0/quantization.html) 17 | -------------------------------------------------------------------------------- /01_quantization/export.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from typing import Dict, List, Optional, Any, Union 4 | 5 | import torch 6 | from model import LeNet 7 | 8 | 9 | class ModelExporter: 10 | """A class to handle the export of PyTorch models to JSON format. 11 | 12 | This class provides methods to extract model architecture information and weights, 13 | handling both regular and quantized models. 14 | 15 | Attributes: 16 | model: The PyTorch model to be exported. 17 | state_dict: The state dictionary containing model parameters. 18 | """ 19 | 20 | def __init__(self, model: torch.nn.Module, state_dict: Dict[str, torch.Tensor]): 21 | """Initializes the ModelExporter with a model and its state dictionary. 22 | 23 | Args: 24 | model: PyTorch model to be exported. 25 | state_dict: State dictionary containing model parameters. 26 | """ 27 | self.model = model 28 | self.state_dict = state_dict 29 | 30 | def _process_layer(self, name: str, module: torch.nn.Module) -> Optional[Dict[str, Any]]: 31 | """Processes a single layer and returns its information. 32 | 33 | Args: 34 | name: Name of the layer. 35 | module: The layer module to process. 36 | 37 | Returns: 38 | A dictionary containing layer information or None if layer type is unknown. 39 | """ 40 | if isinstance(module, torch.quantization.stubs.QuantStub): 41 | print(f"Found QuantStub: {name}") 42 | return { 43 | "name": name, 44 | "type": "QuantStub" 45 | } 46 | 47 | if isinstance(module, torch.nn.Conv2d): 48 | print(f"Found Conv2d: {name}") 49 | return { 50 | "name": name, 51 | "type": "Conv2d", 52 | "in_channels": module.in_channels, 53 | "out_channels": module.out_channels, 54 | "kernel_size": module.kernel_size[0], # Assuming square kernel 55 | "stride": module.stride[0], # Assuming square stride 56 | "padding": module.padding[0], # Assuming square padding 57 | } 58 | 59 | if isinstance(module, torch.nn.Linear): 60 | print(f"Found Linear: {name}") 61 | return { 62 | "name": name, 63 | "type": "Linear", 64 | "in_features": module.in_features, 65 | "out_features": module.out_features, 66 | } 67 | 68 | if isinstance(module, torch.nn.ReLU): 69 | print(f"Found ReLU: {name}") 70 | return { 71 | "name": name, 72 | "type": "ReLU", 73 | "inplace": module.inplace 74 | } 75 | 76 | if isinstance(module, torch.nn.MaxPool2d): 77 | print(f"Found MaxPool2d: {name}") 78 | return { 79 | "name": name, 80 | "type": "MaxPool2d", 81 | "kernel_size": self._get_scalar_value(module.kernel_size), 82 | "stride": self._get_scalar_value(module.stride), 83 | "padding": self._get_scalar_value(module.padding), 84 | } 85 | 86 | if isinstance(module, torch.quantization.stubs.DeQuantStub): 87 | print(f"Found DeQuantStub: {name}") 88 | return { 89 | "name": name, 90 | "type": "DeQuantStub" 91 | } 92 | 93 | print(f"Skipping unknown layer type: {type(module)}") 94 | return None 95 | 96 | @staticmethod 97 | def _get_scalar_value(param: Union[int, tuple]) -> int: 98 | """Extracts scalar value from potentially tuple parameters. 99 | 100 | Args: 101 | param: Parameter that could be either int or tuple. 102 | 103 | Returns: 104 | The scalar value from the parameter. 105 | """ 106 | return param if isinstance(param, int) else param[0] 107 | 108 | def _get_model_info(self) -> List[Dict[str, Any]]: 109 | """Extracts model structure information. 110 | 111 | Returns: 112 | List of dictionaries containing layer information. 113 | """ 114 | model_info = [] 115 | 116 | print("\nExtracting model structure:") 117 | # First, process QuantStub 118 | for name, module in self.model.named_modules(): 119 | if isinstance(module, torch.quantization.stubs.QuantStub): 120 | layer_info = self._process_layer(name, module) 121 | if layer_info: 122 | model_info.append(layer_info) 123 | break 124 | 125 | # Process other layers 126 | for name, module in self.model.named_children(): 127 | if isinstance(module, torch.nn.Sequential): 128 | print(f"Processing Sequential block: {name}") 129 | for sub_name, sub_module in module.named_children(): 130 | if not isinstance(sub_module, torch.quantization.stubs.QuantStub): 131 | layer_info = self._process_layer(f"{name}.{sub_name}", sub_module) 132 | if layer_info: 133 | model_info.append(layer_info) 134 | elif not isinstance(module, torch.quantization.stubs.QuantStub): 135 | layer_info = self._process_layer(name, module) 136 | if layer_info: 137 | model_info.append(layer_info) 138 | 139 | return model_info 140 | 141 | def _process_tensor(self, tensor: torch.Tensor) -> Dict[str, Any]: 142 | """Processes a tensor and returns its information including quantization details. 143 | 144 | Args: 145 | tensor: The tensor to process. 146 | 147 | Returns: 148 | Dictionary containing tensor information and values. 149 | """ 150 | tensor_info = { 151 | "shape": list(tensor.shape), 152 | "dtype": str(tensor.dtype) 153 | } 154 | 155 | if tensor.dtype in [torch.qint8, torch.quint8]: 156 | if tensor.qscheme() == torch.per_tensor_affine: 157 | values = tensor.int_repr().detach().numpy() 158 | tensor_info.update({ 159 | "quantization": "per_tensor", 160 | "scale": float(tensor.q_scale()), 161 | "values": values.tolist() 162 | }) 163 | elif tensor.qscheme() == torch.per_channel_affine: 164 | values = tensor.int_repr().detach().numpy() 165 | tensor_info.update({ 166 | "quantization": "per_channel", 167 | "scales": tensor.q_per_channel_scales().detach().numpy().tolist(), 168 | "axis": int(tensor.q_per_channel_axis()), 169 | "values": values.tolist() 170 | }) 171 | else: 172 | values = tensor.detach().numpy() 173 | tensor_info.update({ 174 | "quantization": "none", 175 | "values": values.tolist() 176 | }) 177 | 178 | return tensor_info 179 | 180 | def export_to_json(self, output_path: str) -> None: 181 | """Exports model structure and weights to a JSON file. 182 | 183 | Args: 184 | output_path: Path where the JSON file will be saved. 185 | """ 186 | model_info = self._get_model_info() 187 | 188 | # Process weights and biases for each layer 189 | last_scale = None 190 | for layer in model_info: 191 | name = layer["name"] 192 | self._process_layer_parameters(layer, name) 193 | 194 | # Store the scale from the last regular layer 195 | if layer["type"] not in ["QuantStub", "DeQuantStub"] and "scale" in layer: 196 | last_scale = layer["scale"] 197 | 198 | # Use the last layer's scale as dequant scale 199 | if layer["type"] == "DeQuantStub" and last_scale is not None: 200 | layer["scale"] = last_scale 201 | 202 | # Create final data structure and save 203 | data = {"layers": model_info} 204 | with open(output_path, 'w') as f: 205 | json.dump(data, f, indent=2) 206 | 207 | def _process_layer_parameters(self, layer: Dict[str, Any], name: str) -> None: 208 | """Processes and adds parameter information to a layer. 209 | 210 | Args: 211 | layer: Dictionary containing layer information. 212 | name: Name of the layer. 213 | """ 214 | # Handle weights 215 | weight_key = f"{name}.weight" 216 | packed_weight_key = f"{name}._packed_params.weight" 217 | 218 | if weight_key in self.state_dict: 219 | layer["weight"] = self._process_tensor(self.state_dict[weight_key]) 220 | elif packed_weight_key in self.state_dict: 221 | print(f"Found packed weight for {name}") 222 | layer["weight"] = self._process_tensor(self.state_dict[packed_weight_key]) 223 | 224 | # Handle biases 225 | bias_key = f"{name}.bias" 226 | packed_bias_key = f"{name}._packed_params.bias" 227 | 228 | if bias_key in self.state_dict: 229 | layer["bias"] = self._process_tensor(self.state_dict[bias_key]) 230 | elif packed_bias_key in self.state_dict: 231 | print(f"Found packed bias for {name}") 232 | layer["bias"] = self._process_tensor(self.state_dict[packed_bias_key]) 233 | 234 | # Handle QuantStub and DeQuantStub parameters 235 | scale_key = f"{name}.scale" 236 | 237 | if scale_key in self.state_dict: 238 | layer["scale"] = float(self.state_dict[scale_key]) 239 | 240 | def main(): 241 | """Main function to handle command line arguments and model export.""" 242 | parser = argparse.ArgumentParser(description='Export model structure and weights') 243 | parser.add_argument('--ckpt_path', 244 | default='./ckpt/quantized_model.pth', 245 | help='path to checkpoint file') 246 | parser.add_argument('--output_path', 247 | default='./LeNet.json', 248 | help='path to output JSON file') 249 | args = parser.parse_args() 250 | 251 | # Create raw model and load checkpoint 252 | model = LeNet() 253 | print(f"Loading checkpoint from {args.ckpt_path}") 254 | checkpoint = torch.load(args.ckpt_path, map_location='cpu') 255 | 256 | # Export model 257 | exporter = ModelExporter(model, checkpoint['model']) 258 | exporter.export_to_json(args.output_path) 259 | print(f"Model exported to {args.output_path}") 260 | 261 | 262 | if __name__ == '__main__': 263 | main() -------------------------------------------------------------------------------- /01_quantization/model.py: -------------------------------------------------------------------------------- 1 | """LeNet model implementation with quantization support.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class LeNet(nn.Module): 8 | """LeNet model for MNIST digit classification with optional quantization.""" 9 | 10 | def __init__(self, use_quantization: bool = False): 11 | """Initializes the LeNet model. 12 | 13 | Args: 14 | use_quantization: If True, enables quantization-aware training. 15 | """ 16 | super().__init__() 17 | self.use_quantization = use_quantization 18 | 19 | self.conv1 = nn.Sequential( 20 | nn.Conv2d(1, 6, 5, padding=2), 21 | nn.ReLU(), 22 | nn.MaxPool2d(2) 23 | ) 24 | 25 | self.conv2 = nn.Sequential( 26 | nn.Conv2d(6, 16, 5), 27 | nn.ReLU(), 28 | nn.MaxPool2d(2) 29 | ) 30 | 31 | self.fc1 = nn.Sequential( 32 | nn.Linear(16 * 5 * 5, 120), 33 | nn.ReLU() 34 | ) 35 | 36 | self.fc2 = nn.Sequential( 37 | nn.Linear(120, 84), 38 | nn.ReLU() 39 | ) 40 | 41 | self.fc3 = nn.Linear(84, 10) 42 | 43 | # Quantization layers 44 | self.quant = torch.quantization.QuantStub() 45 | self.dequant = torch.quantization.DeQuantStub() 46 | 47 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: 48 | """Performs forward pass through the network. 49 | 50 | Args: 51 | input_tensor: Input image tensor of shape [batch_size, 1, 28, 28] 52 | 53 | Returns: 54 | Output logits of shape [batch_size, 10] 55 | """ 56 | if self.use_quantization: 57 | features = self.quant(input_tensor) 58 | else: 59 | features = input_tensor 60 | 61 | # First conv block 62 | conv1_out = self.conv1(features) 63 | 64 | # Second conv block 65 | conv2_out = self.conv2(conv1_out) 66 | 67 | # Flatten for FC layers 68 | batch_size = conv2_out.size(0) 69 | flattened = conv2_out.view(batch_size, -1) 70 | 71 | # Fully connected blocks 72 | fc1_out = self.fc1(flattened) 73 | fc2_out = self.fc2(fc1_out) 74 | logits = self.fc3(fc2_out) 75 | 76 | if self.use_quantization: 77 | logits = self.dequant(logits) 78 | return logits 79 | 80 | def fuse_model(self): 81 | """Fuses Conv2d+ReLU and Linear+ReLU layers for quantization.""" 82 | # Fuse conv blocks 83 | torch.quantization.fuse_modules( 84 | self.conv1, ['0', '1'], inplace=True) 85 | torch.quantization.fuse_modules( 86 | self.conv2, ['0', '1'], inplace=True) 87 | 88 | # Fuse fc blocks 89 | torch.quantization.fuse_modules( 90 | self.fc1, ['0', '1'], inplace=True) 91 | torch.quantization.fuse_modules( 92 | self.fc2, ['0', '1'], inplace=True) -------------------------------------------------------------------------------- /01_quantization/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Tuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torchvision as tv 9 | import torchvision.transforms as transforms 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | from model import LeNet 14 | 15 | 16 | def train_epoch( 17 | model: nn.Module, 18 | loader: DataLoader, 19 | criterion: nn.Module, 20 | optimizer: optim.Optimizer, 21 | device: torch.device, 22 | desc: str = "Training" 23 | ) -> float: 24 | """Trains model for one epoch. 25 | 26 | Args: 27 | model: Neural network model. 28 | loader: DataLoader for training data. 29 | criterion: Loss function. 30 | optimizer: Optimization algorithm. 31 | device: Device to run training on. 32 | desc: Description for progress bar. 33 | 34 | Returns: 35 | Average loss for the epoch. 36 | """ 37 | model.train() 38 | total_loss = 0.0 39 | 40 | pbar = tqdm(loader, desc=desc, leave=False) 41 | for i, (inputs, labels) in enumerate(pbar): 42 | inputs, labels = inputs.to(device), labels.to(device) 43 | 44 | optimizer.zero_grad() 45 | outputs = model(inputs) 46 | loss = criterion(outputs, labels) 47 | loss.backward() 48 | optimizer.step() 49 | 50 | total_loss += loss.item() 51 | pbar.set_postfix({'Loss': total_loss / (i + 1)}) 52 | 53 | return total_loss / len(loader) 54 | 55 | 56 | def evaluate( 57 | model: nn.Module, 58 | loader: DataLoader, 59 | device: torch.device 60 | ) -> float: 61 | """Evaluates model accuracy. 62 | 63 | Args: 64 | model: Neural network model. 65 | loader: DataLoader for test data. 66 | device: Device to run evaluation on. 67 | 68 | Returns: 69 | Accuracy as a percentage. 70 | """ 71 | model.eval() 72 | correct = total = 0 73 | 74 | with torch.no_grad(): 75 | pbar = tqdm(loader, desc="Evaluating", leave=False) 76 | for inputs, labels in pbar: 77 | inputs, labels = inputs.to(device), labels.to(device) 78 | outputs = model(inputs) 79 | _, predicted = outputs.max(1) 80 | 81 | total += labels.size(0) 82 | correct += (predicted == labels).sum().item() 83 | 84 | accuracy = 100.0 * correct / total 85 | pbar.set_postfix({'Accuracy': f"{accuracy:.2f}%"}) 86 | 87 | return accuracy 88 | 89 | 90 | def train_model( 91 | model: nn.Module, 92 | train_loader: DataLoader, 93 | test_loader: DataLoader, 94 | optimizer: optim.Optimizer, 95 | criterion: nn.Module, 96 | device: torch.device, 97 | num_epochs: int, 98 | desc_prefix: str = "Epoch" 99 | ) -> nn.Module: 100 | """Trains the model. 101 | 102 | Args: 103 | model: Neural network model. 104 | train_loader: DataLoader for training data. 105 | test_loader: DataLoader for test data. 106 | optimizer: Optimization algorithm. 107 | criterion: Loss function. 108 | device: Device to run on. 109 | num_epochs: Number of epochs to train. 110 | desc_prefix: Prefix for progress bar description. 111 | 112 | Returns: 113 | Trained model. 114 | """ 115 | for epoch in range(num_epochs): 116 | train_epoch( 117 | model, train_loader, criterion, optimizer, device, 118 | desc=f"{desc_prefix} {epoch+1}/{num_epochs}") 119 | evaluate(model, test_loader, device) 120 | return model 121 | 122 | 123 | def qat( 124 | model: nn.Module, 125 | train_loader: DataLoader, 126 | test_loader: DataLoader, 127 | optimizer: optim.Optimizer, 128 | criterion: nn.Module, 129 | device: torch.device, 130 | num_epochs: int 131 | ) -> nn.Module: 132 | """Performs quantization-aware training and converts to quantized model. 133 | 134 | Args: 135 | model: Model to quantize. 136 | train_loader: DataLoader for training data. 137 | test_loader: DataLoader for test data. 138 | optimizer: Optimization algorithm. 139 | criterion: Loss function. 140 | device: Device to run on. 141 | num_epochs: Number of QAT epochs. 142 | 143 | Returns: 144 | Quantized model. 145 | """ 146 | model.use_quantization = True 147 | model.fuse_model() 148 | model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') 149 | model = torch.quantization.prepare_qat(model, inplace=True) 150 | 151 | model = train_model(model, train_loader, test_loader, optimizer, criterion, 152 | device, num_epochs, desc_prefix="QAT Epoch") 153 | 154 | model = model.cpu().eval() 155 | return torch.quantization.convert(model, inplace=True) 156 | 157 | 158 | def get_dataloaders( 159 | batch_size: int 160 | ) -> Tuple[DataLoader, DataLoader]: 161 | """Creates MNIST training and test dataloaders. 162 | 163 | Args: 164 | batch_size: Number of samples per batch. 165 | 166 | Returns: 167 | Tuple of (train_loader, test_loader). 168 | """ 169 | transform = transforms.ToTensor() 170 | train_set = tv.datasets.MNIST( 171 | root='./data/', train=True, download=True, transform=transform) 172 | test_set = tv.datasets.MNIST( 173 | root='./data/', train=False, download=True, transform=transform) 174 | 175 | train_loader = DataLoader( 176 | train_set, batch_size=batch_size, shuffle=True) 177 | test_loader = DataLoader( 178 | test_set, batch_size=batch_size, shuffle=False) 179 | 180 | return train_loader, test_loader 181 | 182 | 183 | def main(): 184 | """Main training function.""" 185 | parser = argparse.ArgumentParser(description='Train LeNet on MNIST') 186 | parser.add_argument('--ckpt_dir', default='./ckpt', 187 | help='checkpoint directory') 188 | parser.add_argument('--epochs', type=int, default=8, 189 | help='number of epochs') 190 | parser.add_argument('--batch_size', type=int, default=128, 191 | help='batch size') 192 | parser.add_argument('--lr', type=float, default=0.001, 193 | help='learning rate') 194 | parser.add_argument('--qat_epochs', type=int, default=2, 195 | help='quantization-aware training epochs') 196 | args = parser.parse_args() 197 | 198 | # Create checkpoint directory if it doesn't exist 199 | os.makedirs(args.ckpt_dir, exist_ok=True) 200 | 201 | # Setup 202 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 203 | train_loader, test_loader = get_dataloaders(args.batch_size) 204 | model = LeNet().to(device) 205 | criterion = nn.CrossEntropyLoss() 206 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 207 | 208 | # Regular training 209 | print("Starting regular training...") 210 | model = train_model( 211 | model, train_loader, test_loader, optimizer, criterion, 212 | device, args.epochs 213 | ) 214 | 215 | # Quantization-aware training 216 | print("\nStarting quantization-aware training...") 217 | quantized_model = qat( 218 | model, train_loader, test_loader, optimizer, criterion, 219 | device, args.qat_epochs 220 | ) 221 | 222 | # Save model 223 | torch.save( 224 | {'model': quantized_model.state_dict()}, 225 | f"{args.ckpt_dir}/quantized_model.pth" 226 | ) 227 | 228 | 229 | if __name__ == "__main__": 230 | main() -------------------------------------------------------------------------------- /02_inference/.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | 34 | # Ignore all files in the build directory 35 | build/ 36 | bin/ 37 | 38 | # IDE 39 | .vscode/ -------------------------------------------------------------------------------- /02_inference/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) 2 | 3 | project(qnn 4 | VERSION 1.0 5 | LANGUAGES CXX 6 | ) 7 | 8 | set(CMAKE_CXX_STANDARD 17) 9 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 10 | set(CMAKE_CXX_EXTENSIONS ON) 11 | 12 | if(CMAKE_BUILD_TYPE MATCHES "Debug") 13 | message(STATUS "Building in debug mode") 14 | add_compile_definitions(BUILD_DEBUG) 15 | elseif(CMAKE_BUILD_TYPE MATCHES "Release") 16 | add_compile_definitions(BUILD_RELEASE) 17 | endif() 18 | 19 | include(FetchContent) 20 | 21 | list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) 22 | 23 | include(fetch/fmt) 24 | include(fetch/nlohmann) 25 | include(fetch/spdlog) 26 | include(fetch/stb) 27 | 28 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}/bin) 29 | 30 | add_subdirectory(include) 31 | add_subdirectory(tutorials) -------------------------------------------------------------------------------- /02_inference/README.md: -------------------------------------------------------------------------------- 1 | # DNN Inference (INT8) in C++ 2 | 3 | This project utilizes C++ to implement INT8 Deep Neural Network (DNN) inference, aimed at deepening understanding of subsequent Verilog implementations of DNN accelerators. 4 | 5 | To preserve clarity and simplicity in the code, advanced optimizations (e.g., SIMD, OpenMP) have not been employed. 6 | 7 | ## Prerequisites 8 | 9 | - C++ compiler supporting C++17 10 | - CMake version 3.11 or higher 11 | 12 | ## How to Use 13 | ### Build 14 | ```bash 15 | cmake -B build 16 | cmake --build build 17 | ``` 18 | 19 | ### Run 20 | ```bash 21 | ./inference 22 | ``` 23 | 24 | ## Third Party Libraries 25 | 26 | This project relies on the following third-party libraries: 27 | 28 | - [fmt](https://github.com/fmtlib/fmt) 29 | - A modern formatting library. 30 | 31 | - [nlohmann/json](https://github.com/nlohmann/json) 32 | - A robust library tailored for JSON manipulation in Modern C++. 33 | 34 | - [gabime/spdlog](https://github.com/gabime/spdlog) 35 | - A high-performance C++ logging library. 36 | 37 | - [nothings/stb](https://github.com/nothings/stb) 38 | - A single-file public domain library for image loading and writing. 39 | 40 | All dependencies are automatically fetched and configured through CMake FetchContent. 41 | 42 | ## Project Structure 43 | - **cmake** - CMake modules directory 44 | - **fetch** - Scripts for fetching external dependencies 45 | - `nlohmann.cmake` - Fetch script for JSON library 46 | - `spdlog.cmake` - Fetch script for logging library 47 | - `stb.cmake` - Fetch script for image loading library 48 | - **include** 49 | - **operators** - Directory for operator implementations 50 | - `conv2d.hpp` - Convolution 2D operator 51 | - `linear.hpp` - Linear/Fully connected layer 52 | - `maxpool2d.hpp` - Max pooling 2D operator 53 | - `padding.hpp` - Padding operations 54 | - `quant_stub.hpp` - Quantization stub 55 | - `relu.hpp` - ReLU activation function 56 | - `model.hpp` - Model class definition 57 | - `operator.hpp` - Base operator interface 58 | - `operator_factory.hpp` - Operator factory pattern 59 | - `tensor.hpp` - Tensor class definition 60 | - `CMakeLists.txt` 61 | - **tutorials** 62 | - `demo.cc` 63 | - `CMakeLists.txt` 64 | - `CMakeLists.txt` -------------------------------------------------------------------------------- /02_inference/cmake/fetch/fmt.cmake: -------------------------------------------------------------------------------- 1 | FetchContent_Declare( 2 | fmt 3 | GIT_REPOSITORY https://github.com/fmtlib/fmt.git 4 | GIT_TAG 6.1.2 5 | ) 6 | FetchContent_MakeAvailable(fmt) -------------------------------------------------------------------------------- /02_inference/cmake/fetch/nlohmann.cmake: -------------------------------------------------------------------------------- 1 | FetchContent_Declare(json 2 | URL https://github.com/nlohmann/json/releases/download/v3.7.3/include.zip 3 | DOWNLOAD_EXTRACT_TIMESTAMP TRUE 4 | ) 5 | FetchContent_MakeAvailable(json) 6 | 7 | add_library(nlohmann_json INTERFACE) 8 | target_include_directories(nlohmann_json INTERFACE ${json_SOURCE_DIR}/include) 9 | add_library(nlohmann_json::nlohmann_json ALIAS nlohmann_json) -------------------------------------------------------------------------------- /02_inference/cmake/fetch/spdlog.cmake: -------------------------------------------------------------------------------- 1 | FetchContent_Declare( 2 | spdlog 3 | GIT_REPOSITORY https://github.com/gabime/spdlog.git 4 | GIT_TAG v1.5.0 5 | ) 6 | 7 | FetchContent_MakeAvailable(spdlog) -------------------------------------------------------------------------------- /02_inference/cmake/fetch/stb.cmake: -------------------------------------------------------------------------------- 1 | # Fetch nothings/stb 2 | FetchContent_Declare( 3 | stb 4 | GIT_REPOSITORY https://github.com/nothings/stb.git 5 | GIT_TAG master 6 | ) 7 | 8 | FetchContent_MakeAvailable(stb) -------------------------------------------------------------------------------- /02_inference/include/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(libqnn INTERFACE) 2 | 3 | target_include_directories(libqnn INTERFACE 4 | ${CMAKE_CURRENT_SOURCE_DIR} 5 | ) 6 | 7 | target_link_libraries(libqnn INTERFACE 8 | nlohmann_json::nlohmann_json 9 | spdlog::spdlog 10 | fmt::fmt 11 | ) -------------------------------------------------------------------------------- /02_inference/include/model.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file model.hpp 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Model class 6 | * @version 1.0.0 7 | * @date 2020-01-18 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | 19 | #include "operator.hpp" 20 | #include "operator_factory.hpp" 21 | 22 | namespace qnn { 23 | 24 | /** 25 | * @brief Neural network model container 26 | * 27 | * This class represents a neural network model that can be loaded from a JSON 28 | * file and executed. It maintains a sequence of operators that form the model's 29 | * computation graph. 30 | */ 31 | class Model { 32 | public: 33 | /** 34 | * @brief Constructor for the Model class 35 | * 36 | * This constructor sets the logging level to debug if the build type is 37 | * debug. 38 | */ 39 | Model() { 40 | #ifdef BUILD_DEBUG 41 | spdlog::set_level(spdlog::level::debug); 42 | #endif 43 | // Remove resize from constructor 44 | } 45 | 46 | /** 47 | * @brief Parse a layer from JSON and create appropriate operator 48 | * 49 | * @param layer_json JSON object containing layer configuration 50 | * @return Unique pointer to created operator 51 | * @throws json::exception If layer configuration is invalid 52 | * @throws std::runtime_error If layer type is unsupported 53 | */ 54 | static std::variant, OperatorPtr, 55 | OperatorPtr> 56 | parseLayer(const json& layer_json) { 57 | const auto& type = layer_json["type"].get(); 58 | const auto& dtype = layer_json.contains("dtype") 59 | ? layer_json["dtype"].get() 60 | : "torch.qint8"; 61 | 62 | if (type == "QuantStub") { 63 | return QuantStub::LoadFromJson(layer_json); 64 | } 65 | 66 | if (type == "DeQuantStub") { 67 | return DeQuantStub::LoadFromJson(layer_json); 68 | } 69 | 70 | auto createOp = [&]() { 71 | if (type == "Conv2d") return Conv2d::LoadFromJson(layer_json); 72 | if (type == "Linear") return Linear::LoadFromJson(layer_json); 73 | if (type == "MaxPool2d") return MaxPool2d::LoadFromJson(layer_json); 74 | if (type == "ReLU") return ReLU::LoadFromJson(layer_json); 75 | throw std::runtime_error("Unknown operator type: " + type); 76 | }; 77 | 78 | if (dtype == "torch.qint8") { 79 | return createOp.template operator()(); 80 | } 81 | throw std::runtime_error("Unknown supported dtype: " + dtype); 82 | } 83 | 84 | /** 85 | * @brief Creates a Model instance from a JSON file 86 | * 87 | * @param filename Path to the JSON file containing model configuration 88 | * @return Model instance initialized with operators from the JSON 89 | * @throws std::runtime_error If file cannot be opened or parsed 90 | * @throws json::exception If JSON structure is invalid 91 | * 92 | * The JSON file should contain: 93 | * - List of operators with their configurations 94 | * - Operator connections/graph structure 95 | * - Model metadata (optional) 96 | */ 97 | static Model loadModel(const std::string& filename) { 98 | Model model; 99 | 100 | // Read and parse JSON file 101 | std::ifstream file(filename); 102 | if (!file.is_open()) { 103 | throw std::runtime_error("Failed to open model file: " + filename); 104 | } 105 | 106 | json j; 107 | file >> j; 108 | 109 | // Parse layers from JSON 110 | const auto& layers = j["layers"]; 111 | 112 | // Reserve space for operators 113 | model.operators_.reserve(layers.size()); 114 | 115 | // Resize intermediate tensors based on number of layers 116 | model.intermediate_tensors_.resize(layers.size()); 117 | 118 | for (const auto& layer : layers) { 119 | try { 120 | model.operators_.push_back(parseLayer(layer)); 121 | } catch (const std::exception& e) { 122 | throw std::runtime_error("Failed to parse layer: " + 123 | std::string(e.what())); 124 | } 125 | } 126 | 127 | return model; 128 | } 129 | 130 | /** 131 | * @brief Performs forward pass through the model 132 | * 133 | * @param input Input tensor to the model 134 | * @param output Output tensor where results will be stored 135 | * @throws std::runtime_error If model has no operators or computation fails 136 | * 137 | * This method: 138 | * 1. Validates input dimensions 139 | * 2. Executes each operator in sequence 140 | * 3. Stores final result in output tensor 141 | * 142 | * The input and output tensors must have compatible dimensions with 143 | * the model's expected input/output shapes. 144 | */ 145 | std::variant, Tensor> forward( 146 | const Tensor& input) { 147 | if (operators_.empty()) { 148 | throw std::runtime_error("No operators in model"); 149 | } 150 | 151 | // Initialize input tensor 152 | input_tensor_ = std::move(input); 153 | 154 | // Process each operator 155 | for (size_t i = 0; i < operators_.size(); ++i) { 156 | const auto& op_variant = operators_[i]; 157 | std::visit( 158 | [&](const auto& op) { 159 | using Op = std::remove_reference_t; 160 | using OpInputT = typename Op::input_type; 161 | using OpOutputT = typename Op::output_type; 162 | 163 | spdlog::debug("Layer: {} ({})", op->name, op->type); 164 | 165 | if constexpr (std::is_same_v && 166 | std::is_same_v) { 167 | op->Forward(input_tensor_, intermediate_tensors_[i]); 168 | } else if constexpr (std::is_same_v && 169 | std::is_same_v) { 170 | op->Forward(intermediate_tensors_[i - 1], 171 | intermediate_tensors_[i]); 172 | } else if constexpr (std::is_same_v && 173 | std::is_same_v) { 174 | op->Forward(intermediate_tensors_[i - 1], output_tensor_); 175 | } else { 176 | throw std::runtime_error("Unsupported operator type: " + 177 | op->type); 178 | } 179 | }, 180 | op_variant); 181 | } 182 | 183 | // Return the final output 184 | return std::variant, Tensor>( 185 | std::move(output_tensor_)); 186 | } 187 | 188 | private: 189 | /** @brief Vector of operators that form the model's computation graph */ 190 | std::vector< 191 | std::variant, OperatorPtr, 192 | OperatorPtr>> 193 | operators_; 194 | 195 | // Tensors for input, output and intermediate results 196 | Tensor input_tensor_; 197 | Tensor output_tensor_; 198 | std::vector> intermediate_tensors_; 199 | }; 200 | 201 | } // namespace qnn -------------------------------------------------------------------------------- /02_inference/include/operator.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file operator.hpp 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Operator class 6 | * @version 1.0.0 7 | * @date 2020-01-18 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include "tensor.hpp" 19 | 20 | namespace qnn { 21 | 22 | using json = nlohmann::json; 23 | 24 | /** 25 | * @brief Weight information for quantized operators 26 | * 27 | * This class holds quantization parameters and weight values for neural network 28 | * operators. It supports both per-tensor and per-channel quantization schemes. 29 | */ 30 | class WeightInfo { 31 | public: 32 | /** 33 | * @brief Constructs WeightInfo from JSON data 34 | * 35 | * @param j JSON object containing weight parameters 36 | * @return WeightInfo instance initialized with the parameters 37 | * @throws json::exception If required parameters are missing 38 | * @throws std::runtime_error If dtype is unsupported 39 | */ 40 | template 41 | static WeightInfo LoadFromJson(const json& j) { 42 | WeightInfo info; 43 | 44 | // Parse shape 45 | info.shape_ = j["shape"].get>(); 46 | 47 | // Parse dtype and validate 48 | std::string dtype = j["dtype"].get(); 49 | if (dtype == "torch.qint8" && !std::is_same_v) { 50 | throw std::runtime_error( 51 | "Type mismatch: JSON specifies qint8 but template parameter is " 52 | "different"); 53 | } else if (dtype == "torch.float32" && !std::is_same_v) { 54 | throw std::runtime_error( 55 | "Type mismatch: JSON specifies float32 but template parameter is " 56 | "different"); 57 | } 58 | 59 | // Parse quantization type 60 | if (j.contains("quantization")) { 61 | info.quantization_ = j["quantization"].get(); 62 | } 63 | 64 | // Parse values based on dtype 65 | if (j.contains("values")) { 66 | // Calculate total size from shape 67 | size_t total_size = 1; 68 | for (auto dim : info.shape_) { 69 | total_size *= dim; 70 | } 71 | 72 | // Create tensor with the right size 73 | info.values_ = QTensor(info.shape_); 74 | auto* data = info.values_.data(); 75 | 76 | if (dtype == "torch.qint8") { 77 | // Flatten nested arrays recursively 78 | std::function flatten_array = 79 | [&flatten_array](const json& arr, int8_t*& out) { 80 | if (arr.is_array()) { 81 | for (const auto& elem : arr) { 82 | flatten_array(elem, out); 83 | } 84 | } else { 85 | *out++ = arr.get(); 86 | } 87 | }; 88 | 89 | int8_t* ptr = data; 90 | flatten_array(j["values"], ptr); 91 | 92 | } else if (dtype == "torch.float32") { 93 | throw std::runtime_error("Float tensor support not implemented yet"); 94 | } else { 95 | throw std::runtime_error("Unsupported dtype: " + dtype); 96 | } 97 | } 98 | 99 | // Parse per-tensor quantization parameters 100 | if (j.contains("scale")) { 101 | info.scale_ = j["scale"].get(); 102 | } 103 | 104 | // Parse per-channel quantization parameters 105 | if (j.contains("scales")) { 106 | info.scales_ = j["scales"].get>(); 107 | } 108 | 109 | if (j.contains("axis")) { 110 | info.axis_ = j["axis"].get(); 111 | } 112 | 113 | return info; 114 | } 115 | 116 | /** @return Shape of the weight tensor */ 117 | const std::vector& shape() const { return shape_; } 118 | 119 | /** @return Quantization type ("per_tensor" or "per_channel") */ 120 | const std::string& quantization() const { return quantization_; } 121 | 122 | /** @return Quantized weight values */ 123 | const QTensor& values() const { return values_; } 124 | 125 | /** @return Scale factor for per-tensor quantization */ 126 | float scale() const { return scale_; } 127 | 128 | /** @return Scale factors for per-channel quantization */ 129 | const std::vector& scales() const { return scales_; } 130 | 131 | /** @return Axis along which per-channel quantization is performed */ 132 | int axis() const { return axis_; } 133 | 134 | private: 135 | std::vector shape_; 136 | std::string quantization_; 137 | QTensor values_; 138 | float scale_{0.0f}; 139 | std::vector scales_; 140 | int axis_{0}; 141 | }; 142 | 143 | /** 144 | * @brief Base class for all neural network operators 145 | * 146 | * @tparam T Data type of the tensor elements (e.g., float, int8_t) 147 | */ 148 | template 149 | class Operator { 150 | public: 151 | /** @brief Virtual destructor for proper cleanup of derived classes */ 152 | virtual ~Operator() = default; 153 | 154 | /** 155 | * @brief Performs the forward computation of the operator 156 | * 157 | * @param input Input tensor to the operator 158 | * @param output Output tensor where results will be stored 159 | * @throws std::runtime_error If computation fails or dimensions mismatch 160 | */ 161 | virtual void Forward(const Tensor& input, 162 | Tensor& output) = 0; 163 | 164 | /** @brief Name identifier of the operator */ 165 | std::string name; 166 | 167 | /** @brief Type identifier of the operator */ 168 | std::string type; 169 | 170 | /** @brief Type definition for input tensor */ 171 | using input_type = InputT; 172 | 173 | /** @brief Type definition for output tensor */ 174 | using output_type = OutputT; 175 | }; 176 | 177 | /** @brief Smart pointer type for operator instances */ 178 | template 179 | using OperatorPtr = std::unique_ptr>; 180 | 181 | // Common type aliases 182 | using QuantOperator = Operator; 183 | using QuantOperatorPtr = OperatorPtr; 184 | 185 | // Add mixed precision operator types if needed 186 | using QuantStubOperator = Operator; 187 | using QuantStubOperatorPtr = OperatorPtr; 188 | 189 | using DeQuantStubOperator = Operator; 190 | using DeQuantStubOperatorPtr = OperatorPtr; 191 | 192 | } // namespace qnn 193 | -------------------------------------------------------------------------------- /02_inference/include/operator_factory.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file operator_factory.hpp 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Operator factory 6 | * @version 1.0.0 7 | * @date 2020-01-18 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "operators/conv2d.hpp" 13 | #include "operators/dequant_stub.hpp" 14 | #include "operators/linear.hpp" 15 | #include "operators/maxpool2d.hpp" 16 | #include "operators/padding.hpp" 17 | #include "operators/quant_stub.hpp" 18 | #include "operators/relu.hpp" -------------------------------------------------------------------------------- /02_inference/include/operators/conv2d.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file conv2d.hpp 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief 2D Convolution operator 6 | * @version 1.0.0 7 | * @date 2020-01-18 8 | */ 9 | 10 | #pragma once 11 | #include "operator.hpp" 12 | #include "operators/padding.hpp" 13 | 14 | namespace qnn { 15 | 16 | /** 17 | * @brief 2D Convolution operator 18 | * 19 | * Implements 2D convolution with optional bias addition. 20 | * Supports both float and quantized computation. 21 | * 22 | * @tparam InputT Data type of the input tensor elements (e.g., float, int8_t) 23 | * @tparam OutputT Data type of the output tensor elements (e.g., float, int8_t) 24 | */ 25 | template 26 | class Conv2d : public Operator { 27 | public: 28 | /** 29 | * @brief Creates Conv2d operator from JSON configuration 30 | * 31 | * @param j JSON object containing operator parameters 32 | * @return Unique pointer to created operator 33 | * @throws json::exception If required parameters are missing 34 | */ 35 | static OperatorPtr LoadFromJson(const json& j) { 36 | auto op = std::make_unique>(); 37 | 38 | // Parse basic parameters 39 | op->name = j["name"].get(); 40 | op->type = "Conv2d"; 41 | op->in_channels_ = j["in_channels"].get(); 42 | op->out_channels_ = j["out_channels"].get(); 43 | op->kernel_size_ = j["kernel_size"].get(); 44 | op->stride_ = j["stride"].get(); 45 | op->padding_ = j["padding"].get(); 46 | 47 | // Parse weights and bias 48 | if (j.contains("weight")) { 49 | op->weight_ = WeightInfo::LoadFromJson(j["weight"]); 50 | } 51 | 52 | // Parse bias 53 | if (j.contains("bias")) { 54 | const auto& bias = j["bias"]; 55 | const auto& values = bias["values"].get>(); 56 | op->bias_.assign(values.begin(), values.end()); 57 | } 58 | 59 | // Parse quantization parameters 60 | if (j.contains("scale")) { 61 | op->scale_ = j["scale"].get(); 62 | } 63 | 64 | return op; 65 | } 66 | 67 | /** 68 | * @brief Performs 2D convolution computation 69 | * 70 | * @param input Input tensor of shape [N, C, H, W] 71 | * @param output Output tensor of shape [N, out_channels_, H_out, W_out] 72 | * @throws std::runtime_error If input dimensions are invalid 73 | */ 74 | void Forward(const Tensor& input, Tensor& output) override { 75 | const auto& in_shape = input.shape(); 76 | if (in_shape.size() != 4) { 77 | throw std::runtime_error("Input tensor must be 4D [N,C,H,W]"); 78 | } 79 | 80 | // Create and setup padding operator if padding is needed 81 | Tensor padded_input; 82 | padded_input.set_scale(input.scale()); 83 | 84 | if (padding_ > 0) { 85 | auto padding_op = std::make_unique>(); 86 | padding_op->set_pad_height(padding_); 87 | padding_op->set_pad_width(padding_); 88 | padding_op->set_pad_value(0); 89 | padding_op->Forward(input, padded_input); 90 | } else { 91 | padded_input = input; 92 | } 93 | 94 | // Get dimensions after padding 95 | const auto& padded_shape = padded_input.shape(); 96 | size_t batch = padded_shape[0]; 97 | size_t in_height = padded_shape[2]; 98 | size_t in_width = padded_shape[3]; 99 | size_t out_height = (in_height - kernel_size_) / stride_ + 1; 100 | size_t out_width = (in_width - kernel_size_) / stride_ + 1; 101 | 102 | // Convert to vector for resize 103 | std::vector out_shape = {batch, static_cast(out_channels_), 104 | out_height, out_width}; 105 | output.resize(out_shape); 106 | output.set_scale(scale_); 107 | 108 | #ifdef BUILD_DEBUG 109 | spdlog::debug("--------------------------------"); 110 | spdlog::debug("Conv-2D Operator Forward"); 111 | spdlog::debug("Input Shape: [{}]", fmt::join(input.shape(), ", ")); 112 | spdlog::debug("Output Shape: [{}]", fmt::join(output.shape(), ", ")); 113 | spdlog::debug("Input Scale: {}", input.scale()); 114 | spdlog::debug("Output Scale: {}", output.scale()); 115 | spdlog::debug("Kernel size: {}", kernel_size_); 116 | spdlog::debug("In channels: {}", in_channels_); 117 | spdlog::debug("Out channels: {}", out_channels_); 118 | spdlog::debug("Stride: {}", stride_); 119 | spdlog::debug("Padding: {}", padding_); 120 | spdlog::debug("--------------------------------"); 121 | #endif 122 | 123 | // Perform convolution on padded input 124 | for (size_t n = 0; n < batch; n++) { 125 | for (size_t oc = 0; oc < static_cast(out_channels_); oc++) { 126 | for (size_t oh = 0; oh < out_height; oh++) { 127 | for (size_t ow = 0; ow < out_width; ow++) { 128 | float acc = 0.0f; 129 | 130 | for (size_t ic = 0; ic < static_cast(in_channels_); ic++) { 131 | for (size_t kh = 0; kh < kernel_size_; kh++) { 132 | for (size_t kw = 0; kw < kernel_size_; kw++) { 133 | size_t ih = oh * stride_ + kh; 134 | size_t iw = ow * stride_ + kw; 135 | 136 | size_t in_idx = 137 | ((n * in_channels_ + ic) * in_height + ih) * in_width + 138 | iw; 139 | size_t weight_idx = 140 | ((oc * in_channels_ + ic) * kernel_size_ + kh) * 141 | kernel_size_ + 142 | kw; 143 | 144 | acc += static_cast(padded_input.data()[in_idx]) * 145 | static_cast(weight_.values()[weight_idx]); 146 | } 147 | } 148 | } 149 | 150 | if (!bias_.empty()) { 151 | acc += bias_[oc] / (weight_.scales()[oc] * padded_input.scale()); 152 | } 153 | 154 | // Apply output scale 155 | acc = acc * (weight_.scales()[oc] * padded_input.scale()) / 156 | output.scale(); 157 | 158 | // Clamp the result to the int8 range 159 | acc = std::min(std::max(acc, -128.0f), 127.0f); 160 | 161 | // Store the result 162 | size_t out_idx = 163 | ((n * out_channels_ + oc) * out_height + oh) * out_width + ow; 164 | output.data()[out_idx] = static_cast(std::round(acc)); 165 | } 166 | } 167 | } 168 | } 169 | } 170 | 171 | private: 172 | /** @brief Number of input channels */ 173 | int in_channels_; 174 | 175 | /** @brief Number of output channels */ 176 | int out_channels_; 177 | 178 | /** @brief Size of the convolution kernel */ 179 | int kernel_size_; 180 | 181 | /** @brief Stride of the convolution */ 182 | int stride_; 183 | 184 | /** @brief Padding size */ 185 | int padding_; 186 | 187 | /** @brief Convolution weights */ 188 | WeightInfo weight_; 189 | 190 | /** @brief Optional bias terms */ 191 | std::vector bias_; 192 | 193 | /** @brief Scale for quantization */ 194 | float scale_; 195 | }; 196 | } // namespace qnn -------------------------------------------------------------------------------- /02_inference/include/operators/dequant_stub.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file dequant_stub.hpp 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Dequantization stub operator 6 | * @version 1.0.0 7 | * @date 2020-01-18 8 | */ 9 | 10 | #pragma once 11 | #include 12 | 13 | #include "operator.hpp" 14 | 15 | namespace qnn { 16 | 17 | /** 18 | * @brief Quantization stub operator 19 | * 20 | * Handles quantization of floating-point tensors to integer tensors. 21 | * Used at the beginning of quantized neural networks. 22 | */ 23 | class DeQuantStub : public Operator { 24 | public: 25 | /** 26 | * @brief Creates DeQuantStub operator from JSON configuration 27 | * 28 | * @param j JSON object containing operator parameters 29 | * @return Unique pointer to created operator 30 | * @throws json::exception If required parameters are missing 31 | */ 32 | static OperatorPtr LoadFromJson(const json& j) { 33 | auto op = std::make_unique(); 34 | op->scale_ = j["scale"].get(); 35 | op->name = j["name"].get(); 36 | op->type = "DeQuantStub"; 37 | return op; 38 | } 39 | 40 | /** 41 | * @brief Performs tensor quantization 42 | * 43 | * @param input Floating-point input tensor 44 | * @param output Quantized output tensor 45 | * @throws std::runtime_error If quantization parameters are invalid 46 | */ 47 | void Forward(const Tensor& input, Tensor& output) override { 48 | // Convert shape types properly 49 | std::vector out_shape(input.shape().begin(), input.shape().end()); 50 | output.resize(out_shape); 51 | 52 | #ifdef BUILD_DEBUG 53 | spdlog::debug("--------------------------------"); 54 | spdlog::debug("DeQuantStub Operator Forward"); 55 | spdlog::debug("Input Shape: [{}]", fmt::join(input.shape(), ", ")); 56 | spdlog::debug("Output Shape: [{}]", fmt::join(output.shape(), ", ")); 57 | spdlog::debug("Scale: {}", scale_); 58 | spdlog::debug("--------------------------------"); 59 | #endif 60 | 61 | for (size_t i = 0; i < input.size(); ++i) { 62 | output.data()[i] = input.data()[i] * scale_; 63 | } 64 | } 65 | 66 | private: 67 | /** @brief Quantization scale factor */ 68 | float scale_; 69 | }; 70 | 71 | } // namespace qnn -------------------------------------------------------------------------------- /02_inference/include/operators/linear.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file linear.hpp 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Linear operator 6 | * @version 1.0.0 7 | * @date 2020-01-18 8 | */ 9 | 10 | #pragma once 11 | #include "operator.hpp" 12 | 13 | namespace qnn { 14 | 15 | /** 16 | * @brief Linear (fully connected) operator 17 | * 18 | * Implements a linear transformation: y = xW^T + b 19 | * Supports both float and quantized computation. 20 | * 21 | * @tparam InputT Data type of the input tensor elements (e.g., float, int8_t) 22 | * @tparam OutputT Data type of the output tensor elements (e.g., float, int8_t) 23 | */ 24 | template 25 | class Linear : public Operator { 26 | public: 27 | /** 28 | * @brief Creates Linear operator from JSON configuration 29 | * 30 | * @param j JSON object containing operator parameters 31 | * @return Unique pointer to created operator 32 | * @throws json::exception If required parameters are missing 33 | */ 34 | static OperatorPtr LoadFromJson(const json& j) { 35 | auto op = std::make_unique>(); 36 | op->name = j["name"].get(); 37 | op->type = "Linear"; 38 | 39 | // Parse weights and bias 40 | if (j.contains("weight")) { 41 | op->weight_ = WeightInfo::LoadFromJson(j["weight"]); 42 | } 43 | 44 | // Parse bias 45 | if (j.contains("bias")) { 46 | const auto& bias = j["bias"]; 47 | const auto& values = bias["values"].get>(); 48 | op->bias_.assign(values.begin(), values.end()); 49 | } 50 | 51 | // Parse quantization parameters 52 | if (j.contains("scale")) { 53 | op->scale_ = j["scale"].get(); 54 | } 55 | 56 | return op; 57 | } 58 | 59 | /** 60 | * @brief Performs linear transformation 61 | * 62 | * @param input Input tensor of shape [batch_size, in_features_] 63 | * @param output Output tensor of shape [batch_size, out_features_] 64 | * @throws std::runtime_error If input dimensions are invalid 65 | */ 66 | void Forward(const Tensor& input, Tensor& output) override { 67 | // First dimension is always batch size 68 | const size_t batch_size = input.shape()[0]; 69 | 70 | // Calculate flattened features (multiply all dimensions except batch) 71 | size_t in_features = 1; 72 | for (size_t i = 1; i < input.shape().size(); ++i) { 73 | in_features *= input.shape()[i]; 74 | } 75 | 76 | // Validate input dimensions 77 | if (in_features != weight_.shape()[1]) { 78 | throw std::runtime_error( 79 | "Input features dimension doesn't match weight matrix"); 80 | } 81 | 82 | const size_t out_features = weight_.shape()[0]; 83 | 84 | // Resize output tensor to [batch_size, out_features] 85 | output.resize(std::vector{batch_size, out_features}); 86 | output.set_scale(scale_); 87 | 88 | #ifdef BUILD_DEBUG 89 | spdlog::debug("--------------------------------"); 90 | spdlog::debug("Linear Operator Forward"); 91 | spdlog::debug("Input Shape: [{}]", fmt::join(input.shape(), ", ")); 92 | spdlog::debug("Output Shape: [{}]", fmt::join(output.shape(), ", ")); 93 | spdlog::debug("Input Scale: {}", input.scale()); 94 | spdlog::debug("Output Scale: {}", output.scale()); 95 | spdlog::debug("In Features: {}", in_features_); 96 | spdlog::debug("Out Features: {}", out_features_); 97 | spdlog::debug("--------------------------------"); 98 | #endif 99 | 100 | // Perform matrix multiplication: y = xW^T + b 101 | for (size_t b = 0; b < batch_size; ++b) { 102 | for (size_t o = 0; o < out_features; ++o) { 103 | float acc = 0.0f; 104 | 105 | // Compute dot product 106 | for (size_t i = 0; i < in_features; ++i) { 107 | acc += static_cast(input.data()[b * in_features + i]) * 108 | static_cast(weight_.values()[o * in_features + i]); 109 | } 110 | 111 | if (!bias_.empty()) { 112 | // Bias is adjusted by its scale 113 | acc += bias_[o] / (weight_.scales()[o] * input.scale()); 114 | } 115 | 116 | // Apply output scale 117 | acc = acc * (weight_.scales()[o] * input.scale()) / output.scale(); 118 | 119 | // Clamp the result to the int8 range 120 | acc = std::min(std::max(acc, -128.0f), 127.0f); 121 | 122 | // Store the result 123 | output.data()[b * out_features + o] = 124 | static_cast(std::round(acc)); 125 | } 126 | } 127 | } 128 | 129 | private: 130 | /** @brief Number of input features */ 131 | int in_features_; 132 | 133 | /** @brief Number of output features */ 134 | int out_features_; 135 | 136 | /** @brief Weight matrix */ 137 | WeightInfo weight_; 138 | 139 | /** @brief Optional bias terms */ 140 | std::vector bias_; 141 | 142 | /** @brief Quantization scale */ 143 | float scale_; 144 | }; 145 | } // namespace qnn -------------------------------------------------------------------------------- /02_inference/include/operators/maxpool2d.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file maxpool2d.hpp 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief 2D Max Pooling operator 6 | * @version 1.0.0 7 | * @date 2020-01-18 8 | */ 9 | 10 | #pragma once 11 | #include "operator.hpp" 12 | 13 | namespace qnn { 14 | 15 | /** 16 | * @brief 2D Max Pooling operator 17 | * 18 | * Performs max pooling over an input tensor, reducing spatial dimensions 19 | * by selecting maximum values in pooling windows. 20 | * 21 | * @tparam T Data type of the tensor elements (e.g., float, int8_t) 22 | */ 23 | template 24 | class MaxPool2d : public Operator { 25 | public: 26 | /** 27 | * @brief Creates MaxPool2d operator from JSON configuration 28 | * 29 | * @param j JSON object containing operator parameters 30 | * @return Unique pointer to created operator 31 | * @throws json::exception If required parameters are missing 32 | */ 33 | static OperatorPtr LoadFromJson(const json& j) { 34 | auto op = std::make_unique>(); 35 | op->name = j["name"].get(); 36 | op->type = "MaxPool2d"; 37 | op->kernel_size_ = j["kernel_size"].get(); 38 | op->stride_ = j["stride"].get(); 39 | op->padding_ = j.value("padding", 0); 40 | return op; 41 | } 42 | 43 | /** 44 | * @brief Performs max pooling computation 45 | * 46 | * @param input Input tensor of shape [N, C, H, W] 47 | * @param output Output tensor of shape [N, C, H_out, W_out] 48 | * @throws std::runtime_error If input dimensions are invalid 49 | */ 50 | void Forward(const Tensor& input, Tensor& output) override { 51 | const auto& in_shape = input.shape(); 52 | if (in_shape.size() != 4) { 53 | throw std::runtime_error("Input tensor must be 4D [N,C,H,W]"); 54 | } 55 | 56 | // Get input dimensions 57 | size_t batch = in_shape[0]; 58 | size_t channels = in_shape[1]; 59 | size_t in_height = in_shape[2]; 60 | size_t in_width = in_shape[3]; 61 | 62 | // Calculate output dimensions 63 | size_t out_height = (in_height - kernel_size_) / stride_ + 1; 64 | size_t out_width = (in_width - kernel_size_) / stride_ + 1; 65 | 66 | // Resize output tensor 67 | std::vector out_shape = {batch, channels, out_height, out_width}; 68 | output.resize(out_shape); 69 | output.set_scale(input.scale()); 70 | 71 | #ifdef BUILD_DEBUG 72 | spdlog::debug("--------------------------------"); 73 | spdlog::debug("MaxPool-2D Operator Forward"); 74 | spdlog::debug("Input Shape: [{}]", fmt::join(input.shape(), ", ")); 75 | spdlog::debug("Output Shape: [{}]", fmt::join(output.shape(), ", ")); 76 | spdlog::debug("Kernel Size: {}", kernel_size_); 77 | spdlog::debug("Stride: {}", stride_); 78 | spdlog::debug("Padding: {}", padding_); 79 | spdlog::debug("--------------------------------"); 80 | #endif 81 | 82 | // Perform max pooling 83 | for (size_t n = 0; n < batch; n++) { 84 | for (size_t c = 0; c < channels; c++) { 85 | for (size_t oh = 0; oh < out_height; oh++) { 86 | for (size_t ow = 0; ow < out_width; ow++) { 87 | // Initialize with minimum value for the type 88 | InputT max_val = std::numeric_limits::lowest(); 89 | 90 | // Find maximum in the pooling window 91 | for (size_t kh = 0; kh < kernel_size_; kh++) { 92 | for (size_t kw = 0; kw < kernel_size_; kw++) { 93 | size_t ih = oh * stride_ + kh; 94 | size_t iw = ow * stride_ + kw; 95 | 96 | size_t in_idx = 97 | ((n * channels + c) * in_height + ih) * in_width + iw; 98 | max_val = std::max(max_val, input.data()[in_idx]); 99 | } 100 | } 101 | 102 | // Store the maximum value 103 | size_t out_idx = 104 | ((n * channels + c) * out_height + oh) * out_width + ow; 105 | output.data()[out_idx] = static_cast(max_val); 106 | } 107 | } 108 | } 109 | } 110 | } 111 | 112 | private: 113 | /** @brief Size of the pooling window */ 114 | int kernel_size_; 115 | 116 | /** @brief Stride of the pooling operation */ 117 | int stride_; 118 | 119 | /** @brief Padding size */ 120 | int padding_; 121 | }; 122 | 123 | } // namespace qnn -------------------------------------------------------------------------------- /02_inference/include/operators/padding.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file padding.hpp 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Padding operator 6 | * @version 1.0.0 7 | * @date 2020-01-18 8 | */ 9 | 10 | #pragma once 11 | #include "operator.hpp" 12 | 13 | namespace qnn { 14 | 15 | template 16 | class Padding : public Operator { 17 | public: 18 | void set_pad_height(int height) { pad_height_ = height; } 19 | void set_pad_width(int width) { pad_width_ = width; } 20 | void set_pad_value(InputT value) { pad_value_ = value; } 21 | 22 | /** 23 | * @brief Loads Padding operator from JSON configuration 24 | * 25 | * @param j JSON object containing operator parameters 26 | * @return Unique pointer to created operator 27 | */ 28 | static OperatorPtr LoadFromJson(const json& j) { 29 | auto op = std::make_unique>(); 30 | op->name = j["name"].get(); 31 | op->type = "Padding"; 32 | op->set_pad_height(j["pad_height"].get()); 33 | op->set_pad_width(j["pad_width"].get()); 34 | op->set_pad_value(j.value("pad_value", 0)); 35 | return op; 36 | } 37 | 38 | /** 39 | * @brief Performs padding of input tensor 40 | * 41 | * @param input Input tensor of shape [N, C, H, W] 42 | * @param output Output tensor of shape [N, C, H + pad_height * 2, W + 43 | * pad_width * 2] 44 | * @throws std::runtime_error If input dimensions are invalid 45 | */ 46 | void Forward(const Tensor& input, Tensor& output) override { 47 | const auto& in_shape = input.shape(); 48 | if (in_shape.size() != 4) { 49 | throw std::runtime_error("Input tensor must be 4D [N,C,H,W]"); 50 | } 51 | 52 | // Calculate output shape 53 | std::vector out_shape = { 54 | in_shape[0], // N 55 | in_shape[1], // C 56 | in_shape[2] + pad_height_ * 2, // H + pad_height 57 | in_shape[3] + pad_width_ * 2 // W + pad_width 58 | }; 59 | output.resize(out_shape); 60 | 61 | #ifdef BUILD_DEBUG 62 | spdlog::debug("--------------------------------"); 63 | spdlog::debug("Padding Operator Forward"); 64 | spdlog::debug("Input Shape: [{}]", fmt::join(input.shape(), ", ")); 65 | spdlog::debug("Output Shape: [{}]", fmt::join(output.shape(), ", ")); 66 | spdlog::debug("Pad Height: {}", pad_height_); 67 | spdlog::debug("Pad Width: {}", pad_width_); 68 | spdlog::debug("Pad Value: {}", pad_value_); 69 | spdlog::debug("--------------------------------"); 70 | #endif 71 | 72 | // Calculate padding on each side 73 | int pad_top = pad_height_; 74 | int pad_bottom = pad_height_; 75 | int pad_left = pad_width_; 76 | int pad_right = pad_width_; 77 | 78 | // Get dimensions 79 | size_t batch = in_shape[0]; 80 | size_t channels = in_shape[1]; 81 | size_t in_height = in_shape[2]; 82 | size_t in_width = in_shape[3]; 83 | size_t out_height = out_shape[2]; 84 | size_t out_width = out_shape[3]; 85 | 86 | // Fill output with padding value 87 | std::fill(output.data(), output.data() + output.size(), 88 | static_cast(pad_value_)); 89 | 90 | // Copy input data to the padded output 91 | for (size_t n = 0; n < batch; n++) { 92 | for (size_t c = 0; c < channels; c++) { 93 | for (size_t h = 0; h < in_height; h++) { 94 | for (size_t w = 0; w < in_width; w++) { 95 | size_t in_idx = ((n * channels + c) * in_height + h) * in_width + w; 96 | size_t out_idx = 97 | ((n * channels + c) * out_height + (h + pad_top)) * out_width + 98 | (w + pad_left); 99 | output.data()[out_idx] = static_cast(input.data()[in_idx]); 100 | } 101 | } 102 | } 103 | } 104 | } 105 | 106 | private: 107 | int pad_height_; 108 | int pad_width_; 109 | InputT pad_value_; 110 | }; 111 | 112 | } // namespace qnn -------------------------------------------------------------------------------- /02_inference/include/operators/quant_stub.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file quant_stub.hpp 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Quantization stub operator 6 | * @version 1.0.0 7 | * @date 2020-01-18 8 | */ 9 | 10 | #pragma once 11 | #include 12 | 13 | #include "operator.hpp" 14 | 15 | namespace qnn { 16 | 17 | /** 18 | * @brief Quantization stub operator 19 | * 20 | * Handles quantization of floating-point tensors to integer tensors. 21 | * Used at the beginning of quantized neural networks. 22 | */ 23 | class QuantStub : public Operator { 24 | public: 25 | /** 26 | * @brief Creates QuantStub operator from JSON configuration 27 | * 28 | * @param j JSON object containing operator parameters 29 | * @return Unique pointer to created operator 30 | * @throws json::exception If required parameters are missing 31 | */ 32 | static OperatorPtr LoadFromJson(const json& j) { 33 | auto op = std::make_unique(); 34 | op->scale_ = j["scale"].get(); 35 | op->name = j["name"].get(); 36 | op->type = "QuantStub"; 37 | return op; 38 | } 39 | 40 | /** 41 | * @brief Performs tensor quantization 42 | * 43 | * @param input Floating-point input tensor 44 | * @param output Quantized output tensor 45 | * @throws std::runtime_error If quantization parameters are invalid 46 | */ 47 | void Forward(const Tensor& input, Tensor& output) override { 48 | // Convert shape types properly 49 | std::vector out_shape(input.shape().begin(), input.shape().end()); 50 | output.resize(out_shape); 51 | output.set_scale(scale_); 52 | 53 | #ifdef BUILD_DEBUG 54 | spdlog::debug("--------------------------------"); 55 | spdlog::debug("QuantStub Operator Forward"); 56 | spdlog::debug("Input Shape: [{}]", fmt::join(input.shape(), ", ")); 57 | spdlog::debug("Output Shape: [{}]", fmt::join(output.shape(), ", ")); 58 | spdlog::debug("Scale: {}", scale_); 59 | spdlog::debug("--------------------------------"); 60 | #endif 61 | 62 | for (size_t i = 0; i < input.size(); ++i) { 63 | float temp = (input.data()[i] / scale_); 64 | output.data()[i] = 65 | std::clamp(static_cast(std::round(temp)), 66 | static_cast(-128), static_cast(127)); 67 | } 68 | } 69 | 70 | private: 71 | /** @brief Quantization scale factor */ 72 | float scale_; 73 | }; 74 | 75 | } // namespace qnn -------------------------------------------------------------------------------- /02_inference/include/operators/relu.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file relu.hpp 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief ReLU operator 6 | * @version 1.0.0 7 | * @date 2020-01-18 8 | */ 9 | 10 | #pragma once 11 | #include "operator.hpp" 12 | 13 | namespace qnn { 14 | 15 | /** 16 | * @brief ReLU (Rectified Linear Unit) operator 17 | * 18 | * Implements element-wise ReLU activation function. 19 | * 20 | * @tparam T Data type of the tensor elements (e.g., float, int8_t) 21 | */ 22 | template 23 | class ReLU : public Operator { 24 | public: 25 | /** 26 | * @brief Creates ReLU operator from JSON configuration 27 | * 28 | * @param j JSON object containing operator parameters 29 | * @return Unique pointer to created operator 30 | * @throws json::exception If required parameters are missing 31 | */ 32 | static OperatorPtr LoadFromJson(const json& j) { 33 | auto op = std::make_unique>(); 34 | op->name = j["name"].get(); 35 | op->type = "ReLU"; 36 | return op; 37 | } 38 | 39 | /** 40 | * @brief Performs ReLU activation 41 | * 42 | * @param input Input tensor 43 | * @param output Output tensor of same shape as input 44 | * @throws std::runtime_error If input dimensions are invalid 45 | */ 46 | void Forward(const Tensor& input, Tensor& output) override { 47 | // Resize output tensor to match input shape 48 | output.resize(input.shape()); 49 | output.set_scale(input.scale()); 50 | 51 | // Apply ReLU: max(0,x) 52 | for (size_t i = 0; i < input.size(); ++i) { 53 | output.data()[i] = 54 | static_cast(input.data()[i] > 0 ? input.data()[i] : 0); 55 | } 56 | } 57 | }; 58 | } // namespace qnn -------------------------------------------------------------------------------- /02_inference/include/tensor.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file tensor.hpp 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Tensor class 6 | * @version 1.0.0 7 | * @date 2020-01-18 8 | */ 9 | 10 | #pragma once 11 | #include 12 | #include 13 | #include 14 | 15 | namespace qnn { 16 | 17 | /** 18 | * @brief Generic tensor class for storing n-dimensional arrays 19 | * @tparam T Data type of tensor elements 20 | */ 21 | template 22 | class Tensor { 23 | public: 24 | /** @brief Default constructor */ 25 | Tensor() = default; 26 | 27 | /** 28 | * @brief Constructs tensor with specified shape 29 | * @param shape Vector of dimensions 30 | */ 31 | explicit Tensor(const std::vector& shape) { resize(shape); } 32 | 33 | /** 34 | * @brief Copy constructor 35 | * @param other Tensor to copy from 36 | */ 37 | Tensor(const Tensor& other) : data_(other.data_), shape_(other.shape_) {} 38 | 39 | /** 40 | * @brief Move constructor 41 | * @param other Tensor to move from 42 | */ 43 | Tensor(Tensor&& other) noexcept 44 | : data_(std::move(other.data_)), shape_(std::move(other.shape_)) {} 45 | 46 | /** 47 | * @brief Assignment operator 48 | * @param other Tensor to assign from 49 | * @return Reference to this tensor 50 | */ 51 | Tensor& operator=(const Tensor& other) { 52 | if (this != &other) { 53 | data_ = other.data_; 54 | shape_ = other.shape_; 55 | } 56 | return *this; 57 | } 58 | 59 | /** 60 | * @brief Move assignment operator 61 | * @param other Tensor to assign from 62 | * @return Reference to this tensor 63 | */ 64 | Tensor& operator=(Tensor&& other) noexcept { 65 | if (this != &other) { 66 | data_ = std::move(other.data_); 67 | shape_ = std::move(other.shape_); 68 | scale_ = other.scale_; 69 | } 70 | return *this; 71 | } 72 | 73 | /** 74 | * @brief Resize with int64_t vector 75 | * @param shape Vector of dimensions 76 | */ 77 | void resize(const std::vector& shape) { 78 | shape_.clear(); 79 | shape_.reserve(shape.size()); 80 | for (const auto& dim : shape) { 81 | if (dim < 0) { 82 | throw std::invalid_argument("Negative dimension size"); 83 | } 84 | shape_.push_back(static_cast(dim)); 85 | } 86 | size_t total = 1; 87 | for (const auto& dim : shape_) { 88 | total *= dim; 89 | } 90 | data_.resize(total); 91 | } 92 | 93 | /** 94 | * @brief Resize with size_t vector 95 | * @param shape Vector of dimensions 96 | */ 97 | void resize(const std::vector& shape) { 98 | shape_ = shape; 99 | size_t total = 1; 100 | for (const auto& dim : shape_) { 101 | total *= dim; 102 | } 103 | data_.resize(total); 104 | } 105 | 106 | /** @return Size of tensor */ 107 | size_t size() const { 108 | size_t total = 1; 109 | for (const auto& dim : shape_) { 110 | total *= dim; 111 | } 112 | return total; 113 | } 114 | 115 | /** @return Reference to tensor shape */ 116 | const std::vector& shape() const { return shape_; } 117 | 118 | /** @return Pointer to raw data */ 119 | T* data() { return data_.data(); } 120 | 121 | /** @return Const pointer to raw data */ 122 | const T* data() const { return data_.data(); } 123 | 124 | /** 125 | * @brief Access tensor element 126 | * @param index Index of element 127 | * @return Reference to element 128 | */ 129 | T& operator[](size_t index) { 130 | if (index >= data_.size()) { 131 | throw std::out_of_range("Tensor index out of range"); 132 | } 133 | return data_[index]; 134 | } 135 | 136 | /** 137 | * @brief Access tensor element 138 | * @param index Index of element 139 | * @return Const reference to element 140 | */ 141 | const T& operator[](size_t index) const { 142 | if (index >= data_.size()) { 143 | throw std::out_of_range("Tensor index out of range"); 144 | } 145 | return data_[index]; 146 | } 147 | 148 | /** 149 | * @brief Get tensor scale (for quantized tensors) 150 | * @return Scale value 151 | */ 152 | float scale() const { return scale_; } 153 | 154 | /** 155 | * @brief Set tensor scale (for quantized tensors) 156 | * @param scale Scale value 157 | */ 158 | void set_scale(float scale) { scale_ = scale; } 159 | 160 | private: 161 | std::vector data_; 162 | std::vector shape_; 163 | float scale_ = 1.0f; 164 | }; 165 | 166 | /** @brief Specialized tensor types */ 167 | using QTensor = Tensor; 168 | using FTensor = Tensor; 169 | 170 | } // namespace qnn -------------------------------------------------------------------------------- /02_inference/tutorials/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Add executable 2 | add_executable(inference demo.cc) 3 | 4 | # Add include directories 5 | target_include_directories(inference 6 | PRIVATE 7 | ${PROJECT_SOURCE_DIR}/include 8 | ${stb_SOURCE_DIR} 9 | ) 10 | 11 | # Link libraries 12 | target_link_libraries(inference 13 | PRIVATE 14 | libqnn 15 | fmt::fmt 16 | ) -------------------------------------------------------------------------------- /02_inference/tutorials/demo.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * @file demo.cc 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief A demo for the INT8 inference of LeNet-5 6 | * @version 1.0.0 7 | * @date 2020-01-18 8 | */ 9 | 10 | #define STB_IMAGE_IMPLEMENTATION 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | #include "model.hpp" 17 | #include "tensor.hpp" 18 | 19 | /** 20 | * @brief Loads an image and converts it to a tensor 21 | * 22 | * @tparam T Data type of the output tensor (e.g., float, int8_t) 23 | * @param img_path Path to the image file 24 | * @param normalize Whether to normalize pixel values to [0,1] 25 | * @return Tensor containing the image data in NCHW format [1, C, H, W] 26 | * @throws std::runtime_error If image loading fails 27 | */ 28 | template 29 | qnn::Tensor load_image(const std::string& img_path) { 30 | int height, width, channels; 31 | stbi_uc* img = stbi_load(img_path.c_str(), &width, &height, &channels, 1); 32 | 33 | if (!img) { 34 | throw std::runtime_error("Failed to load image: " + img_path); 35 | } 36 | 37 | // Create tensor with shape [1, C, H, W] 38 | qnn::Tensor tensor({1, 1, height, width}); 39 | T* data = tensor.data(); 40 | 41 | // Convert image data to tensor 42 | size_t idx = 0; 43 | for (int h = 0; h < height; ++h) { 44 | for (int w = 0; w < width; ++w) { 45 | T pixel = static_cast(img[h * width + w]); 46 | data[idx++] = pixel; 47 | } 48 | } 49 | 50 | stbi_image_free(img); 51 | return tensor; 52 | } 53 | 54 | int main(int argc, char* argv[]) { 55 | if (argc != 3) { 56 | spdlog::error("Usage: {} ", argv[0]); 57 | return 1; 58 | } 59 | 60 | try { 61 | // Load model 62 | auto model = qnn::Model::loadModel(argv[1]); 63 | 64 | spdlog::info("Model loaded: {}", argv[1]); 65 | 66 | // Load image and create input tensor 67 | std::string image_path = argv[2]; 68 | auto input = load_image(image_path); 69 | spdlog::info("Image loaded: {}", image_path); 70 | 71 | // Normalize input tensor / 255.0 72 | std::transform(input.data(), input.data() + input.size(), input.data(), 73 | [](float x) { return x / 255.0f; }); 74 | 75 | // Forward pass 76 | auto output = model.forward(input); 77 | 78 | // Output the argmax prediction 79 | std::visit( 80 | [](const auto& tensor) { 81 | size_t max_index = std::distance( 82 | tensor.data(), 83 | std::max_element(tensor.data(), tensor.data() + tensor.size())); 84 | spdlog::info("Prediction: {}", max_index); 85 | }, 86 | output); 87 | } catch (const std::exception& e) { 88 | spdlog::error("Error: {}", e.what()); 89 | return 1; 90 | } 91 | 92 | return 0; 93 | } -------------------------------------------------------------------------------- /03_hardware/.gitignore: -------------------------------------------------------------------------------- 1 | ######################################################################################################### 2 | ## This is an example .gitignore file for Vivado, please treat it as an example as 3 | ## it might not be complete. In addition, XAPP 1165 should be followed. 4 | ######################################################################################################### 5 | ######### 6 | #Exclude all 7 | ######### 8 | * 9 | !*/ 10 | !.gitignore 11 | !*.md 12 | ########################################################################### 13 | ## VIVADO 14 | ########################################################################### 15 | ######### 16 | #Source files: 17 | ######### 18 | #Do NOT ignore VHDL, Verilog, SystemVerilog, block diagrams or EDIF files. 19 | !*.vhd 20 | !*.v 21 | !*.sv 22 | !*.bd 23 | !*.edif 24 | ######### 25 | #IP files 26 | ######### 27 | #.xci: synthesis and implemented not possible - you need to return back to the previous version to generate output products 28 | #.xci + .dcp: implementation possible but not re-synthesis 29 | #*.xci(www.spiritconsortium.org) 30 | !*.xci 31 | #*.dcp(checkpoint files) 32 | !*.dcp 33 | !*.vds 34 | !*.pb 35 | #All bd comments and layout coordinates are stored within .ui 36 | !*.ui 37 | !*.ooc 38 | ######### 39 | #System Generator 40 | ######### 41 | !*.mdl 42 | !*.slx 43 | !*.bxml 44 | ######### 45 | #Simulation logic analyzer 46 | ######### 47 | !*.wcfg 48 | !*.coe 49 | ######### 50 | #MIG 51 | ######### 52 | !*.prj 53 | !*.mem 54 | ######### 55 | #Project files 56 | ######### 57 | #XPR + *.XML ? XPR (Files are merged into a single XPR file for 2014.1 version) 58 | #Do NOT ignore *.xpr files 59 | !*.xpr 60 | #Include *.xml files for 2013.4 or earlier version 61 | !*.xml 62 | ######### 63 | #Constraint files 64 | ######### 65 | #Do NOT ignore *.xdc files 66 | !*.xdc 67 | ######### 68 | #TCL - files 69 | ######### 70 | !*.tcl 71 | ######### 72 | #Journal - files 73 | ######### 74 | !*.jou 75 | ######### 76 | #Reports 77 | ######### 78 | !*.rpt 79 | !*.txt 80 | !*.vdi 81 | ######### 82 | #C-files 83 | ######### 84 | !*.c 85 | !*.h 86 | !*.elf 87 | !*.bmm 88 | !*.xmp 89 | -------------------------------------------------------------------------------- /03_hardware/README.md: -------------------------------------------------------------------------------- 1 | # DNN Accelerator Implementation 2 | 3 | This project presents the Verilog implementation of a Deep Neural Network (DNN) accelerator. 4 | 5 | ## Architecture 6 | 7 |
8 | Overview 9 |
10 | 11 | ## Project Structure 12 | - **axi_dma_ctrl** 13 | 14 | - **axi_ir** 15 | 16 | - **img2col** 17 | 18 | - **systolic_array** 19 | -------------------------------------------------------------------------------- /03_hardware/axi_dma_ctrl/README.md: -------------------------------------------------------------------------------- 1 | # AXI DMA Controller Implementation 2 | 3 | This project develops AXI DMA controllers to facilitate data transfers between the CPU/PS and FPGA/PL using DDR memory. 4 | 5 | It follows the programming guide for AXI DMA as outlined in [PG021 AXI DMA](https://www.xilinx.com/support/documentation/ip_documentation/axi_dma/v7_1/pg021_axi_dma.pdf) 6 | 7 | ## Design Overview 8 | The design uses state machines to control AXI read/write operations following the official programming sequence: 9 | 10 | ### DMA Read Operation: 11 | - Verify DMA idle status 12 | - Configure interrupt settings (completion interrupt enabled) 13 | - Set source address (MM2S_SA register) 14 | - Set transfer length (MM2S_LENGTH register) 15 | 16 | ### DMA Write Operation: 17 | - Similar sequence with different register addresses 18 | - Follows S2MM channel programming sequence 19 | 20 | ### Implementation Details 21 | - State Machine: Controls AXI operation states 22 | - Counter: Manages AXI write addresses and content 23 | - Address Alignment: Supports 4-byte alignment when DRE disabled 24 | 25 | ## Simulation Environment 26 | 27 | ### Test Setup 28 | The simulation environment includes: 29 | - AXI Smart Connect 30 | - AXI DMA 31 | - AXI4-Stream Data FIFO 32 | - AXI BRAM Controller 33 | - Block RAM Generator (BRAM Controller mode) 34 | 35 | ### Block Diagrams 36 | #### DMA Read Configuration 37 | ![DMA Read Design](../../imgs/dma_read_block_design.png) 38 | 39 | #### DMA Write Configuration 40 | ![DMA Write Design](../../imgs/dma_write_block_design.png) 41 | 42 | ### Simulation Results 43 | #### DMA Read Operation 44 | ![DMA Read Simulation](../../imgs/dma_read_simulation.png) 45 | 46 | #### DMA Write Operation 47 | ![DMA Write Simulation](../../imgs/dma_write_simulation.png) 48 | -------------------------------------------------------------------------------- /03_hardware/axi_dma_ctrl/axi_dma_read_ctrl.sv: -------------------------------------------------------------------------------- 1 | `timescale 1 ns / 1 ps 2 | 3 | module axi_dma_read_ctrl # 4 | ( 5 | parameter C_M_START_DATA_VALUE = 32'h0000_0000, 6 | parameter C_M_TARGET_SLAVE_BASE_ADDR = 32'h0000_0000, 7 | parameter integer C_M_AXI_ADDR_WIDTH = 32, 8 | parameter integer C_M_AXI_DATA_WIDTH = 32, 9 | parameter integer C_M_TRANSACTIONS_NUM = 4 10 | ) 11 | ( 12 | // User Define 13 | input wire [31:0] DMA_SA_CONFIG, 14 | input wire [25:0] DMA_LENGTH_CONFIG, 15 | input wire DMA_READ_IRQ, 16 | input wire DMA_READ_VALID, 17 | output wire DMA_IDLE, 18 | // AXI Lite Defination 19 | input wire M_AXI_ACLK, 20 | input wire M_AXI_ARESETN, 21 | output wire [C_M_AXI_ADDR_WIDTH-1 : 0] M_AXI_AWADDR, 22 | output wire [2 : 0] M_AXI_AWPROT, 23 | output wire M_AXI_AWVALID, 24 | input wire M_AXI_AWREADY, 25 | output wire [C_M_AXI_DATA_WIDTH-1 : 0] M_AXI_WDATA, 26 | output wire [C_M_AXI_DATA_WIDTH/8-1 : 0] M_AXI_WSTRB, 27 | output wire M_AXI_WVALID, 28 | input wire M_AXI_WREADY, 29 | input wire [1 : 0] M_AXI_BRESP, 30 | input wire M_AXI_BVALID, 31 | output wire M_AXI_BREADY, 32 | output wire [C_M_AXI_ADDR_WIDTH-1 : 0] M_AXI_ARADDR, 33 | output wire [2 : 0] M_AXI_ARPROT, 34 | output wire M_AXI_ARVALID, 35 | input wire M_AXI_ARREADY, 36 | input wire [C_M_AXI_DATA_WIDTH-1 : 0] M_AXI_RDATA, 37 | input wire [1 : 0] M_AXI_RRESP, 38 | input wire M_AXI_RVALID, 39 | output wire M_AXI_RREADY 40 | ); 41 | // AXI4LITE signals 42 | //write address valid 43 | reg axi_awvalid; 44 | //write data valid 45 | reg axi_wvalid; 46 | //read address valid 47 | reg axi_arvalid; 48 | //read data acceptance 49 | reg axi_rready; 50 | //write response acceptance 51 | reg axi_bready; 52 | //write address 53 | reg [C_M_AXI_ADDR_WIDTH-1 : 0] axi_awaddr; 54 | //write data 55 | reg [C_M_AXI_DATA_WIDTH-1 : 0] axi_wdata; 56 | //read addresss 57 | reg [C_M_AXI_ADDR_WIDTH-1 : 0] axi_araddr; 58 | //A pulse to initiate a write transaction 59 | reg start_single_write; 60 | //A pulse to initiate a read transaction 61 | reg start_single_read; 62 | // User Logic 63 | // State definitions 64 | typedef enum logic [3:0] { 65 | ST_IDLE = 4'b0000, // Idle state waiting for DMA_READ_VALID 66 | ST_CHECK_IDLE = 4'b0001, // Check if DMA is idle via AXI read 67 | ST_CONTROL_DMA = 4'b0010, // Configure DMA registers 68 | ST_WAITING_IRQ = 4'b0100, // Wait for DMA completion interrupt 69 | ST_CLEAR_IRQ = 4'b1000 // Clear the DMA completion interrupt 70 | } dma_state_t; 71 | 72 | // Write state definitions 73 | typedef enum logic [3:0] { 74 | WR_IRQ = 4'b0000, // Write IRQ configuration 75 | WR_SA = 4'b0001, // Write source address 76 | WR_LENGTH = 4'b0010, // Write transfer length 77 | WR_COMPLETE = 4'b0011 // Write sequence complete 78 | } write_state_t; 79 | 80 | // Register addresses 81 | localparam IDLE_REG = 32'h0000_0004; // Idle status register (read bit 0) 82 | localparam IRQ_REG = 32'h0000_0000; // IRQ control register 83 | localparam SA_REG = 32'h0000_0018; // Source address register 84 | localparam LENGTH_REG = 32'h0000_0028; // Transfer length register 85 | localparam CLEAR_REG = 32'h0000_0004; // IRQ clear register 86 | 87 | // Register values 88 | localparam IRQ_DATA = 32'h00011003; // IRQ enable configuration 89 | localparam CLEAR_DATA = 32'h00011001; // IRQ clear value 90 | 91 | // Internal registers 92 | dma_state_t state; 93 | write_state_t write_state; 94 | logic read_complete; 95 | logic dma_idle; 96 | 97 | // Configuration registers with enable control 98 | logic [31:0] dma_sa_config; 99 | logic [25:0] dma_length_config; 100 | assign dma_sa_config = DMA_READ_VALID ? DMA_SA_CONFIG : dma_sa_config; 101 | assign dma_length_config = DMA_READ_VALID ? DMA_LENGTH_CONFIG : dma_length_config; 102 | 103 | //Adding the offset address to the base addr of the slave 104 | assign M_AXI_AWADDR = C_M_TARGET_SLAVE_BASE_ADDR + axi_awaddr; 105 | //AXI 4 write data 106 | assign M_AXI_WDATA = axi_wdata; 107 | assign M_AXI_AWPROT = 3'b000; 108 | assign M_AXI_AWVALID = axi_awvalid; 109 | //Write Data(W) 110 | assign M_AXI_WVALID = axi_wvalid; 111 | //Set all byte strobes in this example 112 | assign M_AXI_WSTRB = 4'b1111; 113 | //Write Response (B) 114 | assign M_AXI_BREADY = axi_bready; 115 | //Read Address (AR) 116 | assign M_AXI_ARADDR = C_M_TARGET_SLAVE_BASE_ADDR + axi_araddr; 117 | assign M_AXI_ARVALID = axi_arvalid; 118 | assign M_AXI_ARPROT = 3'b001; 119 | //Read and Read Response (R) 120 | assign M_AXI_RREADY = axi_rready; 121 | // User Logic 122 | wire [31:0] dma_sa_config ; 123 | wire [25:0] dma_length_config ; 124 | assign dma_sa_config = DMA_READ_VALID == 1 ? DMA_SA_CONFIG : dma_sa_config; 125 | assign dma_length_config = DMA_READ_VALID == 1 ? DMA_LENGTH_CONFIG : dma_length_config; 126 | //Write Address Channel 127 | always @(posedge M_AXI_ACLK) begin 128 | if (M_AXI_ARESETN == 0) begin 129 | axi_awvalid <= 1'b0; 130 | end 131 | else begin 132 | if (start_single_write) begin 133 | axi_awvalid <= 1'b1; 134 | end 135 | else if (M_AXI_AWREADY && axi_awvalid) begin 136 | axi_awvalid <= 1'b0; 137 | end 138 | end 139 | end 140 | //Write Data Channel 141 | always @(posedge M_AXI_ACLK) begin 142 | if (M_AXI_ARESETN == 0 ) begin 143 | axi_wvalid <= 1'b0; 144 | end 145 | else if (start_single_write) begin 146 | axi_wvalid <= 1'b1; 147 | end 148 | else if (M_AXI_WREADY && axi_wvalid) begin 149 | axi_wvalid <= 1'b0; 150 | end 151 | end 152 | //Write Response (B) Channel 153 | always @(posedge M_AXI_ACLK) begin 154 | if (M_AXI_ARESETN == 0) begin 155 | axi_bready <= 1'b0; 156 | end 157 | else if (M_AXI_BVALID && ~axi_bready) begin 158 | axi_bready <= 1'b1; 159 | end 160 | else if (axi_bready) begin 161 | axi_bready <= 1'b0; 162 | end 163 | else begin 164 | axi_bready <= axi_bready; 165 | end 166 | end 167 | // Write Response (B) Valid 168 | always @(posedge M_AXI_ACLK) begin 169 | if (M_AXI_ARESETN == 0) begin 170 | axi_arvalid <= 1'b0; 171 | end 172 | else if (start_single_read) begin 173 | axi_arvalid <= 1'b1; 174 | end 175 | else if (M_AXI_ARREADY && axi_arvalid) begin 176 | axi_arvalid <= 1'b0; 177 | end 178 | end 179 | //Read Data (and Response) Channel 180 | always @(posedge M_AXI_ACLK) begin 181 | if (M_AXI_ARESETN == 0 ) begin 182 | axi_rready <= 1'b0; 183 | end 184 | else if (M_AXI_RVALID && ~axi_rready) begin 185 | axi_rready <= 1'b1; 186 | end 187 | else if (axi_rready) begin 188 | axi_rready <= 1'b0; 189 | end 190 | end 191 | //-------------------------------- 192 | //User Logic 193 | //-------------------------------- 194 | // DMA Read_Example By CNILeo 195 | // 顶层状态机 196 | always @(posedge M_AXI_ACLK) begin 197 | if(M_AXI_ARESETN == 0) begin 198 | state <= ST_IDLE; 199 | end 200 | else begin 201 | case(state) 202 | ST_IDLE: begin 203 | if(DMA_READ_VALID) begin 204 | state <= ST_CHECK_IDLE; 205 | end 206 | end 207 | ST_CHECK_IDLE: begin 208 | if(M_AXI_RDATA[0] == 1 || M_AXI_RDATA[1] == 1) begin 209 | state <= ST_CONTROL_DMA; 210 | end 211 | end 212 | ST_CONTROL_DMA: begin 213 | if(write_state == WR_COMPLETE) begin 214 | state <= ST_WAITING_IRQ; 215 | end 216 | end 217 | ST_WAITING_IRQ: begin 218 | if(DMA_READ_IRQ) begin 219 | state <= ST_CLEAR_IRQ; 220 | end 221 | end 222 | ST_CLEAR_IRQ: begin 223 | if(~DMA_READ_IRQ) begin 224 | state <= ST_IDLE; 225 | end 226 | end 227 | default: 228 | state <= ST_IDLE; 229 | endcase 230 | end 231 | end 232 | //Control DMA_IDLE 233 | always @(posedge M_AXI_ACLK) begin 234 | if (M_AXI_ARESETN == 0) begin 235 | dma_idle <= 1'b0; 236 | end 237 | else if(state == ST_IDLE) begin 238 | dma_idle <= 1'b1; 239 | end 240 | else begin 241 | dma_idle <= 1'b0; 242 | end 243 | end 244 | 245 | assign DMA_IDLE = dma_idle & ~DMA_READ_VALID; 246 | //Read Addresses 247 | always @(posedge M_AXI_ACLK) begin 248 | if (M_AXI_ARESETN == 0) begin 249 | axi_araddr <= 0; 250 | end 251 | else if (start_single_read) begin 252 | axi_araddr <= IDLE_REG; 253 | end 254 | else begin 255 | axi_araddr <= 0; 256 | end 257 | end 258 | //Read Complete 259 | always @(posedge M_AXI_ACLK) begin 260 | if(M_AXI_ARESETN == 0) begin 261 | read_complete <= 1'b1; 262 | end 263 | else begin 264 | if(axi_rready == 1'b1) 265 | read_complete <= 1'b1; 266 | else if(start_single_read == 1'b1) 267 | read_complete <= 1'b0; 268 | else 269 | read_complete <= read_complete; 270 | end 271 | end 272 | //ENA Check_IDLE 273 | always @(posedge M_AXI_ACLK)begin 274 | if(M_AXI_ARESETN == 0) begin 275 | start_single_read <= 1'b0; 276 | end 277 | else begin 278 | if(state == ST_CHECK_IDLE) begin 279 | if(~axi_arvalid && ~M_AXI_RVALID && ~start_single_read && read_complete) begin 280 | start_single_read <= 1'b1; 281 | end 282 | else begin 283 | start_single_read <= 1'b0; //Negate to generate a pulse 284 | end 285 | end 286 | else begin 287 | start_single_read <= 1'b0; 288 | end 289 | end 290 | end 291 | // 计数器 - 控制写地址 292 | // 4'b0001 - IRQ_Set - 32'h0000_0000 - 293 | // 4'b0010 - Source Address Set - 32'h0000_0018 294 | // 4'b0011 - Length_Set - 32'h0000_0028 295 | // 4'b0100 - Clear_IRQ - 32'h0000_0004 296 | always @(posedge M_AXI_ACLK)begin 297 | if(M_AXI_ARESETN == 0) begin 298 | write_state <= WR_IRQ; 299 | end 300 | else begin 301 | if (state == ST_CONTROL_DMA) begin 302 | if(M_AXI_AWREADY && axi_awvalid) begin 303 | write_state <= write_state + 1'b1; 304 | end 305 | else begin 306 | write_state <= write_state; 307 | end 308 | end 309 | else begin 310 | write_state <= WR_IRQ; 311 | end 312 | end 313 | end 314 | // Write Addresses 315 | always @(posedge M_AXI_ACLK) begin 316 | if (M_AXI_ARESETN == 0) begin 317 | axi_awaddr <= 32'h0000_0000; 318 | end 319 | else if(state == ST_CONTROL_DMA) begin 320 | case(write_state) 321 | WR_IRQ: axi_awaddr <= IRQ_REG; 322 | WR_SA: axi_awaddr <= SA_REG; 323 | WR_LENGTH: axi_awaddr <= LENGTH_REG; 324 | default: axi_awaddr <= 32'h0000_0000; 325 | endcase 326 | end 327 | else if(state == ST_CLEAR_IRQ) begin 328 | axi_awaddr <= CLEAR_REG; 329 | end 330 | else begin 331 | axi_awaddr <= 0; 332 | end 333 | end 334 | //start single write 335 | always @(posedge M_AXI_ACLK) begin 336 | if(M_AXI_ARESETN == 0) begin 337 | start_single_write <= 0; 338 | end 339 | else begin 340 | if (state == ST_CONTROL_DMA || state == ST_CLEAR_IRQ) begin 341 | if (~axi_awvalid && ~axi_wvalid && ~M_AXI_BVALID && ~start_single_write) begin 342 | start_single_write <= 1'b1; 343 | end 344 | else begin 345 | start_single_write <= 0; 346 | end 347 | end 348 | else begin 349 | start_single_write <= 0; 350 | end 351 | end 352 | end 353 | // Write Data Control 354 | always @(posedge M_AXI_ACLK) begin 355 | if (M_AXI_ARESETN == 0) begin 356 | axi_wdata <= C_M_START_DATA_VALUE; 357 | end 358 | else if (state == ST_CONTROL_DMA) begin 359 | case(write_state) 360 | WR_IRQ: axi_wdata <= C_M_START_DATA_VALUE + IRQ_DATA; 361 | WR_SA: axi_wdata <= C_M_START_DATA_VALUE + dma_sa_config; 362 | WR_LENGTH: axi_wdata <= C_M_START_DATA_VALUE + dma_length_config; 363 | default: axi_wdata <= C_M_START_DATA_VALUE; 364 | endcase 365 | end 366 | else if(state == ST_CLEAR_IRQ)begin 367 | axi_wdata <= CLEAR_DATA; 368 | end 369 | else begin 370 | axi_wdata <= C_M_START_DATA_VALUE; 371 | end 372 | end 373 | endmodule -------------------------------------------------------------------------------- /03_hardware/axi_dma_ctrl/axi_dma_write_ctrl.sv: -------------------------------------------------------------------------------- 1 | `timescale 1 ns / 1 ps 2 | 3 | module axi_dma_write_ctrl # 4 | ( 5 | parameter C_M_START_DATA_VALUE = 32'h0000_0000, 6 | parameter C_M_TARGET_SLAVE_BASE_ADDR = 32'h0000_0000, 7 | parameter integer C_M_AXI_ADDR_WIDTH = 32, 8 | parameter integer C_M_AXI_DATA_WIDTH = 32, 9 | parameter integer C_M_TRANSACTIONS_NUM = 4 10 | ) 11 | ( 12 | // User Define 13 | input wire [31:0] DMA_DA_CONFIG, 14 | input wire [25:0] DMA_LENGTH_CONFIG, 15 | input wire DMA_WRITE_IRQ, 16 | input wire DMA_WRITE_VALID, 17 | output wire DMA_WRITE_IDLE, 18 | // AXI Lite Defination 19 | input wire M_AXI_ACLK, 20 | input wire M_AXI_ARESETN, 21 | output wire [C_M_AXI_ADDR_WIDTH-1 : 0] M_AXI_AWADDR, 22 | output wire [2 : 0] M_AXI_AWPROT, 23 | output wire M_AXI_AWVALID, 24 | input wire M_AXI_AWREADY, 25 | output wire [C_M_AXI_DATA_WIDTH-1 : 0] M_AXI_WDATA, 26 | output wire [C_M_AXI_DATA_WIDTH/8-1 : 0] M_AXI_WSTRB, 27 | output wire M_AXI_WVALID, 28 | input wire M_AXI_WREADY, 29 | input wire [1 : 0] M_AXI_BRESP, 30 | input wire M_AXI_BVALID, 31 | output wire M_AXI_BREADY, 32 | output wire [C_M_AXI_ADDR_WIDTH-1 : 0] M_AXI_ARADDR, 33 | output wire [2 : 0] M_AXI_ARPROT, 34 | output wire M_AXI_ARVALID, 35 | input wire M_AXI_ARREADY, 36 | input wire [C_M_AXI_DATA_WIDTH-1 : 0] M_AXI_RDATA, 37 | input wire [1 : 0] M_AXI_RRESP, 38 | input wire M_AXI_RVALID, 39 | output wire M_AXI_RREADY 40 | ); 41 | // AXI4LITE signals 42 | //write address valid 43 | reg axi_awvalid; 44 | //write data valid 45 | reg axi_wvalid; 46 | //read address valid 47 | reg axi_arvalid; 48 | //read data acceptance 49 | reg axi_rready; 50 | //write response acceptance 51 | reg axi_bready; 52 | //write address 53 | reg [C_M_AXI_ADDR_WIDTH-1 : 0] axi_awaddr; 54 | //write data 55 | reg [C_M_AXI_DATA_WIDTH-1 : 0] axi_wdata; 56 | //read addresss 57 | reg [C_M_AXI_ADDR_WIDTH-1 : 0] axi_araddr; 58 | //A pulse to initiate a write transaction 59 | reg start_single_write; 60 | //A pulse to initiate a read transaction 61 | reg start_single_read; 62 | // User Logic 63 | // Internal signals 64 | logic [3:0] state; // Main FSM state 65 | logic [3:0] write_state; // Write sequence state 66 | logic read_complete; // Read transaction complete flag 67 | logic dma_write_idle; // DMA idle status 68 | 69 | // State machine parameters 70 | typedef enum logic [3:0] { 71 | ST_IDLE = 4'b0000, // Wait for DMA write valid 72 | ST_CHECK_IDLE = 4'b0001, // Check if DMA is idle 73 | ST_CTRL_DMA = 4'b0010, // Configure DMA registers 74 | ST_WAIT_IRQ = 4'b0100, // Wait for interrupt 75 | ST_CLEAR_IRQ = 4'b1000 // Clear interrupt flag 76 | } state_t; 77 | 78 | // Write sequence states 79 | typedef enum logic [3:0] { 80 | WR_IRQ = 4'b0000, // Write IRQ configuration 81 | WR_DA = 4'b0001, // Write destination address 82 | WR_LENGTH = 4'b0010, // Write transfer length 83 | WR_COMPLETE = 4'b0011 // Write sequence complete 84 | } write_state_t; 85 | 86 | // Register addresses 87 | localparam ADDR_IDLE = 32'h0000_0034; // Idle status register 88 | localparam ADDR_IRQ = 32'h0000_0030; // IRQ control register 89 | localparam ADDR_DA = 32'h0000_0048; // Destination address register 90 | localparam ADDR_LENGTH = 32'h0000_0058; // Transfer length register 91 | localparam ADDR_CLEAR = 32'h0000_0034; // IRQ clear register 92 | 93 | // Register values 94 | localparam VAL_IRQ = 32'h00011003; // IRQ enable configuration 95 | localparam VAL_CLEAR = 32'h00011001; // IRQ clear value 96 | 97 | // Configuration signals 98 | logic [31:0] dma_da_config; // Destination address config 99 | logic [25:0] dma_length_config; // Transfer length config 100 | 101 | // Register updates on valid 102 | assign dma_da_config = DMA_WRITE_VALID ? DMA_DA_CONFIG : dma_da_config; 103 | assign dma_length_config = DMA_WRITE_VALID ? DMA_LENGTH_CONFIG : dma_length_config; 104 | 105 | //Adding the offset address to the base addr of the slave 106 | assign M_AXI_AWADDR = C_M_TARGET_SLAVE_BASE_ADDR + axi_awaddr; 107 | //AXI 4 write data 108 | assign M_AXI_WDATA = axi_wdata; 109 | assign M_AXI_AWPROT = 3'b000; 110 | assign M_AXI_AWVALID = axi_awvalid; 111 | //Write Data(W) 112 | assign M_AXI_WVALID = axi_wvalid; 113 | //Set all byte strobes in this example 114 | assign M_AXI_WSTRB = 4'b1111; 115 | //Write Response (B) 116 | assign M_AXI_BREADY = axi_bready; 117 | //Read Address (AR) 118 | assign M_AXI_ARADDR = C_M_TARGET_SLAVE_BASE_ADDR + axi_araddr; 119 | assign M_AXI_ARVALID = axi_arvalid; 120 | assign M_AXI_ARPROT = 3'b001; 121 | //Read and Read Response (R) 122 | assign M_AXI_RREADY = axi_rready; 123 | // User Logic 124 | wire [31:0] dma_da_config ; 125 | wire [25:0] dma_length_config ; 126 | assign dma_da_config = DMA_WRITE_VALID == 1 ? DMA_DA_CONFIG : dma_da_config; 127 | assign dma_length_config = DMA_WRITE_VALID == 1 ? DMA_LENGTH_CONFIG : dma_length_config; 128 | //Write Address Channel 129 | always @(posedge M_AXI_ACLK) begin 130 | if (M_AXI_ARESETN == 0) begin 131 | axi_awvalid <= 1'b0; 132 | end 133 | else begin 134 | if (start_single_write) begin 135 | axi_awvalid <= 1'b1; 136 | end 137 | else if (M_AXI_AWREADY && axi_awvalid) begin 138 | axi_awvalid <= 1'b0; 139 | end 140 | end 141 | end 142 | //Write Data Channel 143 | always @(posedge M_AXI_ACLK) begin 144 | if (M_AXI_ARESETN == 0 ) begin 145 | axi_wvalid <= 1'b0; 146 | end 147 | else if (start_single_write) begin 148 | axi_wvalid <= 1'b1; 149 | end 150 | else if (M_AXI_WREADY && axi_wvalid) begin 151 | axi_wvalid <= 1'b0; 152 | end 153 | end 154 | //Write Response (B) Channel 155 | always @(posedge M_AXI_ACLK) begin 156 | if (M_AXI_ARESETN == 0) begin 157 | axi_bready <= 1'b0; 158 | end 159 | else if (M_AXI_BVALID && ~axi_bready) begin 160 | axi_bready <= 1'b1; 161 | end 162 | else if (axi_bready) begin 163 | axi_bready <= 1'b0; 164 | end 165 | else begin 166 | axi_bready <= axi_bready; 167 | end 168 | end 169 | // Write Response (B) Valid 170 | always @(posedge M_AXI_ACLK) begin 171 | if (M_AXI_ARESETN == 0) begin 172 | axi_arvalid <= 1'b0; 173 | end 174 | else if (start_single_read) begin 175 | axi_arvalid <= 1'b1; 176 | end 177 | else if (M_AXI_ARREADY && axi_arvalid) begin 178 | axi_arvalid <= 1'b0; 179 | end 180 | end 181 | //Read Data (and Response) Channel 182 | always @(posedge M_AXI_ACLK) begin 183 | if (M_AXI_ARESETN == 0 ) begin 184 | axi_rready <= 1'b0; 185 | end 186 | else if (M_AXI_RVALID && ~axi_rready) begin 187 | axi_rready <= 1'b1; 188 | end 189 | else if (axi_rready) begin 190 | axi_rready <= 1'b0; 191 | end 192 | end 193 | //-------------------------------- 194 | //User Logic 195 | //-------------------------------- 196 | // DMA Read_Example By CNILeo 197 | // 顶层状态机 198 | always @(posedge M_AXI_ACLK) begin 199 | if(M_AXI_ARESETN == 0) begin 200 | state <= ST_IDLE; 201 | end 202 | else begin 203 | case(state) 204 | ST_IDLE: begin 205 | if(DMA_WRITE_VALID) begin 206 | state <= ST_CHECK_IDLE; 207 | end 208 | end 209 | ST_CHECK_IDLE: begin 210 | if(M_AXI_RDATA[0] == 1 || M_AXI_RDATA[1] == 1) begin 211 | state <= ST_CTRL_DMA; 212 | end 213 | end 214 | ST_CTRL_DMA: begin 215 | if(write_state == WR_COMPLETE) begin 216 | state <= ST_WAIT_IRQ; 217 | end 218 | end 219 | ST_WAIT_IRQ: begin 220 | if(DMA_WRITE_IRQ) begin 221 | state <= ST_CLEAR_IRQ; 222 | end 223 | end 224 | ST_CLEAR_IRQ: begin 225 | if(~DMA_WRITE_IRQ) begin 226 | state <= ST_IDLE; 227 | end 228 | end 229 | default: 230 | state <= ST_IDLE; 231 | endcase 232 | end 233 | end 234 | //Control DMA_IDLE 235 | always @(posedge M_AXI_ACLK) begin 236 | if (M_AXI_ARESETN == 0) begin 237 | dma_write_idle <= 1'b0; 238 | end 239 | else if(state == ST_IDLE) begin 240 | dma_write_idle <= 1'b1; 241 | end 242 | else begin 243 | dma_write_idle <= 1'b0; 244 | end 245 | end 246 | 247 | assign DMA_WRITE_IDLE = dma_write_idle & ~DMA_WRITE_VALID; 248 | //Read Addresses 249 | always @(posedge M_AXI_ACLK) begin 250 | if (M_AXI_ARESETN == 0) begin 251 | axi_araddr <= 0; 252 | end 253 | else if (start_single_read) begin 254 | axi_araddr <= ADDR_IDLE; 255 | end 256 | else begin 257 | axi_araddr <= 0; 258 | end 259 | end 260 | //Read Complete 261 | always @(posedge M_AXI_ACLK) begin 262 | if(M_AXI_ARESETN == 0) begin 263 | read_complete <= 1'b1; 264 | end 265 | else begin 266 | if(axi_rready == 1'b1) 267 | read_complete <= 1'b1; 268 | else if(start_single_read == 1'b1) 269 | read_complete <= 1'b0; 270 | else 271 | read_complete <= read_complete; 272 | end 273 | end 274 | //ENA Check_IDLE 275 | always @(posedge M_AXI_ACLK)begin 276 | if(M_AXI_ARESETN == 0) begin 277 | start_single_read <= 1'b0; 278 | end 279 | else begin 280 | if(state == ST_CHECK_IDLE) begin 281 | if(~axi_arvalid && ~M_AXI_RVALID && ~start_single_read && read_complete) begin 282 | start_single_read <= 1'b1; 283 | end 284 | else begin 285 | start_single_read <= 1'b0; //Negate to generate a pulse 286 | end 287 | end 288 | else begin 289 | start_single_read <= 1'b0; 290 | end 291 | end 292 | end 293 | // 计数器 - 控制写地址 294 | // 4'b0001 - IRQ_Set - 32'h0000_0000 - 295 | // 4'b0010 - Source Address Set - 32'h0000_0018 296 | // 4'b0011 - Length_Set - 32'h0000_0028 297 | // 4'b0100 - Clear_IRQ - 32'h0000_0004 298 | always @(posedge M_AXI_ACLK)begin 299 | if(M_AXI_ARESETN == 0) begin 300 | write_state <= WR_IRQ; 301 | end 302 | else begin 303 | if (state == ST_CTRL_DMA) begin 304 | if(M_AXI_AWREADY && axi_awvalid) begin 305 | write_state <= write_state + 1'b1; 306 | end 307 | else begin 308 | write_state <= write_state; 309 | end 310 | end 311 | else begin 312 | write_state <= WR_IRQ; 313 | end 314 | end 315 | end 316 | // Write Addresses 317 | always @(posedge M_AXI_ACLK) begin 318 | if (M_AXI_ARESETN == 0) begin 319 | axi_awaddr <= 32'h0000_0000; 320 | end 321 | else if(state == ST_CTRL_DMA) begin 322 | case(write_state) 323 | WR_IRQ: axi_awaddr <= ADDR_IRQ; 324 | WR_DA: axi_awaddr <= ADDR_DA; 325 | WR_LENGTH: axi_awaddr <= ADDR_LENGTH; 326 | default: axi_awaddr <= 32'h0000_0000; 327 | endcase 328 | end 329 | else if(state == ST_CLEAR_IRQ) begin 330 | axi_awaddr <= ADDR_CLEAR; 331 | end 332 | else begin 333 | axi_awaddr <= 0; 334 | end 335 | end 336 | //start single write 337 | always @(posedge M_AXI_ACLK) begin 338 | if(M_AXI_ARESETN == 0) begin 339 | start_single_write <= 0; 340 | end 341 | else begin 342 | if (state == ST_CTRL_DMA || state == ST_CLEAR_IRQ) begin 343 | if (~axi_awvalid && ~axi_wvalid && ~M_AXI_BVALID && ~start_single_write) begin 344 | start_single_write <= 1'b1; 345 | end 346 | else begin 347 | start_single_write <= 0; 348 | end 349 | end 350 | else begin 351 | start_single_write <= 0; 352 | end 353 | end 354 | end 355 | // Write Data Control 356 | always @(posedge M_AXI_ACLK) begin 357 | if (M_AXI_ARESETN == 0) begin 358 | axi_wdata <= C_M_START_DATA_VALUE; 359 | end 360 | else if (state == ST_CTRL_DMA) begin 361 | case(write_state) 362 | WR_IRQ: axi_wdata <= C_M_START_DATA_VALUE + VAL_IRQ; 363 | WR_DA: axi_wdata <= C_M_START_DATA_VALUE + dma_da_config; 364 | WR_LENGTH: axi_wdata <= C_M_START_DATA_VALUE + dma_length_config; 365 | default: axi_wdata <= C_M_START_DATA_VALUE; 366 | endcase 367 | end 368 | else if(state == ST_CLEAR_IRQ)begin 369 | axi_wdata <= VAL_CLEAR; 370 | end 371 | else begin 372 | axi_wdata <= C_M_START_DATA_VALUE; 373 | end 374 | end 375 | endmodule 376 | -------------------------------------------------------------------------------- /03_hardware/axi_dma_ctrl/tb_axi_dma_read_ctrl.sv: -------------------------------------------------------------------------------- 1 | `timescale 1ns / 1ps 2 | 3 | module axi_dma_read_ctrl_tb(); 4 | 5 | // Clock period parameters 6 | localparam CLK_PERIOD = 20; // 50MHz clock 7 | localparam SA_INCREMENT = 32'h0000_0020; 8 | localparam MAX_SA = 32'h0000_0800; 9 | 10 | // Testbench signals 11 | logic clk; 12 | logic rst_n; // Active low reset 13 | 14 | // DUT interface signals 15 | logic [25:0] dma_length; 16 | logic [31:0] dma_start_addr; 17 | logic dma_read_valid; 18 | logic dma_idle; 19 | logic dma_irq; 20 | 21 | // DUT instantiation 22 | dma_read_ctrl dut ( 23 | .CLK (clk), 24 | .RST (rst_n), 25 | .dma_read_valid(dma_read_valid), 26 | .dma_length_config(dma_length), 27 | .dma_sa_config (dma_start_addr), 28 | .DMA_IRQ (dma_irq), 29 | .dma_idle (dma_idle) 30 | ); 31 | 32 | // Clock generation 33 | always begin 34 | clk = 1'b0; 35 | #(CLK_PERIOD/2); 36 | clk = 1'b1; 37 | #(CLK_PERIOD/2); 38 | end 39 | 40 | // DMA read valid control 41 | always_ff @(posedge clk) begin 42 | if (!rst_n) begin 43 | dma_read_valid <= 1'b0; 44 | end else begin 45 | dma_read_valid <= dma_idle; 46 | end 47 | end 48 | 49 | // Start address increment on each transaction 50 | always_ff @(negedge dma_read_valid) begin 51 | if (!rst_n) begin 52 | dma_start_addr <= '0; 53 | end else begin 54 | dma_start_addr <= dma_start_addr + SA_INCREMENT; 55 | end 56 | end 57 | 58 | // Test stimulus 59 | initial begin 60 | // Initialize signals 61 | rst_n = 1'b1; 62 | dma_length = 26'h0000040; 63 | 64 | // Apply reset 65 | #(CLK_PERIOD); 66 | rst_n = 1'b0; 67 | #(2*CLK_PERIOD); 68 | rst_n = 1'b1; 69 | 70 | // Wait for completion or timeout 71 | wait(dma_start_addr > MAX_SA); 72 | 73 | // Add some delay for final transaction 74 | #(10*CLK_PERIOD); 75 | $finish; 76 | end 77 | 78 | // Optional: Add assertions 79 | // pragma translate_off 80 | assert property (@(posedge clk) disable iff(!rst_n) 81 | dma_read_valid |-> !$isunknown(dma_start_addr)) 82 | else $error("Start address unknown when read_valid is high"); 83 | // pragma translate_on 84 | 85 | endmodule -------------------------------------------------------------------------------- /03_hardware/axi_dma_ctrl/tb_axi_dma_write_ctrl.sv: -------------------------------------------------------------------------------- 1 | `timescale 1ns / 1ps 2 | 3 | module axi_dma_write_ctrl_tb(); 4 | 5 | // Clock period parameters 6 | localparam CLK_PERIOD = 20; // 50MHz clock 7 | localparam DA_INCREMENT = 32'h0000_0020; 8 | localparam MAX_DA = 32'h0000_0800; 9 | 10 | // Testbench signals 11 | logic clk; 12 | logic rst_n; // Active low reset 13 | 14 | // DUT interface signals 15 | logic [25:0] dma_length; 16 | logic [31:0] dma_dest_addr; 17 | logic dma_write_valid; 18 | logic dma_idle; 19 | logic dma_irq; 20 | 21 | // DUT instantiation 22 | axi_dma_write_ctrl dut ( 23 | .CLK (clk), 24 | .RST (rst_n), 25 | .dma_write_valid(dma_write_valid), 26 | .dma_length_config(dma_length), 27 | .dma_da_config (dma_dest_addr), 28 | .DMA_Write_INT (dma_irq), 29 | .dma_write_idle(dma_idle) 30 | ); 31 | 32 | // Clock generation 33 | always begin 34 | clk = 1'b0; 35 | #(CLK_PERIOD/2); 36 | clk = 1'b1; 37 | #(CLK_PERIOD/2); 38 | end 39 | 40 | // DMA write valid control 41 | always_ff @(posedge clk) begin 42 | if (!rst_n) begin 43 | dma_write_valid <= 1'b0; 44 | end else begin 45 | dma_write_valid <= dma_idle; 46 | end 47 | end 48 | 49 | // Destination address increment on each transaction 50 | always_ff @(negedge dma_write_valid) begin 51 | if (!rst_n) begin 52 | dma_dest_addr <= '0; 53 | end else begin 54 | dma_dest_addr <= dma_dest_addr + DA_INCREMENT; 55 | end 56 | end 57 | 58 | // Test stimulus 59 | initial begin 60 | // Initialize signals 61 | rst_n = 1'b1; 62 | dma_length = 26'h0000040; 63 | 64 | // Apply reset 65 | #(CLK_PERIOD); 66 | rst_n = 1'b0; 67 | #(2*CLK_PERIOD); 68 | rst_n = 1'b1; 69 | 70 | // Wait for completion or timeout 71 | wait(dma_dest_addr > MAX_DA); 72 | 73 | // Add some delay for final transaction 74 | #(10*CLK_PERIOD); 75 | $finish; 76 | end 77 | 78 | // Optional: Add assertions 79 | // pragma translate_off 80 | assert property (@(posedge clk) disable iff(!rst_n) 81 | dma_write_valid |-> !$isunknown(dma_dest_addr)) 82 | else $error("Destination address unknown when write_valid is high"); 83 | // pragma translate_on 84 | 85 | endmodule -------------------------------------------------------------------------------- /03_hardware/axi_ir/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | This code utilizes [AirHDL](https://airhdl.com/) to automatically generate code for configuring neural network accelerators via the AXI protocol, specifically through Instruction Registers. 4 | 5 | # Register Map 6 | 7 | ## Load Store Unit (LSU) 8 | 9 | ### Load from Memory 10 | | Register Name | Offset | Description | 11 | |--------------------------|--------|--------------------------------------------| 12 | | LSU_LD_SRC_ADDR_LOW | 0 | Lower 32 bits of source memory address | 13 | | LSU_LD_SRC_ADDR_HIGH | 1 | Higher 32 bits of source memory address | 14 | | LSU_LD_DST_ID | 2 | Destination ID for loading data | 15 | | LSU_LD_LENGTH | 3 | Length of data to be loaded | 16 | | LSU_LD_CONTROL | 4 | Control register for load operation | 17 | | LSU_LD_STATUS | 5 | Status register for load operation | 18 | 19 | ### Store to Memory 20 | | Register Name | Offset | Description | 21 | |--------------------------|--------|--------------------------------------------| 22 | | LSU_ST_SRC_ID | 0 | Source ID for storing data | 23 | | LSU_ST_SRC_ADDR | 1 | Source address for store operation | 24 | | LSU_ST_DST_ADDR_LOW | 2 | Lower 32 bits of destination memory address| 25 | | LSU_ST_DST_ADDR_HIGH | 3 | Higher 32 bits of destination memory address | 26 | | LSU_ST_LENGTH | 4 | Length of data to be stored | 27 | | LSU_ST_CONTROL | 5 | Control register for store operation | 28 | | LSU_ST_STATUS | 6 | Status register for store operation | 29 | 30 | ### Load from Cache 31 | | Register Name | Offset | Description | 32 | |--------------------------|--------|--------------------------------------------| 33 | | LSU_LD_CACHE_SRC_ID | 0 | Source ID for cache load | 34 | | LSU_LD_CACHE_SRC_ADDR | 1 | Source address in cache | 35 | | LSU_LD_CACHE_DST_ID | 2 | Destination ID for cache load | 36 | | LSU_LD_CACHE_DST_ADDR_LOW| 3 | Lower 32 bits of destination address | 37 | | LSU_LD_CACHE_DST_ADDR_HIGH | 4 | Higher 32 bits of destination address | 38 | | LSU_LD_CACHE_LENGTH | 5 | Length of data to be loaded from cache | 39 | | LSU_LD_CACHE_CONTROL | 6 | Control register for cache load operation | 40 | | LSU_LD_CACHE_STATUS | 7 | Status register for cache load operation | 41 | 42 | ### Store to Cache 43 | | Register Name | Offset | Description | 44 | |--------------------------|--------|--------------------------------------------| 45 | | LSU_ST_CACHE_SRC_ADDR_LOW| 0 | Lower 32 bits of source address | 46 | | LSU_ST_CACHE_SRC_ADDR_HIGH | 1 | Higher 32 bits of source address | 47 | | LSU_ST_CACHE_DST_ID | 2 | Destination ID for cache store | 48 | | LSU_ST_CACHE_DST_ADDR | 3 | Destination address in cache | 49 | | LSU_ST_CACHE_LENGTH | 4 | Length of data to be stored to cache | 50 | | LSU_ST_CACHE_CONTROL | 5 | Control register for cache store operation | 51 | | LSU_ST_CACHE_STATUS | 6 | Status register for cache store operation | 52 | 53 | ## Image-to-Column Conversion (IMG2COL) 54 | | Register Name | Offset | Description | 55 | |--------------------------|--------|--------------------------------------------| 56 | | IMG2COL_IN_HEIGHT | 0 | Input image height | 57 | | IMG2COL_IN_WIDTH | 1 | Input image width | 58 | | IMG2COL_IN_CHANNELS | 2 | Number of input channels | 59 | | IMG2COL_KERNEL_SIZE | 3 | Convolution kernel size | 60 | | IMG2COL_STRIDE | 4 | Convolution stride | 61 | | IMG2COL_PAD | 5 | Padding size | 62 | | IMG2COL_CONTROL | 6 | Control register for img2col operation | 63 | | IMG2COL_STATUS | 7 | Status register for img2col operation | 64 | 65 | ## Systolic Array 66 | | Register Name | Offset | Description | 67 | |--------------------------|--------|--------------------------------------------| 68 | | SYSTOLIC_ARRAY_IN_HEIGHT | 0 | Input matrix height | 69 | | SYSTOLIC_ARRAY_OUT_HEIGHT| 1 | Output matrix height | 70 | | SYSTOLIC_ARRAY_IN_WIDTH | 2 | Input matrix width | 71 | | SYSTOLIC_ARRAY_OUT_WIDTH | 3 | Output matrix width | 72 | | SYSTOLIC_ARRAY_IN_CHANNELS | 4 | Number of input channels | 73 | | SYSTOLIC_ARRAY_OUT_CHANNELS | 5 | Number of output channels | 74 | | SYSTOLIC_ARRAY_RELU_EN | 6 | ReLU activation enable | 75 | | SYSTOLIC_ARRAY_CONTROL | 7 | Control register for systolic array | 76 | | SYSTOLIC_ARRAY_STATUS | 8 | Status register for systolic array | -------------------------------------------------------------------------------- /03_hardware/systolic_array/README.md: -------------------------------------------------------------------------------- 1 | # Systolic Array Implementation 2 | This folder contains a parameterized systolic array implementation in SystemVerilog, designed for efficient matrix multiplication. 3 | 4 | The design uses a grid of Processing Elements (PEs) that perform multiply-accumulate (MAC) operations using DSP48E1 blocks. 5 | 6 | ## Features 7 | - Parameterized NxN array size 8 | - Configurable data width 9 | - Pipelined architecture for high throughput 10 | - DSP48E1-based MAC operations 11 | - Valid signal propagation for result verification 12 | 13 | ## Architecture 14 | The systolic array consists of: 15 | - An NxN grid of Processing Elements (PEs) 16 | - Input data distribution network 17 | - Output collection network 18 | - Valid signal propagation chain 19 | 20 | ### Processing Element (PE) 21 | Each PE performs the following operations: 22 | - Receives input data from west (a_in) and north (b_in) 23 | - Performs MAC operation: c_out = c_in + (a_in * b_in) 24 | - Propagates data to east (a_out) and south (b_out) 25 | - Propagates valid signal 26 | 27 | ## Parameters 28 | - `ARRAY_SIZE`: Size of the NxN array (default: 8) 29 | - `DATA_WIDTH`: Bit width of input data (default: 8) 30 | 31 | 32 | ## Usage 33 | 1. Configure the array size and data width parameters as needed 34 | 2. Provide input matrices A and B through `a_inputs` and `b_inputs` 35 | 3. Assert `valid_in` when input data is valid 36 | 4. Wait for `valid_out` to indicate valid output data 37 | 5. Read result matrix from `c_outputs` 38 | 39 | ## Timing 40 | - Results appear after (2 * ARRAY_SIZE - 1) clock cycles 41 | - New input can be started every clock cycle 42 | - Output valid signal indicates when results are ready 43 | 44 | ## Example 45 | For an 8x8 array with 8-bit data width: 46 | - Matrix multiplication of 8x8 matrices 47 | - Input data range: -128 to 127 (8-bit) 48 | - Output data width: 19 bits (8*2 + 3 bits for accumulation) 49 | - Latency: 15 clock cycles for first result -------------------------------------------------------------------------------- /03_hardware/systolic_array/systolic_array.sv: -------------------------------------------------------------------------------- 1 | `timescale 1ns / 1ps 2 | 3 | // Parameterized Systolic Array for matrix multiplication 4 | // Uses DSP48E1 for MAC operations 5 | module systolic_array #( 6 | parameter ARRAY_SIZE = 8, // NxN array size 7 | parameter DATA_WIDTH = 8 // Input data width 8 | )( 9 | input logic clk, 10 | input logic rst_n, 11 | input logic valid_in, 12 | input logic [DATA_WIDTH-1:0] a_inputs [ARRAY_SIZE-1:0], // Input matrix A 13 | input logic [DATA_WIDTH-1:0] b_inputs [ARRAY_SIZE-1:0], // Input matrix B 14 | output logic [DATA_WIDTH*2+ARRAY_SIZE-1:0] c_outputs [ARRAY_SIZE-1:0], // Output matrix C 15 | output logic valid_out 16 | ); 17 | 18 | // Internal signals for PE connections 19 | logic [DATA_WIDTH-1:0] a_wires [ARRAY_SIZE:0][ARRAY_SIZE:0]; 20 | logic [DATA_WIDTH-1:0] b_wires [ARRAY_SIZE:0][ARRAY_SIZE:0]; 21 | logic [DATA_WIDTH*2+ARRAY_SIZE-1:0] c_wires [ARRAY_SIZE:0][ARRAY_SIZE:0]; 22 | 23 | // Valid signal propagation 24 | logic valid_regs [ARRAY_SIZE:0][ARRAY_SIZE:0]; 25 | 26 | // Input assignment 27 | always @(posedge clk or negedge rst_n) begin 28 | if (!rst_n) begin 29 | for (int i = 0; i < ARRAY_SIZE; i++) begin 30 | a_wires[0][i] <= '0; 31 | b_wires[i][0] <= '0; 32 | end 33 | end else begin 34 | for (int i = 0; i < ARRAY_SIZE; i++) begin 35 | a_wires[0][i] <= a_inputs[i]; 36 | b_wires[i][0] <= b_inputs[i]; 37 | end 38 | end 39 | end 40 | 41 | // Generate the PE array 42 | genvar i, j; 43 | generate 44 | for (i = 0; i < ARRAY_SIZE; i++) begin : ROW 45 | for (j = 0; j < ARRAY_SIZE; j++) begin : COL 46 | processing_element #( 47 | .DATA_WIDTH(DATA_WIDTH) 48 | ) pe ( 49 | .clk(clk), 50 | .rst_n(rst_n), 51 | .valid_in(valid_regs[i][j]), 52 | .a_in(a_wires[i][j]), 53 | .b_in(b_wires[i][j]), 54 | .c_in(c_wires[i][j]), 55 | .a_out(a_wires[i+1][j]), 56 | .b_out(b_wires[i][j+1]), 57 | .c_out(c_wires[i+1][j+1]), 58 | .valid_out(valid_regs[i+1][j+1]) 59 | ); 60 | end 61 | end 62 | endgenerate 63 | 64 | // Output assignment 65 | always @(posedge clk or negedge rst_n) begin 66 | if (!rst_n) begin 67 | for (int i = 0; i < ARRAY_SIZE; i++) begin 68 | c_outputs[i] <= '0; 69 | end 70 | valid_out <= 1'b0; 71 | end else begin 72 | for (int i = 0; i < ARRAY_SIZE; i++) begin 73 | c_outputs[i] <= c_wires[ARRAY_SIZE][i+1]; 74 | end 75 | valid_out <= valid_regs[ARRAY_SIZE][ARRAY_SIZE]; 76 | end 77 | end 78 | 79 | endmodule 80 | 81 | // Processing Element (PE) module 82 | module processing_element #( 83 | parameter DATA_WIDTH = 8 84 | )( 85 | input logic clk, 86 | input logic rst_n, 87 | input logic valid_in, 88 | input logic [DATA_WIDTH-1:0] a_in, 89 | input logic [DATA_WIDTH-1:0] b_in, 90 | input logic [DATA_WIDTH*2+ARRAY_SIZE-1:0] c_in, 91 | output logic [DATA_WIDTH-1:0] a_out, 92 | output logic [DATA_WIDTH-1:0] b_out, 93 | output logic [DATA_WIDTH*2+ARRAY_SIZE-1:0] c_out, 94 | output logic valid_out 95 | ); 96 | 97 | // Registered outputs 98 | always @(posedge clk or negedge rst_n) begin 99 | if (!rst_n) begin 100 | a_out <= '0; 101 | b_out <= '0; 102 | c_out <= '0; 103 | valid_out <= 1'b0; 104 | end else begin 105 | a_out <= a_in; 106 | b_out <= b_in; 107 | // MAC operation using DSP48E1 108 | c_out <= c_in + (a_in * b_in); 109 | valid_out <= valid_in; 110 | end 111 | end 112 | 113 | endmodule 114 | -------------------------------------------------------------------------------- /03_hardware/systolic_array/systolic_array_tb.sv: -------------------------------------------------------------------------------- 1 | `timescale 1ns / 1ps 2 | 3 | module systolic_array_tb; 4 | // Parameters 5 | localparam ARRAY_SIZE = 8; 6 | localparam DATA_WIDTH = 8; 7 | localparam CLK_PERIOD = 10; 8 | 9 | // Signals 10 | logic clk; 11 | logic rst_n; 12 | logic valid_in; 13 | logic [DATA_WIDTH-1:0] a_inputs [ARRAY_SIZE-1:0]; 14 | logic [DATA_WIDTH-1:0] b_inputs [ARRAY_SIZE-1:0]; 15 | logic [DATA_WIDTH*2+ARRAY_SIZE-1:0] c_outputs [ARRAY_SIZE-1:0]; 16 | logic valid_out; 17 | 18 | // DUT instantiation 19 | systolic_array #( 20 | .ARRAY_SIZE(ARRAY_SIZE), 21 | .DATA_WIDTH(DATA_WIDTH) 22 | ) dut (.*); 23 | 24 | // Clock generation 25 | initial begin 26 | clk = 0; 27 | forever #(CLK_PERIOD/2) clk = ~clk; 28 | end 29 | 30 | // Test stimulus 31 | initial begin 32 | // Initialize test matrices 33 | logic [DATA_WIDTH-1:0] matrix_a [ARRAY_SIZE][ARRAY_SIZE]; 34 | logic [DATA_WIDTH-1:0] matrix_b [ARRAY_SIZE][ARRAY_SIZE]; 35 | 36 | // Reset 37 | rst_n = 0; 38 | valid_in = 0; 39 | foreach(a_inputs[i]) a_inputs[i] = '0; 40 | foreach(b_inputs[i]) b_inputs[i] = '0; 41 | repeat(3) @(posedge clk); 42 | rst_n = 1; 43 | 44 | // Initialize test matrices with simple values 45 | for (int i = 0; i < ARRAY_SIZE; i++) begin 46 | for (int j = 0; j < ARRAY_SIZE; j++) begin 47 | matrix_a[i][j] = i + 1; // 1-8 in each row 48 | matrix_b[i][j] = j + 1; // 1-8 in each column 49 | end 50 | end 51 | 52 | // Feed the matrices 53 | valid_in = 1; 54 | for (int cycle = 0; cycle < ARRAY_SIZE*2; cycle++) begin 55 | for (int i = 0; i < ARRAY_SIZE; i++) begin 56 | // Diagonal feeding pattern 57 | if (cycle - i >= 0 && cycle - i < ARRAY_SIZE) begin 58 | a_inputs[i] = matrix_a[i][cycle-i]; 59 | b_inputs[i] = matrix_b[cycle-i][i]; 60 | end else begin 61 | a_inputs[i] = '0; 62 | b_inputs[i] = '0; 63 | end 64 | end 65 | @(posedge clk); 66 | end 67 | 68 | valid_in = 0; 69 | 70 | // Wait for computation to complete 71 | // Need to wait ARRAY_SIZE more cycles for the result to propagate 72 | repeat(ARRAY_SIZE + 5) @(posedge clk); 73 | 74 | // Print results 75 | $display("Matrix multiplication results:"); 76 | for (int i = 0; i < ARRAY_SIZE; i++) begin 77 | $write("Row %0d: ", i); 78 | for (int j = 0; j < ARRAY_SIZE; j++) begin 79 | $write("%d ", c_outputs[i]); 80 | end 81 | $display(""); 82 | end 83 | 84 | // End simulation 85 | #100; 86 | $finish; 87 | end 88 | 89 | // Optional: Waveform dumping 90 | initial begin 91 | $dumpfile("systolic_array_tb.vcd"); 92 | $dumpvars(0, systolic_array_tb); 93 | end 94 | 95 | endmodule -------------------------------------------------------------------------------- /04_software/.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Object files 5 | *.o 6 | *.ko 7 | *.obj 8 | *.elf 9 | 10 | # Linker output 11 | *.ilk 12 | *.map 13 | *.exp 14 | 15 | # Precompiled Headers 16 | *.gch 17 | *.pch 18 | 19 | # Libraries 20 | *.lib 21 | *.a 22 | *.la 23 | *.lo 24 | 25 | # Shared objects (inc. Windows DLLs) 26 | *.dll 27 | *.so 28 | *.so.* 29 | *.dylib 30 | 31 | # Executables 32 | *.exe 33 | *.out 34 | *.app 35 | *.i*86 36 | *.x86_64 37 | *.hex 38 | 39 | # Debug files 40 | *.dSYM/ 41 | *.su 42 | *.idb 43 | *.pdb 44 | 45 | # Kernel Module Compile Results 46 | *.mod* 47 | *.cmd 48 | .tmp_versions/ 49 | modules.order 50 | Module.symvers 51 | Mkfile.old 52 | dkms.conf 53 | 54 | # Compiled Object files 55 | *.slo 56 | 57 | # Fortran module files 58 | *.mod 59 | *.smod 60 | 61 | # Compiled Static libraries 62 | *.lai 63 | 64 | # Ignore all files in the build directory 65 | build/ 66 | bin/ 67 | 68 | # IDE 69 | .vscode/ 70 | -------------------------------------------------------------------------------- /04_software/README.md: -------------------------------------------------------------------------------- 1 | # Software For DNN Accelerator 2 | This project utilizes C++ to implement software for a DNN accelerator. 3 | 4 | ## Project Structure 5 | - **hal** 6 | 7 | - **driver** 8 | 9 | - **runtime** -------------------------------------------------------------------------------- /04_software/driver/Makefile: -------------------------------------------------------------------------------- 1 | # Compiler and flags 2 | CC = gcc 3 | CFLAGS = -Wall -Werror -O2 -fPIC 4 | CPPFLAGS = -I../hal/include -I./include 5 | 6 | # Library name and version 7 | LIB_NAME = libaccel_driver 8 | LIB_VERSION = 1.0.0 9 | SONAME = $(LIB_NAME).so.1 10 | 11 | # Directory structure 12 | SRC_DIR = src 13 | OBJ_DIR = obj 14 | TEST_DIR = test 15 | LIB_DIR = lib 16 | 17 | SRCS = $(wildcard $(SRC_DIR)/*.c) 18 | OBJS = $(SRCS:$(SRC_DIR)/%.c=$(OBJ_DIR)/%.o) 19 | TEST_SRCS = $(wildcard $(TEST_DIR)/*.c) 20 | TEST_BINS = $(TEST_SRCS:$(TEST_DIR)/%.c=$(TEST_DIR)/%) 21 | 22 | # Shared library files 23 | LIB_SO = $(LIB_DIR)/$(LIB_NAME).so.$(LIB_VERSION) 24 | LIB_SONAME = $(LIB_DIR)/$(LIB_NAME).so.1 25 | LIB_LINK = $(LIB_DIR)/$(LIB_NAME).so 26 | 27 | # Installation directories 28 | PREFIX = /usr/local 29 | LIBDIR = $(PREFIX)/lib 30 | INCLUDEDIR = $(PREFIX)/include/accel 31 | 32 | # HAL library path and file 33 | HAL_LIB = ../hal/lib/libhal_accelerator.a 34 | 35 | .PHONY: all clean install uninstall test 36 | 37 | all: $(LIB_SO) 38 | 39 | # Create directories 40 | $(OBJ_DIR): 41 | mkdir -p $(OBJ_DIR) 42 | 43 | $(LIB_DIR): 44 | mkdir -p $(LIB_DIR) 45 | 46 | # Compile source files 47 | $(OBJ_DIR)/%.o: $(SRC_DIR)/%.c | $(OBJ_DIR) 48 | $(CC) $(CPPFLAGS) $(CFLAGS) -c $< -o $@ 49 | 50 | # Create shared library 51 | $(LIB_SO): $(OBJS) | $(LIB_DIR) 52 | $(CC) -shared -Wl,-soname,$(SONAME) -o $@ $(OBJS) $(HAL_LIB) 53 | cd $(LIB_DIR) && ln -sf $(LIB_NAME).so.$(LIB_VERSION) $(LIB_NAME).so.1 54 | cd $(LIB_DIR) && ln -sf $(LIB_NAME).so.1 $(LIB_NAME).so 55 | 56 | # Build and run tests 57 | test: $(TEST_BINS) 58 | for test in $(TEST_BINS); do ./$$test; done 59 | 60 | $(TEST_DIR)/%: $(TEST_DIR)/%.c $(LIB_SO) 61 | $(CC) $(CPPFLAGS) $(CFLAGS) $< -o $@ -L$(LIB_DIR) -Wl,-rpath,$(shell pwd)/$(LIB_DIR) -laccel_driver ../hal/lib/libhal_accelerator.a 62 | 63 | # Install library and headers 64 | install: $(LIB_SO) 65 | install -d $(DESTDIR)$(LIBDIR) 66 | install -m 755 $(LIB_SO) $(DESTDIR)$(LIBDIR) 67 | ln -sf $(LIB_SO) $(DESTDIR)$(LIBDIR)/$(LIB_SONAME) 68 | ln -sf $(LIB_SONAME) $(DESTDIR)$(LIBDIR)/$(LIB_LINK) 69 | install -d $(DESTDIR)$(INCLUDEDIR) 70 | install -m 644 include/*.h $(DESTDIR)$(INCLUDEDIR) 71 | ldconfig 72 | 73 | # Uninstall library and headers 74 | uninstall: 75 | rm -f $(DESTDIR)$(LIBDIR)/$(LIB_SO) 76 | rm -f $(DESTDIR)$(LIBDIR)/$(LIB_SONAME) 77 | rm -f $(DESTDIR)$(LIBDIR)/$(LIB_LINK) 78 | rm -rf $(DESTDIR)$(INCLUDEDIR) 79 | 80 | # Clean build artifacts 81 | clean: 82 | rm -rf $(OBJ_DIR) 83 | rm -rf $(LIB_DIR) 84 | rm -f $(TEST_BINS) 85 | -------------------------------------------------------------------------------- /04_software/driver/include/accel.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file accel.h 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Main interface for accelerator driver 6 | * @version 1.0.0 7 | * @date 2020-03-30 8 | */ 9 | 10 | #ifndef ACCEL_H 11 | #define ACCEL_H 12 | 13 | #ifdef __cplusplus 14 | extern "C" { 15 | #endif 16 | 17 | #include "accel_config.h" 18 | #include "accel_types.h" 19 | /** 20 | * @brief Initialize the accelerator 21 | * @param device_path Path to device file 22 | * @return Status code 23 | */ 24 | accel_status_t accel_init(const char* device_path); 25 | 26 | /** 27 | * @brief Clean up accelerator resources 28 | */ 29 | void accel_cleanup(void); 30 | 31 | /** 32 | * @brief Allocate memory buffer for accelerator operations 33 | * @param size Buffer size in bytes 34 | * @return Buffer descriptor, NULL on failure 35 | */ 36 | accel_buffer_t* accel_alloc_buffer(uint32_t size); 37 | 38 | /** 39 | * @brief Free allocated memory buffer 40 | * @param buffer Buffer to free 41 | */ 42 | void accel_free_buffer(accel_buffer_t* buffer); 43 | 44 | /** 45 | * @brief Submit operation to accelerator 46 | * @param params Operation parameters 47 | * @return Status code 48 | */ 49 | accel_status_t accel_submit_op(const accel_op_params_t* params); 50 | 51 | /** 52 | * @brief Wait for operation completion 53 | * @param timeout_ms Timeout in milliseconds (0 for infinite) 54 | * @return Status code 55 | */ 56 | accel_status_t accel_wait_complete(uint32_t timeout_ms); 57 | 58 | /** 59 | * @brief Get last error message 60 | * @return Error message string 61 | */ 62 | const char* accel_get_error(void); 63 | #ifdef __cplusplus 64 | } 65 | #endif 66 | #endif /* ACCEL_H */ -------------------------------------------------------------------------------- /04_software/driver/include/accel_config.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file accel_config.h 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Configuration interface for accelerator driver 6 | * @version 1.0.0 7 | * @date 2020-03-30 8 | */ 9 | 10 | #ifndef ACCEL_CONFIG_H 11 | #define ACCEL_CONFIG_H 12 | 13 | #include "accel_types.h" 14 | 15 | /** 16 | * @brief Configuration flags 17 | */ 18 | #define ACCEL_CONFIG_ENABLE_DMA (1 << 0) 19 | #define ACCEL_CONFIG_SYNC_MODE (1 << 1) 20 | #define ACCEL_CONFIG_HIGH_PRIORITY (1 << 2) 21 | 22 | /** 23 | * @brief Device configuration structure 24 | */ 25 | typedef struct { 26 | uint32_t flags; /**< Configuration flags */ 27 | uint32_t num_channels; /**< Number of DMA channels */ 28 | uint32_t max_transfer; /**< Maximum transfer size */ 29 | uint32_t timeout_ms; /**< Operation timeout in milliseconds */ 30 | } accel_config_t; 31 | 32 | /** 33 | * @brief Configure the accelerator device 34 | * @param config Pointer to configuration structure 35 | * @return Status code 36 | */ 37 | accel_status_t accel_configure(const accel_config_t* config); 38 | 39 | /** 40 | * @brief Get current device configuration 41 | * @param config Pointer to store configuration 42 | * @return Status code 43 | */ 44 | accel_status_t accel_get_config(accel_config_t* config); 45 | 46 | /** 47 | * @brief Reset device configuration to defaults 48 | * @return Status code 49 | */ 50 | accel_status_t accel_reset_config(void); 51 | 52 | #endif /* ACCEL_CONFIG_H */ -------------------------------------------------------------------------------- /04_software/driver/include/accel_types.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file accel_types.h 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Common type definitions for accelerator driver 6 | * @version 1.0.0 7 | * @date 2020-03-30 8 | */ 9 | 10 | #ifndef ACCEL_TYPES_H 11 | #define ACCEL_TYPES_H 12 | 13 | #include 14 | 15 | /** 16 | * @brief Operation types supported by the accelerator 17 | */ 18 | typedef enum { 19 | ACCEL_OP_NONE = 0, 20 | ACCEL_OP_MATMUL, /**< Matrix multiplication */ 21 | ACCEL_OP_CONV2D, /**< 2D convolution */ 22 | } accel_op_type_t; 23 | 24 | /** 25 | * @brief Status codes for accelerator operations 26 | */ 27 | typedef enum { 28 | ACCEL_STATUS_OK = 0, 29 | ACCEL_STATUS_ERROR, 30 | ACCEL_STATUS_INVALID_PARAM, 31 | ACCEL_STATUS_NO_MEMORY, 32 | ACCEL_STATUS_TIMEOUT, 33 | ACCEL_STATUS_BUSY, 34 | ACCEL_STATUS_NOT_INITIALIZED 35 | } accel_status_t; 36 | 37 | /** 38 | * @brief Memory buffer descriptor 39 | */ 40 | typedef struct { 41 | void* host_addr; /**< Host virtual address */ 42 | uint64_t dev_addr; /**< Device physical address */ 43 | uint32_t size; /**< Buffer size in bytes */ 44 | } accel_buffer_t; 45 | 46 | /** 47 | * @brief Operation parameters 48 | */ 49 | typedef struct { 50 | accel_op_type_t op_type; /**< Operation type */ 51 | accel_buffer_t input; /**< Input buffer */ 52 | accel_buffer_t output; /**< Output buffer */ 53 | accel_buffer_t weights; /**< Weights buffer */ 54 | uint32_t flags; /**< Operation flags */ 55 | } accel_op_params_t; 56 | 57 | #endif /* ACCEL_TYPES_H */ -------------------------------------------------------------------------------- /04_software/driver/src/accel.c: -------------------------------------------------------------------------------- 1 | /** 2 | * @file accel.c 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Implementation of main accelerator driver interface 6 | * @version 1.0.0 7 | * @date 2020-03-30 8 | */ 9 | 10 | #include "accel.h" 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #include "hal.h" 17 | #include "hal_config.h" 18 | #include "hal_io.h" 19 | #include "hal_mem.h" 20 | 21 | // Driver context 22 | struct { 23 | hal_context_t* hal; 24 | accel_config_t config; 25 | char last_error[256]; 26 | bool initialized; 27 | } g_ctx = {0}; 28 | 29 | // HAL operation codes mapping 30 | #define HAL_OP_MATMUL 0x01 31 | #define HAL_OP_CONV 0x02 32 | 33 | // Convert HAL status to driver status 34 | static accel_status_t convert_hal_status(uint32_t hal_status) { 35 | if (hal_status & HAL_STATUS_ERROR) { 36 | return ACCEL_STATUS_ERROR; 37 | } 38 | if (hal_status & HAL_STATUS_BUSY) { 39 | return ACCEL_STATUS_BUSY; 40 | } 41 | return ACCEL_STATUS_OK; 42 | } 43 | 44 | accel_status_t accel_init(const char* device_path) { 45 | if (!device_path) { 46 | return ACCEL_STATUS_INVALID_PARAM; 47 | } 48 | 49 | if (g_ctx.initialized) { 50 | return ACCEL_STATUS_OK; 51 | } 52 | 53 | // Initialize HAL 54 | g_ctx.hal = hal_init(device_path); 55 | if (!g_ctx.hal) { 56 | snprintf(g_ctx.last_error, sizeof(g_ctx.last_error), 57 | "Failed to initialize HAL"); 58 | return ACCEL_STATUS_ERROR; 59 | } 60 | 61 | g_ctx.initialized = true; 62 | return ACCEL_STATUS_OK; 63 | } 64 | 65 | void accel_cleanup(void) { 66 | if (g_ctx.initialized) { 67 | hal_cleanup(g_ctx.hal); 68 | memset(&g_ctx, 0, sizeof(g_ctx)); 69 | } 70 | } 71 | 72 | accel_buffer_t* accel_alloc_buffer(uint32_t size) { 73 | if (!g_ctx.initialized) { 74 | return NULL; 75 | } 76 | 77 | accel_buffer_t* buffer = malloc(sizeof(accel_buffer_t)); 78 | if (!buffer) { 79 | snprintf(g_ctx.last_error, sizeof(g_ctx.last_error), 80 | "Failed to allocate buffer descriptor"); 81 | return NULL; 82 | } 83 | 84 | buffer->host_addr = hal_mem_alloc(g_ctx.hal, size); 85 | if (!buffer->host_addr) { 86 | free(buffer); 87 | snprintf(g_ctx.last_error, sizeof(g_ctx.last_error), 88 | "Failed to allocate device memory"); 89 | return NULL; 90 | } 91 | 92 | buffer->dev_addr = hal_virt_to_phys(g_ctx.hal, buffer->host_addr); 93 | buffer->size = size; 94 | 95 | return buffer; 96 | } 97 | 98 | void accel_free_buffer(accel_buffer_t* buffer) { 99 | if (buffer && g_ctx.initialized) { 100 | if (buffer->host_addr) { 101 | hal_mem_free(g_ctx.hal, buffer->host_addr); 102 | } 103 | free(buffer); 104 | } 105 | } 106 | 107 | accel_status_t accel_submit_op(const accel_op_params_t* params) { 108 | if (!g_ctx.initialized) { 109 | return ACCEL_STATUS_NOT_INITIALIZED; 110 | } 111 | 112 | if (!params) { 113 | return ACCEL_STATUS_INVALID_PARAM; 114 | } 115 | 116 | hal_systolic_config_t systolic_cfg = {0}; 117 | hal_lsu_config_t lsu_cfg = {0}; 118 | 119 | switch (params->op_type) { 120 | case ACCEL_OP_MATMUL: 121 | systolic_cfg.opcode = HAL_OP_MATMUL; 122 | systolic_cfg.control = params->flags; 123 | if (!hal_configure_systolic(g_ctx.hal, &systolic_cfg)) { 124 | return ACCEL_STATUS_ERROR; 125 | } 126 | break; 127 | 128 | case ACCEL_OP_CONV2D: 129 | systolic_cfg.opcode = HAL_OP_CONV; 130 | systolic_cfg.control = params->flags; 131 | if (!hal_configure_systolic(g_ctx.hal, &systolic_cfg)) { 132 | return ACCEL_STATUS_ERROR; 133 | } 134 | break; 135 | 136 | default: 137 | return ACCEL_STATUS_INVALID_PARAM; 138 | } 139 | 140 | // Configure LSU for data transfer 141 | lsu_cfg.src_addr = params->input.dev_addr; 142 | lsu_cfg.dst_addr = params->output.dev_addr; 143 | lsu_cfg.length = params->input.size; 144 | if (!hal_configure_lsu(g_ctx.hal, &lsu_cfg)) { 145 | return ACCEL_STATUS_ERROR; 146 | } 147 | 148 | return ACCEL_STATUS_OK; 149 | } 150 | 151 | accel_status_t accel_wait_complete(uint32_t timeout_ms) { 152 | if (!g_ctx.initialized) { 153 | return ACCEL_STATUS_NOT_INITIALIZED; 154 | } 155 | 156 | if (!hal_wait_for_ready(g_ctx.hal)) { 157 | snprintf(g_ctx.last_error, sizeof(g_ctx.last_error), "Operation timed out"); 158 | return ACCEL_STATUS_TIMEOUT; 159 | } 160 | 161 | return convert_hal_status(hal_get_status(g_ctx.hal)); 162 | } 163 | 164 | const char* accel_get_error(void) { return g_ctx.last_error; } -------------------------------------------------------------------------------- /04_software/driver/src/accel_config.c: -------------------------------------------------------------------------------- 1 | /** 2 | * @file accel_config.c 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Implementation of configuration interface for accelerator driver 6 | * @version 1.0.0 7 | * @date 2020-03-30 8 | */ 9 | 10 | #include "accel_config.h" 11 | 12 | #include 13 | #include 14 | 15 | #include "hal.h" 16 | 17 | // Reference to global context from accel.c 18 | extern struct { 19 | hal_context_t* hal; 20 | accel_config_t config; 21 | char last_error[256]; 22 | bool initialized; 23 | } g_ctx; 24 | 25 | accel_status_t accel_configure(const accel_config_t* config) { 26 | if (!g_ctx.initialized) { 27 | return ACCEL_STATUS_NOT_INITIALIZED; 28 | } 29 | 30 | if (!config) { 31 | return ACCEL_STATUS_INVALID_PARAM; 32 | } 33 | 34 | // Store configuration 35 | memcpy(&g_ctx.config, config, sizeof(accel_config_t)); 36 | 37 | return ACCEL_STATUS_OK; 38 | } 39 | 40 | accel_status_t accel_get_config(accel_config_t* config) { 41 | if (!g_ctx.initialized) { 42 | return ACCEL_STATUS_NOT_INITIALIZED; 43 | } 44 | 45 | if (!config) { 46 | return ACCEL_STATUS_INVALID_PARAM; 47 | } 48 | 49 | // Return current configuration 50 | memcpy(config, &g_ctx.config, sizeof(accel_config_t)); 51 | 52 | return ACCEL_STATUS_OK; 53 | } 54 | 55 | accel_status_t accel_reset_config(void) { 56 | if (!g_ctx.initialized) { 57 | return ACCEL_STATUS_NOT_INITIALIZED; 58 | } 59 | 60 | // Reset to default configuration 61 | memset(&g_ctx.config, 0, sizeof(accel_config_t)); 62 | 63 | // Set default values 64 | g_ctx.config.flags = ACCEL_CONFIG_ENABLE_DMA; 65 | g_ctx.config.num_channels = 1; 66 | g_ctx.config.max_transfer = 0x1000000; // 16MB 67 | g_ctx.config.timeout_ms = 1000; // 1 second 68 | 69 | return ACCEL_STATUS_OK; 70 | } -------------------------------------------------------------------------------- /04_software/driver/test/accel_test.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file accel_test.h 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Test framework for accelerator driver unit testing 6 | * @version 1.0.0 7 | * @date 2020-03-30 8 | */ 9 | 10 | #ifndef ACCEL_TEST_H 11 | #define ACCEL_TEST_H 12 | 13 | #include 14 | 15 | // Test statistics 16 | static int total_tests = 0; 17 | static int passed_tests = 0; 18 | static int failed_tests = 0; 19 | 20 | // Basic assertions 21 | #define ACCEL_TEST_ASSERT(condition) \ 22 | do { \ 23 | total_tests++; \ 24 | if (condition) { \ 25 | passed_tests++; \ 26 | printf("PASS: %s:%d\n", __FILE__, __LINE__); \ 27 | } else { \ 28 | failed_tests++; \ 29 | printf("FAIL: %s:%d\n", __FILE__, __LINE__); \ 30 | } \ 31 | } while (0) 32 | 33 | // Equality assertions 34 | #define ACCEL_TEST_ASSERT_EQUAL(expected, actual) \ 35 | ACCEL_TEST_ASSERT((expected) == (actual)) 36 | 37 | // Pointer assertions 38 | #define ACCEL_TEST_ASSERT_NOT_NULL(ptr) ACCEL_TEST_ASSERT((ptr) != NULL) 39 | 40 | #define ACCEL_TEST_ASSERT_NULL(ptr) ACCEL_TEST_ASSERT((ptr) == NULL) 41 | 42 | // Test control 43 | #define ACCEL_TEST_BEGIN() \ 44 | do { \ 45 | total_tests = 0; \ 46 | passed_tests = 0; \ 47 | failed_tests = 0; \ 48 | printf("\nStarting tests...\n"); \ 49 | } while (0) 50 | 51 | #define ACCEL_TEST_END() \ 52 | do { \ 53 | printf("\nTest Summary:\n"); \ 54 | printf("Total: %d\n", total_tests); \ 55 | printf("Passed: %d\n", passed_tests); \ 56 | printf("Failed: %d\n", failed_tests); \ 57 | return failed_tests; \ 58 | } while (0) 59 | 60 | #define ACCEL_TEST_RUN(test_func) \ 61 | do { \ 62 | printf("\nRunning %s...\n", #test_func); \ 63 | test_func(); \ 64 | } while (0) 65 | 66 | #endif /* ACCEL_TEST_H */ -------------------------------------------------------------------------------- /04_software/driver/test/test_accel.c: -------------------------------------------------------------------------------- 1 | /** 2 | * @file test_accel.c 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Unit tests for basic accelerator driver functionality 6 | * @version 1.0.0 7 | * @date 2020-03-30 8 | */ 9 | 10 | #include "accel.h" 11 | #include "accel_test.h" 12 | 13 | static void test_init_cleanup(void) { 14 | // Test initialization 15 | accel_status_t status = accel_init("/dev/accelerator0"); 16 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_OK); 17 | 18 | // Test double initialization 19 | status = accel_init("/dev/accelerator0"); 20 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_OK); 21 | 22 | // Test invalid device path 23 | status = accel_init(NULL); 24 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_INVALID_PARAM); 25 | 26 | // Cleanup 27 | accel_cleanup(); 28 | } 29 | 30 | static void test_buffer_management(void) { 31 | // Initialize 32 | accel_status_t status = accel_init("/dev/accelerator0"); 33 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_OK); 34 | 35 | // Test buffer allocation 36 | accel_buffer_t* buffer = accel_alloc_buffer(4096); 37 | ACCEL_TEST_ASSERT_NOT_NULL(buffer); 38 | ACCEL_TEST_ASSERT_NOT_NULL(buffer->host_addr); 39 | ACCEL_TEST_ASSERT(buffer->dev_addr != 0); 40 | ACCEL_TEST_ASSERT_EQUAL(4096, buffer->size); 41 | 42 | // Test zero size allocation 43 | accel_buffer_t* zero_buffer = accel_alloc_buffer(0); 44 | ACCEL_TEST_ASSERT_NULL(zero_buffer); 45 | 46 | // Free buffer 47 | accel_free_buffer(buffer); 48 | 49 | // Test NULL free 50 | accel_free_buffer(NULL); 51 | 52 | accel_cleanup(); 53 | } 54 | 55 | static void test_operation_submission(void) { 56 | // Initialize 57 | accel_status_t status = accel_init("/dev/accelerator0"); 58 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_OK); 59 | 60 | // Allocate buffers 61 | accel_buffer_t* input = accel_alloc_buffer(1024); 62 | accel_buffer_t* output = accel_alloc_buffer(1024); 63 | accel_buffer_t* weights = accel_alloc_buffer(1024); 64 | ACCEL_TEST_ASSERT_NOT_NULL(input); 65 | ACCEL_TEST_ASSERT_NOT_NULL(output); 66 | ACCEL_TEST_ASSERT_NOT_NULL(weights); 67 | 68 | // Configure operation 69 | accel_op_params_t params = {.op_type = ACCEL_OP_MATMUL, 70 | .input = *input, 71 | .output = *output, 72 | .weights = *weights, 73 | .flags = 0}; 74 | 75 | // Submit operation 76 | status = accel_submit_op(¶ms); 77 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_OK); 78 | 79 | // Wait for completion 80 | status = accel_wait_complete(1000); 81 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_OK); 82 | 83 | // Test invalid parameters 84 | status = accel_submit_op(NULL); 85 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_INVALID_PARAM); 86 | 87 | // Cleanup 88 | accel_free_buffer(input); 89 | accel_free_buffer(output); 90 | accel_free_buffer(weights); 91 | accel_cleanup(); 92 | } 93 | 94 | static void test_error_handling(void) { 95 | // Test operations before initialization 96 | accel_buffer_t* buffer = accel_alloc_buffer(1024); 97 | ACCEL_TEST_ASSERT_NULL(buffer); 98 | 99 | accel_op_params_t params = {0}; 100 | accel_status_t status = accel_submit_op(¶ms); 101 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_NOT_INITIALIZED); 102 | 103 | // Initialize 104 | status = accel_init("/dev/accelerator0"); 105 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_OK); 106 | 107 | // Test error message 108 | const char* error = accel_get_error(); 109 | ACCEL_TEST_ASSERT_NOT_NULL(error); 110 | 111 | accel_cleanup(); 112 | } 113 | 114 | int main(void) { 115 | ACCEL_TEST_BEGIN(); 116 | 117 | ACCEL_TEST_RUN(test_init_cleanup); 118 | ACCEL_TEST_RUN(test_buffer_management); 119 | ACCEL_TEST_RUN(test_operation_submission); 120 | ACCEL_TEST_RUN(test_error_handling); 121 | 122 | ACCEL_TEST_END(); 123 | } -------------------------------------------------------------------------------- /04_software/driver/test/test_config.c: -------------------------------------------------------------------------------- 1 | /** 2 | * @file test_config.c 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Unit tests for accelerator configuration 6 | * @version 1.0.0 7 | * @date 2020-03-30 8 | */ 9 | #include "accel.h" 10 | #include "accel_config.h" 11 | #include "accel_test.h" 12 | 13 | static void test_basic_config(void) { 14 | // Initialize 15 | accel_status_t status = accel_init("/dev/accelerator0"); 16 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_OK); 17 | 18 | // Test configuration 19 | accel_config_t config = { 20 | .flags = ACCEL_CONFIG_ENABLE_DMA | ACCEL_CONFIG_SYNC_MODE, 21 | .num_channels = 2, 22 | .max_transfer = 0x1000000, 23 | .timeout_ms = 5000}; 24 | 25 | status = accel_configure(&config); 26 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_OK); 27 | 28 | // Read back configuration 29 | accel_config_t read_config = {0}; 30 | status = accel_get_config(&read_config); 31 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_OK); 32 | 33 | // Verify configuration 34 | ACCEL_TEST_ASSERT_EQUAL(config.flags, read_config.flags); 35 | ACCEL_TEST_ASSERT_EQUAL(config.num_channels, read_config.num_channels); 36 | ACCEL_TEST_ASSERT_EQUAL(config.max_transfer, read_config.max_transfer); 37 | ACCEL_TEST_ASSERT_EQUAL(config.timeout_ms, read_config.timeout_ms); 38 | 39 | accel_cleanup(); 40 | } 41 | 42 | static void test_config_reset(void) { 43 | // Initialize 44 | accel_status_t status = accel_init("/dev/accelerator0"); 45 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_OK); 46 | 47 | // Set custom configuration 48 | accel_config_t config = {.flags = ACCEL_CONFIG_HIGH_PRIORITY, 49 | .num_channels = 4, 50 | .max_transfer = 0x2000000, 51 | .timeout_ms = 10000}; 52 | 53 | status = accel_configure(&config); 54 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_OK); 55 | 56 | // Reset configuration 57 | status = accel_reset_config(); 58 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_OK); 59 | 60 | // Read back configuration 61 | accel_config_t read_config = {0}; 62 | status = accel_get_config(&read_config); 63 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_OK); 64 | 65 | // Verify default values 66 | ACCEL_TEST_ASSERT_EQUAL(ACCEL_CONFIG_ENABLE_DMA, read_config.flags); 67 | ACCEL_TEST_ASSERT_EQUAL(1, read_config.num_channels); 68 | ACCEL_TEST_ASSERT_EQUAL(0x1000000, read_config.max_transfer); 69 | ACCEL_TEST_ASSERT_EQUAL(1000, read_config.timeout_ms); 70 | 71 | accel_cleanup(); 72 | } 73 | 74 | static void test_invalid_config(void) { 75 | // Test before initialization 76 | accel_config_t config = {0}; 77 | accel_status_t status = accel_configure(&config); 78 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_NOT_INITIALIZED); 79 | 80 | // Initialize 81 | status = accel_init("/dev/accelerator0"); 82 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_OK); 83 | 84 | // Test NULL config 85 | status = accel_configure(NULL); 86 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_INVALID_PARAM); 87 | 88 | status = accel_get_config(NULL); 89 | ACCEL_TEST_ASSERT(status == ACCEL_STATUS_INVALID_PARAM); 90 | 91 | accel_cleanup(); 92 | } 93 | 94 | int main(void) { 95 | ACCEL_TEST_BEGIN(); 96 | 97 | ACCEL_TEST_RUN(test_basic_config); 98 | ACCEL_TEST_RUN(test_config_reset); 99 | ACCEL_TEST_RUN(test_invalid_config); 100 | 101 | ACCEL_TEST_END(); 102 | } -------------------------------------------------------------------------------- /04_software/hal/Makefile: -------------------------------------------------------------------------------- 1 | # Compiler and flags 2 | CC = gcc 3 | CFLAGS = -Wall -Wextra -Werror -O2 4 | INCLUDE = -Iinclude -Itest 5 | 6 | # Directories 7 | SRC_DIR = src 8 | INC_DIR = include 9 | TEST_DIR = test 10 | OBJ_DIR = obj 11 | LIB_DIR = lib 12 | 13 | # Library name and path 14 | LIB_NAME = hal_accelerator 15 | LIB = $(LIB_DIR)/lib$(LIB_NAME).a 16 | 17 | # Source files 18 | SRCS = $(wildcard $(SRC_DIR)/*.c) 19 | OBJS = $(SRCS:$(SRC_DIR)/%.c=$(OBJ_DIR)/%.o) 20 | 21 | # Test files 22 | TEST_SRCS = $(wildcard $(TEST_DIR)/*.c) 23 | TEST_BINS = $(TEST_SRCS:$(TEST_DIR)/%.c=$(TEST_DIR)/bin/%) 24 | 25 | # Dependencies 26 | HAL_DEPS = hal_base.o hal_config.o hal_io.o hal_mem.o 27 | 28 | .PHONY: all clean test dirs 29 | 30 | all: dirs $(LIB) 31 | 32 | # Create necessary directories 33 | dirs: 34 | @mkdir -p $(OBJ_DIR) 35 | @mkdir -p $(LIB_DIR) 36 | @mkdir -p $(TEST_DIR)/bin 37 | 38 | # Build object files 39 | $(OBJ_DIR)/%.o: $(SRC_DIR)/%.c 40 | $(CC) $(CFLAGS) $(INCLUDE) -c -o $@ $< 41 | 42 | # Build static library 43 | $(LIB): $(addprefix $(OBJ_DIR)/,$(HAL_DEPS)) 44 | ar rcs $@ $^ 45 | 46 | # Individual test targets 47 | $(TEST_DIR)/bin/%: $(TEST_DIR)/%.c $(LIB) 48 | $(CC) $(CFLAGS) $(INCLUDE) -o $@ $< $(LIB) 49 | 50 | # Build and run tests 51 | test: dirs $(TEST_DIR)/bin/test_hal_mem $(TEST_DIR)/bin/test_hal_io $(TEST_DIR)/bin/test_hal_init 52 | @echo "Running tests..." 53 | @for test in $(TEST_DIR)/bin/*; do \ 54 | if [ -x $$test ]; then \ 55 | ./$$test; \ 56 | fi \ 57 | done 58 | 59 | # Clean build artifacts 60 | clean: 61 | rm -rf $(OBJ_DIR) 62 | rm -rf $(LIB_DIR) 63 | rm -rf $(TEST_DIR)/bin 64 | 65 | # Show help 66 | help: 67 | @echo "Available targets:" 68 | @echo " all - Build the HAL library (default)" 69 | @echo " test - Build and run tests" 70 | @echo " clean - Remove build artifacts" 71 | @echo " help - Show this help message" 72 | -------------------------------------------------------------------------------- /04_software/hal/include/hal.h: -------------------------------------------------------------------------------- 1 | #ifndef HAL_ACCELERATOR_H 2 | #define HAL_ACCELERATOR_H 3 | 4 | #include "hal_base.h" 5 | #include "hal_config.h" 6 | #include "hal_io.h" 7 | 8 | #endif /* HAL_ACCELERATOR_H */ -------------------------------------------------------------------------------- /04_software/hal/include/hal_base.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file hal_base.h 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Hardware abstraction layer base definitions 6 | * @version 1.0.0 7 | * @date 2020-03-28 8 | */ 9 | 10 | #ifndef HAL_BASE_H 11 | #define HAL_BASE_H 12 | 13 | #include 14 | #include 15 | #include 16 | 17 | // Status definitions 18 | #define HAL_STATUS_READY 0x1 19 | #define HAL_STATUS_BUSY 0x2 20 | #define HAL_STATUS_COMPLETE 0x4 21 | #define HAL_STATUS_ERROR 0x8 22 | 23 | // Memory region definitions 24 | #define HAL_ACCEL_MEM_BASE 0x30000000 25 | #define HAL_ACCEL_MEM_SIZE (256 * 1024 * 1024) // 256MB 26 | 27 | /** 28 | * @brief HAL context structure 29 | */ 30 | struct hal_context { 31 | int fd; /**< File descriptor for device */ 32 | uint32_t status; /**< Current hardware status */ 33 | void* mapped_memory; /**< Memory mapped region for registers */ 34 | void* accel_memory_base; /**< Base of mapped accelerator memory */ 35 | size_t accel_memory_size; /**< Size of mapped accelerator memory */ 36 | void* mem_ctx; /**< Memory management context */ 37 | }; 38 | 39 | typedef struct hal_context hal_context_t; 40 | 41 | // Base HAL operations 42 | hal_context_t* hal_init(const char* device_path); 43 | void hal_cleanup(hal_context_t* ctx); 44 | 45 | #endif /* HAL_BASE_H */ -------------------------------------------------------------------------------- /04_software/hal/include/hal_config.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file hal_config.h 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Hardware configuration structures and functions 6 | * @version 1.0.0 7 | * @date 2020-03-28 8 | */ 9 | 10 | #ifndef HAL_CONFIG_H 11 | #define HAL_CONFIG_H 12 | 13 | #include 14 | 15 | #include "hal_base.h" 16 | 17 | /** 18 | * @brief LSU (Load-Store Unit) configuration structure 19 | */ 20 | typedef struct __attribute__((packed)) { 21 | uint32_t opcode; /**< Operation code */ 22 | uint64_t src_addr; /**< Source memory address */ 23 | uint64_t dst_addr; /**< Destination memory address */ 24 | uint32_t length; /**< Data transfer length */ 25 | uint32_t control; /**< Control flags */ 26 | uint32_t status; /**< Operation status */ 27 | } hal_lsu_config_t; 28 | 29 | /** 30 | * @brief Systolic array configuration structure 31 | */ 32 | typedef struct __attribute__((packed)) { 33 | uint32_t opcode; /**< Operation code (conv, matmul, etc.) */ 34 | uint32_t in_height; /**< Input height */ 35 | uint32_t in_width; /**< Input width */ 36 | uint32_t in_channels; /**< Input channels */ 37 | uint32_t out_height; /**< Output height */ 38 | uint32_t out_width; /**< Output width */ 39 | uint32_t out_channels; /**< Output channels */ 40 | uint32_t stride; /**< Stride value */ 41 | uint32_t control; /**< Control flags (ReLU, quantization, etc.) */ 42 | uint32_t status; /**< Operation status */ 43 | } hal_systolic_config_t; 44 | 45 | /** 46 | * @brief IMG2COL configuration structure 47 | */ 48 | typedef struct __attribute__((packed)) { 49 | uint32_t opcode; /**< Operation code */ 50 | uint32_t in_height; /**< Input image height */ 51 | uint32_t in_width; /**< Input image width */ 52 | uint32_t in_channels; /**< Input image channels */ 53 | uint32_t kernel_size; /**< Convolution kernel size */ 54 | uint32_t stride; /**< Stride value */ 55 | uint32_t pad; /**< Padding size */ 56 | uint32_t control; /**< Control flags */ 57 | uint32_t status; /**< Operation status */ 58 | } hal_img2col_config_t; 59 | 60 | /** 61 | * @brief Controller instruction register structure 62 | * 63 | * This structure represents the hardware control registers. The actual address 64 | * mapping is handled by the device driver through /dev/accelerator. 65 | */ 66 | typedef struct __attribute__((packed)) { 67 | uint32_t opcode; /**< Operation code */ 68 | uint64_t src_addr; /**< Source address */ 69 | uint64_t dst_addr; /**< Destination address */ 70 | uint32_t length; /**< Data length */ 71 | uint32_t control; /**< Control signals */ 72 | uint32_t status; /**< Operation status */ 73 | 74 | union { 75 | hal_lsu_config_t lsu; /**< LSU operation */ 76 | hal_systolic_config_t systolic_array; /**< Convolution operation */ 77 | hal_img2col_config_t img2col; /**< Image conversion operation */ 78 | } ir_data; 79 | } hal_controller_ir_t; 80 | 81 | /** 82 | * @brief Configure the LSU unit 83 | * @param ctx HAL context 84 | * @param config LSU configuration 85 | * @return true if successful, false on error 86 | */ 87 | bool hal_configure_lsu(hal_context_t* ctx, const hal_lsu_config_t* config); 88 | 89 | /** 90 | * @brief Configure the systolic array 91 | * @param ctx HAL context 92 | * @param config Systolic array configuration 93 | * @return true if successful, false on error 94 | */ 95 | bool hal_configure_systolic(hal_context_t* ctx, 96 | const hal_systolic_config_t* config); 97 | 98 | /** 99 | * @brief Configure the IMG2COL unit 100 | * @param ctx HAL context 101 | * @param config IMG2COL configuration 102 | * @return true if successful, false on error 103 | */ 104 | bool hal_configure_img2col(hal_context_t* ctx, 105 | const hal_img2col_config_t* config); 106 | 107 | #endif /* HAL_CONFIG_H */ -------------------------------------------------------------------------------- /04_software/hal/include/hal_io.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file hal_io.h 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief I/O operations interface for hardware accelerator 6 | * @version 1.0.0 7 | * @date 2020-03-28 8 | */ 9 | 10 | #ifndef HAL_IO_H 11 | #define HAL_IO_H 12 | 13 | #include 14 | #include 15 | 16 | #include "hal_base.h" 17 | 18 | /** 19 | * @brief Wait for hardware to be ready 20 | * @param ctx HAL context 21 | * @return true if ready, false on timeout or error 22 | */ 23 | bool hal_wait_for_ready(hal_context_t* ctx); 24 | 25 | /** 26 | * @brief Check if hardware is ready 27 | * @param ctx HAL context 28 | * @return true if ready, false otherwise 29 | */ 30 | bool hal_is_ready(hal_context_t* ctx); 31 | 32 | /** 33 | * @brief Check if hardware is busy 34 | * @param ctx HAL context 35 | * @return true if busy, false otherwise 36 | */ 37 | bool hal_is_busy(hal_context_t* ctx); 38 | 39 | /** 40 | * @brief Check if hardware is in error state 41 | * @param ctx HAL context 42 | * @return true if error, false otherwise 43 | */ 44 | bool hal_is_error(hal_context_t* ctx); 45 | 46 | /** 47 | * @brief Get current hardware status 48 | * @param ctx HAL context 49 | * @return Current status value 50 | */ 51 | uint32_t hal_get_status(hal_context_t* ctx); 52 | 53 | /** 54 | * @brief Set hardware status 55 | * @param ctx HAL context 56 | * @param status New status value 57 | */ 58 | void hal_set_status(hal_context_t* ctx, uint32_t status); 59 | 60 | #endif /* HAL_IO_H */ -------------------------------------------------------------------------------- /04_software/hal/include/hal_mem.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file hal_mem.h 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Memory management interface for hardware accelerator 6 | * @version 1.0.0 7 | * @date 2020-03-28 8 | */ 9 | 10 | #ifndef HAL_MEM_H 11 | #define HAL_MEM_H 12 | 13 | #include 14 | #include 15 | 16 | #include "hal_base.h" 17 | 18 | // Memory alignment requirement 19 | #define HAL_MEM_ALIGN 64 // 64-byte alignment 20 | 21 | /** 22 | * @brief Initialize memory management subsystem 23 | * @param ctx HAL context 24 | * @param base Base address of memory region 25 | * @param size Size of memory region 26 | * @return true if successful, false on error 27 | */ 28 | bool hal_mem_init(hal_context_t* ctx, void* base, size_t size); 29 | 30 | /** 31 | * @brief Clean up memory management subsystem 32 | * @param ctx HAL context 33 | */ 34 | void hal_mem_cleanup(hal_context_t* ctx); 35 | 36 | /** 37 | * @brief Allocate memory from accelerator memory region 38 | * @param ctx HAL context 39 | * @param size Size of memory to allocate 40 | * @return Virtual address of allocated memory, or NULL on failure 41 | */ 42 | void* hal_mem_alloc(hal_context_t* ctx, size_t size); 43 | 44 | /** 45 | * @brief Free previously allocated accelerator memory 46 | * @param ctx HAL context 47 | * @param ptr Virtual address to free 48 | */ 49 | void hal_mem_free(hal_context_t* ctx, void* ptr); 50 | 51 | /** 52 | * @brief Convert virtual address to physical address 53 | * @param ctx HAL context 54 | * @param vaddr Virtual address to convert 55 | * @return Physical address, or 0 on error 56 | */ 57 | uint64_t hal_virt_to_phys(hal_context_t* ctx, void* vaddr); 58 | 59 | /** 60 | * @brief Get available memory size 61 | * @param ctx HAL context 62 | * @return Total size of free memory blocks 63 | */ 64 | size_t hal_mem_available(hal_context_t* ctx); 65 | 66 | #endif /* HAL_MEM_H */ -------------------------------------------------------------------------------- /04_software/hal/src/hal_base.c: -------------------------------------------------------------------------------- 1 | /** 2 | * @file hal_base.c 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Implementation of base HAL functionality for hardware accelerator 6 | * @version 1.0.0 7 | * @date 2020-03-28 8 | */ 9 | 10 | #include "hal_base.h" 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include "hal_mem.h" 18 | 19 | hal_context_t* hal_init(const char* device_path) { 20 | hal_context_t* ctx = (hal_context_t*)malloc(sizeof(hal_context_t)); 21 | if (!ctx) { 22 | return NULL; 23 | } 24 | 25 | // Open device file 26 | ctx->fd = open(device_path, O_RDWR); 27 | if (ctx->fd < 0) { 28 | free(ctx); 29 | return NULL; 30 | } 31 | 32 | // Map register space 33 | ctx->mapped_memory = 34 | mmap(NULL, getpagesize(), PROT_READ | PROT_WRITE, MAP_SHARED, ctx->fd, 0); 35 | if (ctx->mapped_memory == MAP_FAILED) { 36 | close(ctx->fd); 37 | free(ctx); 38 | return NULL; 39 | } 40 | 41 | // Map accelerator memory region 42 | ctx->accel_memory_base = 43 | mmap(NULL, HAL_ACCEL_MEM_SIZE, PROT_READ | PROT_WRITE, MAP_SHARED, 44 | ctx->fd, HAL_ACCEL_MEM_BASE); 45 | if (ctx->accel_memory_base == MAP_FAILED) { 46 | munmap(ctx->mapped_memory, getpagesize()); 47 | close(ctx->fd); 48 | free(ctx); 49 | return NULL; 50 | } 51 | 52 | ctx->accel_memory_size = HAL_ACCEL_MEM_SIZE; 53 | ctx->status = HAL_STATUS_READY; 54 | 55 | // Initialize memory management 56 | if (!hal_mem_init(ctx, ctx->accel_memory_base, ctx->accel_memory_size)) { 57 | munmap(ctx->accel_memory_base, ctx->accel_memory_size); 58 | munmap(ctx->mapped_memory, getpagesize()); 59 | close(ctx->fd); 60 | free(ctx); 61 | return NULL; 62 | } 63 | 64 | return ctx; 65 | } 66 | 67 | void hal_cleanup(hal_context_t* ctx) { 68 | if (ctx) { 69 | hal_mem_cleanup(ctx); 70 | if (ctx->mapped_memory) { 71 | munmap(ctx->mapped_memory, getpagesize()); 72 | } 73 | if (ctx->accel_memory_base) { 74 | munmap(ctx->accel_memory_base, ctx->accel_memory_size); 75 | } 76 | if (ctx->fd >= 0) { 77 | close(ctx->fd); 78 | } 79 | free(ctx); 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /04_software/hal/src/hal_config.c: -------------------------------------------------------------------------------- 1 | /** 2 | * @file hal_config.c 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Implementation of configuration operations for hardware accelerator 6 | * @version 1.0.0 7 | * @date 2020-03-28 8 | */ 9 | 10 | #include "hal_config.h" 11 | 12 | #include 13 | #include 14 | 15 | #include "hal_base.h" 16 | #include "hal_io.h" // For hal_wait_for_ready 17 | 18 | /** 19 | * @brief Write configuration to hardware registers 20 | * @param ctx Pointer to HAL context 21 | * @param config Pointer to configuration data 22 | * @return true if successful, false on error 23 | */ 24 | static bool write_config(hal_context_t* ctx, const void* config) { 25 | if (!ctx || !config) { 26 | return false; 27 | } 28 | 29 | /* Wait for hardware to be ready */ 30 | if (!hal_wait_for_ready(ctx)) { 31 | return false; 32 | } 33 | 34 | /* Get the mapped register structure */ 35 | hal_controller_ir_t* ir = (hal_controller_ir_t*)ctx->mapped_memory; 36 | if (!ir) { 37 | return false; 38 | } 39 | 40 | /* Copy configuration data to registers */ 41 | memcpy(ir, config, sizeof(hal_controller_ir_t)); 42 | 43 | return true; 44 | } 45 | 46 | /** 47 | * @brief Configure the LSU unit 48 | * @param ctx Pointer to HAL context 49 | * @param config Pointer to LSU configuration 50 | * @return true if successful, false on error 51 | */ 52 | bool hal_configure_lsu(hal_context_t* ctx, const hal_lsu_config_t* config) { 53 | hal_controller_ir_t ir = {0}; 54 | memcpy(&ir.ir_data.lsu, config, sizeof(hal_lsu_config_t)); 55 | return write_config(ctx, &ir); 56 | } 57 | 58 | /** 59 | * @brief Configure the systolic array 60 | * @param ctx Pointer to HAL context 61 | * @param config Pointer to systolic array configuration 62 | * @return true if successful, false on error 63 | */ 64 | bool hal_configure_systolic(hal_context_t* ctx, 65 | const hal_systolic_config_t* config) { 66 | hal_controller_ir_t ir = {0}; 67 | memcpy(&ir.ir_data.systolic_array, config, sizeof(hal_systolic_config_t)); 68 | return write_config(ctx, &ir); 69 | } 70 | 71 | /** 72 | * @brief Configure the img2col unit 73 | * @param ctx Pointer to HAL context 74 | * @param config Pointer to img2col configuration 75 | * @return true if successful, false on error 76 | */ 77 | bool hal_configure_img2col(hal_context_t* ctx, 78 | const hal_img2col_config_t* config) { 79 | hal_controller_ir_t ir = {0}; 80 | memcpy(&ir.ir_data.img2col, config, sizeof(hal_img2col_config_t)); 81 | return write_config(ctx, &ir); 82 | } 83 | -------------------------------------------------------------------------------- /04_software/hal/src/hal_io.c: -------------------------------------------------------------------------------- 1 | /** 2 | * @file hal_io.c 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Implementation of I/O operations for hardware accelerator 6 | * @version 1.0.0 7 | * @date 2020-03-28 8 | */ 9 | 10 | #include "hal_io.h" 11 | 12 | #include 13 | 14 | bool hal_wait_for_ready(hal_context_t* ctx) { 15 | if (!ctx) { 16 | return false; 17 | } 18 | 19 | // Simple polling implementation 20 | int retries = 100; // Maximum retries 21 | while (retries-- > 0) { 22 | if (hal_is_ready(ctx)) { 23 | return true; 24 | } 25 | usleep(1000); // Wait 1ms between checks 26 | } 27 | return false; 28 | } 29 | 30 | bool hal_is_ready(hal_context_t* ctx) { 31 | if (!ctx) { 32 | return false; 33 | } 34 | return (ctx->status & HAL_STATUS_READY) != 0; 35 | } 36 | 37 | bool hal_is_busy(hal_context_t* ctx) { 38 | if (!ctx) { 39 | return false; 40 | } 41 | return (ctx->status & HAL_STATUS_BUSY) != 0; 42 | } 43 | 44 | bool hal_is_error(hal_context_t* ctx) { 45 | if (!ctx) { 46 | return false; 47 | } 48 | return (ctx->status & HAL_STATUS_ERROR) != 0; 49 | } 50 | 51 | uint32_t hal_get_status(hal_context_t* ctx) { 52 | if (!ctx) { 53 | return 0; 54 | } 55 | return ctx->status; 56 | } 57 | 58 | void hal_set_status(hal_context_t* ctx, uint32_t status) { 59 | if (ctx) { 60 | ctx->status = status; 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /04_software/hal/src/hal_mem.c: -------------------------------------------------------------------------------- 1 | /** 2 | * @file hal_mem.c 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Implementation of memory management for hardware accelerator 6 | * @version 1.0.0 7 | * @date 2020-03-28 8 | */ 9 | 10 | #include "hal_mem.h" 11 | 12 | #include 13 | #include 14 | 15 | /** 16 | * @brief Memory block structure for tracking allocations 17 | */ 18 | struct mem_block { 19 | void* addr; /**< Virtual address of block */ 20 | size_t size; /**< Size of block */ 21 | bool used; /**< Whether block is in use */ 22 | struct mem_block* next; /**< Next block in list */ 23 | }; 24 | 25 | /** 26 | * @brief Memory management context 27 | */ 28 | struct hal_mem_context { 29 | void* base_addr; /**< Base address of memory region */ 30 | size_t total_size; /**< Total size of memory region */ 31 | struct mem_block* blocks; /**< List of memory blocks */ 32 | }; 33 | 34 | /** 35 | * @brief Round up size to alignment boundary 36 | * @param size Size to align 37 | * @return Aligned size 38 | */ 39 | static size_t align_size(size_t size) { 40 | return (size + HAL_MEM_ALIGN - 1) & ~(HAL_MEM_ALIGN - 1); 41 | } 42 | 43 | bool hal_mem_init(hal_context_t* ctx, void* base, size_t size) { 44 | struct hal_mem_context* mem_ctx = malloc(sizeof(struct hal_mem_context)); 45 | if (!mem_ctx) { 46 | return false; 47 | } 48 | 49 | struct mem_block* block = malloc(sizeof(struct mem_block)); 50 | if (!block) { 51 | free(mem_ctx); 52 | return false; 53 | } 54 | 55 | block->addr = base; 56 | block->size = size; 57 | block->used = false; 58 | block->next = NULL; 59 | 60 | mem_ctx->base_addr = base; 61 | mem_ctx->total_size = size; 62 | mem_ctx->blocks = block; 63 | 64 | ctx->mem_ctx = mem_ctx; 65 | return true; 66 | } 67 | 68 | void hal_mem_cleanup(hal_context_t* ctx) { 69 | if (!ctx || !ctx->mem_ctx) { 70 | return; 71 | } 72 | 73 | struct hal_mem_context* mem_ctx = ctx->mem_ctx; 74 | struct mem_block* block = mem_ctx->blocks; 75 | 76 | while (block) { 77 | struct mem_block* next = block->next; 78 | free(block); 79 | block = next; 80 | } 81 | 82 | free(mem_ctx); 83 | ctx->mem_ctx = NULL; 84 | } 85 | 86 | void* hal_mem_alloc(hal_context_t* ctx, size_t size) { 87 | if (!ctx || !ctx->mem_ctx || size == 0) { 88 | return NULL; 89 | } 90 | 91 | struct hal_mem_context* mem_ctx = ctx->mem_ctx; 92 | size = align_size(size); 93 | 94 | struct mem_block* block = mem_ctx->blocks; 95 | struct mem_block* best_fit = NULL; 96 | size_t best_size = (size_t)-1; 97 | 98 | // Find best fit block 99 | while (block) { 100 | if (!block->used && block->size >= size) { 101 | if (block->size < best_size) { 102 | best_fit = block; 103 | best_size = block->size; 104 | } 105 | } 106 | block = block->next; 107 | } 108 | 109 | if (!best_fit) { 110 | return NULL; 111 | } 112 | 113 | // Split block if it's significantly larger 114 | if (best_fit->size > size + sizeof(struct mem_block) + HAL_MEM_ALIGN) { 115 | struct mem_block* new_block = malloc(sizeof(struct mem_block)); 116 | if (!new_block) { 117 | return NULL; 118 | } 119 | 120 | new_block->addr = (char*)best_fit->addr + size; 121 | new_block->size = best_fit->size - size; 122 | new_block->used = false; 123 | new_block->next = best_fit->next; 124 | 125 | best_fit->size = size; 126 | best_fit->next = new_block; 127 | } 128 | 129 | best_fit->used = true; 130 | return best_fit->addr; 131 | } 132 | 133 | void hal_mem_free(hal_context_t* ctx, void* ptr) { 134 | if (!ctx || !ctx->mem_ctx || !ptr) { 135 | return; 136 | } 137 | 138 | struct hal_mem_context* mem_ctx = ctx->mem_ctx; 139 | struct mem_block* block = mem_ctx->blocks; 140 | struct mem_block* prev = NULL; 141 | 142 | // Find the block 143 | while (block && block->addr != ptr) { 144 | prev = block; 145 | block = block->next; 146 | } 147 | 148 | if (!block || !block->used) { 149 | return; 150 | } 151 | 152 | block->used = false; 153 | 154 | // Merge with next block if it's free 155 | while (block->next && !block->next->used) { 156 | struct mem_block* next = block->next; 157 | block->size += next->size; 158 | block->next = next->next; 159 | free(next); 160 | } 161 | 162 | // Merge with previous block if it's free 163 | if (prev && !prev->used) { 164 | prev->size += block->size; 165 | prev->next = block->next; 166 | free(block); 167 | } 168 | } 169 | 170 | uint64_t hal_virt_to_phys(hal_context_t* ctx, void* vaddr) { 171 | if (!ctx || !ctx->mem_ctx) { 172 | return 0; 173 | } 174 | 175 | struct hal_mem_context* mem_ctx = ctx->mem_ctx; 176 | if (vaddr < mem_ctx->base_addr || 177 | vaddr >= (mem_ctx->base_addr + mem_ctx->total_size)) { 178 | return 0; 179 | } 180 | 181 | return HAL_ACCEL_MEM_BASE + ((uint8_t*)vaddr - (uint8_t*)mem_ctx->base_addr); 182 | } 183 | 184 | size_t hal_mem_available(hal_context_t* ctx) { 185 | if (!ctx || !ctx->mem_ctx) { 186 | return 0; 187 | } 188 | 189 | struct hal_mem_context* mem_ctx = ctx->mem_ctx; 190 | size_t available = 0; 191 | struct mem_block* block = mem_ctx->blocks; 192 | 193 | while (block) { 194 | if (!block->used) { 195 | available += block->size; 196 | } 197 | block = block->next; 198 | } 199 | 200 | return available; 201 | } -------------------------------------------------------------------------------- /04_software/hal/test/hal_test.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file hal_test.h 3 | * @brief Test framework for HAL unit testing 4 | */ 5 | 6 | #ifndef HAL_TEST_H 7 | #define HAL_TEST_H 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | // Test statistics 14 | static int total_tests = 0; 15 | static int passed_tests = 0; 16 | static int failed_tests = 0; 17 | 18 | // Basic assertions 19 | #define HAL_TEST_ASSERT(condition) \ 20 | do { \ 21 | total_tests++; \ 22 | if (condition) { \ 23 | passed_tests++; \ 24 | printf("PASS: %s:%d\n", __FILE__, __LINE__); \ 25 | } else { \ 26 | failed_tests++; \ 27 | printf("FAIL: %s:%d\n", __FILE__, __LINE__); \ 28 | } \ 29 | } while (0) 30 | 31 | // Equality assertions 32 | #define HAL_TEST_ASSERT_EQUAL(expected, actual) \ 33 | HAL_TEST_ASSERT((expected) == (actual)) 34 | 35 | #define HAL_TEST_ASSERT_EQUAL_UINT64(expected, actual) \ 36 | HAL_TEST_ASSERT((uint64_t)(expected) == (uint64_t)(actual)) 37 | 38 | // Pointer assertions 39 | #define HAL_TEST_ASSERT_NOT_NULL(ptr) HAL_TEST_ASSERT((ptr) != NULL) 40 | 41 | #define HAL_TEST_ASSERT_NULL(ptr) HAL_TEST_ASSERT((ptr) == NULL) 42 | 43 | // Comparison assertions 44 | #define HAL_TEST_ASSERT_NOT_EQUAL(expected, actual) \ 45 | HAL_TEST_ASSERT((expected) != (actual)) 46 | 47 | #define HAL_TEST_ASSERT_GREATER_OR_EQUAL_UINT64(expected, actual) \ 48 | HAL_TEST_ASSERT((uint64_t)(actual) >= (uint64_t)(expected)) 49 | 50 | #define HAL_TEST_ASSERT_LESS_THAN_UINT64(expected, actual) \ 51 | HAL_TEST_ASSERT((uint64_t)(actual) < (uint64_t)(expected)) 52 | 53 | #define HAL_TEST_ASSERT_LESS_THAN(expected, actual) \ 54 | HAL_TEST_ASSERT((actual) < (expected)) 55 | 56 | // Test control 57 | #define HAL_TEST_BEGIN() \ 58 | do { \ 59 | total_tests = 0; \ 60 | passed_tests = 0; \ 61 | failed_tests = 0; \ 62 | printf("\nStarting tests...\n"); \ 63 | } while (0) 64 | 65 | #define HAL_TEST_END() \ 66 | do { \ 67 | printf("\nTest Summary:\n"); \ 68 | printf("Total: %d\n", total_tests); \ 69 | printf("Passed: %d\n", passed_tests); \ 70 | printf("Failed: %d\n", failed_tests); \ 71 | return failed_tests; \ 72 | } while (0) 73 | 74 | #define HAL_TEST_RUN(test_func) \ 75 | do { \ 76 | printf("\nRunning %s...\n", #test_func); \ 77 | test_func(); \ 78 | } while (0) 79 | 80 | #endif /* HAL_TEST_H */ -------------------------------------------------------------------------------- /04_software/hal/test/test_hal_init.c: -------------------------------------------------------------------------------- 1 | /** 2 | * @file test_hal_init.c 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Unit tests for HAL initialization 6 | * 7 | * @version 1.0.0 8 | * @date 2020-03-28 9 | */ 10 | 11 | #include "hal_base.h" 12 | #include "hal_test.h" 13 | 14 | /** 15 | * @brief Test basic initialization and cleanup 16 | */ 17 | static void test_hal_init_basic(void) { 18 | hal_context_t* ctx = hal_init("/dev/accelerator0"); 19 | HAL_TEST_ASSERT_NOT_NULL(ctx); 20 | HAL_TEST_ASSERT_NOT_NULL(ctx->mapped_memory); 21 | HAL_TEST_ASSERT_NOT_NULL(ctx->accel_memory_base); 22 | HAL_TEST_ASSERT_EQUAL(HAL_ACCEL_MEM_SIZE, ctx->accel_memory_size); 23 | hal_cleanup(ctx); 24 | } 25 | 26 | /** 27 | * @brief Test initialization with invalid parameters 28 | */ 29 | static void test_hal_init_invalid_params(void) { 30 | // Test NULL device path 31 | hal_context_t* ctx = hal_init(NULL); 32 | HAL_TEST_ASSERT_NULL(ctx); 33 | 34 | // Test invalid device path 35 | ctx = hal_init("/dev/nonexistent"); 36 | HAL_TEST_ASSERT_NULL(ctx); 37 | } 38 | 39 | /** 40 | * @brief Test cleanup with invalid parameters 41 | */ 42 | static void test_hal_init_cleanup_null(void) { 43 | // Should not crash with NULL context 44 | hal_cleanup(NULL); 45 | } 46 | 47 | /** 48 | * @brief Test multiple init/cleanup cycles 49 | */ 50 | static void test_hal_init_multiple(void) { 51 | hal_context_t* ctx1 = hal_init("/dev/accelerator0"); 52 | HAL_TEST_ASSERT_NOT_NULL(ctx1); 53 | 54 | hal_context_t* ctx2 = hal_init("/dev/accelerator0"); 55 | HAL_TEST_ASSERT_NOT_NULL(ctx2); 56 | 57 | // Contexts should be different 58 | HAL_TEST_ASSERT(ctx1 != ctx2); 59 | 60 | hal_cleanup(ctx1); 61 | hal_cleanup(ctx2); 62 | } 63 | 64 | int main(void) { 65 | HAL_TEST_BEGIN(); 66 | 67 | HAL_TEST_RUN(test_hal_init_basic); 68 | HAL_TEST_RUN(test_hal_init_invalid_params); 69 | HAL_TEST_RUN(test_hal_init_cleanup_null); 70 | HAL_TEST_RUN(test_hal_init_multiple); 71 | 72 | HAL_TEST_END(); 73 | } -------------------------------------------------------------------------------- /04_software/hal/test/test_hal_io.c: -------------------------------------------------------------------------------- 1 | /** 2 | * @file test_hal_io.c 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Unit tests for HAL I/O operations 6 | * @version 1.0.0 7 | * @date 2020-03-28 8 | */ 9 | 10 | #include "hal_base.h" 11 | #include "hal_io.h" 12 | #include "hal_test.h" 13 | 14 | static hal_context_t* ctx; 15 | 16 | /** 17 | * @brief Set up test environment before each test 18 | */ 19 | static void set_up(void) { 20 | ctx = hal_init("/dev/accelerator0"); 21 | HAL_TEST_ASSERT_NOT_NULL(ctx); 22 | } 23 | 24 | /** 25 | * @brief Clean up test environment after each test 26 | */ 27 | static void tear_down(void) { hal_cleanup(ctx); } 28 | 29 | /** 30 | * @brief Test basic I/O operations 31 | */ 32 | static void test_hal_io_basic(void) { 33 | set_up(); 34 | 35 | // Test wait for ready 36 | HAL_TEST_ASSERT(hal_wait_for_ready(ctx)); 37 | 38 | // Test status checks 39 | HAL_TEST_ASSERT(hal_is_ready(ctx)); 40 | HAL_TEST_ASSERT(!hal_is_busy(ctx)); 41 | HAL_TEST_ASSERT(!hal_is_error(ctx)); 42 | 43 | tear_down(); 44 | } 45 | 46 | /** 47 | * @brief Test status transitions 48 | */ 49 | static void test_hal_io_status(void) { 50 | set_up(); 51 | 52 | // Test initial status 53 | HAL_TEST_ASSERT_EQUAL(HAL_STATUS_READY, hal_get_status(ctx)); 54 | 55 | // Test setting status 56 | hal_set_status(ctx, HAL_STATUS_BUSY); 57 | HAL_TEST_ASSERT_EQUAL(HAL_STATUS_BUSY, hal_get_status(ctx)); 58 | HAL_TEST_ASSERT(hal_is_busy(ctx)); 59 | 60 | hal_set_status(ctx, HAL_STATUS_COMPLETE); 61 | HAL_TEST_ASSERT_EQUAL(HAL_STATUS_COMPLETE, hal_get_status(ctx)); 62 | HAL_TEST_ASSERT(!hal_is_busy(ctx)); 63 | 64 | tear_down(); 65 | } 66 | 67 | /** 68 | * @brief Test error handling 69 | */ 70 | static void test_hal_io_error(void) { 71 | set_up(); 72 | 73 | // Test error status 74 | hal_set_status(ctx, HAL_STATUS_ERROR); 75 | HAL_TEST_ASSERT(hal_is_error(ctx)); 76 | HAL_TEST_ASSERT(!hal_is_ready(ctx)); 77 | HAL_TEST_ASSERT(!hal_is_busy(ctx)); 78 | 79 | // Test error recovery 80 | hal_set_status(ctx, HAL_STATUS_READY); 81 | HAL_TEST_ASSERT(!hal_is_error(ctx)); 82 | HAL_TEST_ASSERT(hal_is_ready(ctx)); 83 | 84 | tear_down(); 85 | } 86 | 87 | /** 88 | * @brief Test invalid parameters 89 | */ 90 | static void test_hal_io_invalid_params(void) { 91 | // Test null context 92 | HAL_TEST_ASSERT(!hal_wait_for_ready(NULL)); 93 | HAL_TEST_ASSERT(!hal_is_ready(NULL)); 94 | HAL_TEST_ASSERT(!hal_is_busy(NULL)); 95 | HAL_TEST_ASSERT(!hal_is_error(NULL)); 96 | HAL_TEST_ASSERT_EQUAL(0, hal_get_status(NULL)); 97 | } 98 | 99 | int main(void) { 100 | HAL_TEST_BEGIN(); 101 | 102 | HAL_TEST_RUN(test_hal_io_basic); 103 | HAL_TEST_RUN(test_hal_io_status); 104 | HAL_TEST_RUN(test_hal_io_error); 105 | HAL_TEST_RUN(test_hal_io_invalid_params); 106 | 107 | HAL_TEST_END(); 108 | } -------------------------------------------------------------------------------- /04_software/hal/test/test_hal_mem.c: -------------------------------------------------------------------------------- 1 | /** 2 | * @file test_hal_mem.c 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Unit tests for HAL memory management 6 | * @version 1.0.0 7 | * @date 2020-03-28 8 | */ 9 | 10 | #include "hal_base.h" 11 | #include "hal_mem.h" 12 | #include "hal_test.h" 13 | 14 | static hal_context_t* ctx; 15 | static const size_t TEST_SIZE = 1024; // 1KB for testing 16 | 17 | /** 18 | * @brief Set up test environment before each test 19 | */ 20 | static void set_up(void) { 21 | ctx = hal_init("/dev/accelerator0"); 22 | HAL_TEST_ASSERT_NOT_NULL(ctx); 23 | } 24 | 25 | /** 26 | * @brief Clean up test environment after each test 27 | */ 28 | static void tear_down(void) { hal_cleanup(ctx); } 29 | 30 | /** 31 | * @brief Test basic memory allocation and deallocation 32 | */ 33 | static void test_hal_mem_basic_alloc_free(void) { 34 | set_up(); 35 | 36 | void* ptr = hal_mem_alloc(ctx, TEST_SIZE); 37 | HAL_TEST_ASSERT_NOT_NULL(ptr); 38 | 39 | // Verify alignment 40 | HAL_TEST_ASSERT((((uintptr_t)ptr & (HAL_MEM_ALIGN - 1)) == 0)); 41 | 42 | hal_mem_free(ctx, ptr); 43 | 44 | // Verify available memory is restored 45 | HAL_TEST_ASSERT_EQUAL_UINT64(HAL_ACCEL_MEM_SIZE, hal_mem_available(ctx)); 46 | 47 | tear_down(); 48 | } 49 | 50 | /** 51 | * @brief Test multiple allocations and deallocations 52 | */ 53 | static void test_hal_mem_multiple_allocs(void) { 54 | set_up(); 55 | 56 | void* ptr1 = hal_mem_alloc(ctx, TEST_SIZE); 57 | void* ptr2 = hal_mem_alloc(ctx, TEST_SIZE); 58 | void* ptr3 = hal_mem_alloc(ctx, TEST_SIZE); 59 | 60 | HAL_TEST_ASSERT_NOT_NULL(ptr1); 61 | HAL_TEST_ASSERT_NOT_NULL(ptr2); 62 | HAL_TEST_ASSERT_NOT_NULL(ptr3); 63 | 64 | // Verify pointers are different 65 | HAL_TEST_ASSERT(ptr1 != ptr2); 66 | HAL_TEST_ASSERT(ptr2 != ptr3); 67 | HAL_TEST_ASSERT(ptr1 != ptr3); 68 | 69 | hal_mem_free(ctx, ptr2); // Free middle block 70 | 71 | // Allocate a new block that should fit in the freed space 72 | void* ptr4 = hal_mem_alloc(ctx, TEST_SIZE); 73 | HAL_TEST_ASSERT_NOT_NULL(ptr4); 74 | 75 | hal_mem_free(ctx, ptr1); 76 | hal_mem_free(ctx, ptr3); 77 | hal_mem_free(ctx, ptr4); 78 | 79 | tear_down(); 80 | } 81 | 82 | /** 83 | * @brief Test virtual to physical address conversion 84 | */ 85 | static void test_hal_virt_to_phys(void) { 86 | set_up(); 87 | 88 | void* ptr = hal_mem_alloc(ctx, TEST_SIZE); 89 | HAL_TEST_ASSERT_NOT_NULL(ptr); 90 | 91 | uint64_t phys_addr = hal_virt_to_phys(ctx, ptr); 92 | HAL_TEST_ASSERT_NOT_EQUAL(0, phys_addr); 93 | HAL_TEST_ASSERT_GREATER_OR_EQUAL_UINT64(HAL_ACCEL_MEM_BASE, phys_addr); 94 | HAL_TEST_ASSERT_LESS_THAN_UINT64(HAL_ACCEL_MEM_BASE + HAL_ACCEL_MEM_SIZE, 95 | phys_addr); 96 | 97 | hal_mem_free(ctx, ptr); 98 | 99 | tear_down(); 100 | } 101 | 102 | /** 103 | * @brief Test allocation with invalid parameters 104 | */ 105 | static void test_hal_mem_invalid_params(void) { 106 | set_up(); 107 | 108 | // Test null context 109 | HAL_TEST_ASSERT_NULL(hal_mem_alloc(NULL, TEST_SIZE)); 110 | 111 | // Test zero size 112 | HAL_TEST_ASSERT_NULL(hal_mem_alloc(ctx, 0)); 113 | 114 | // Test too large size 115 | HAL_TEST_ASSERT_NULL(hal_mem_alloc(ctx, HAL_ACCEL_MEM_SIZE + 1)); 116 | 117 | // Test invalid virtual address 118 | HAL_TEST_ASSERT_EQUAL_UINT64(0, hal_virt_to_phys(ctx, (void*)0xDEADBEEF)); 119 | 120 | tear_down(); 121 | } 122 | 123 | /** 124 | * @brief Test memory fragmentation and coalescing 125 | */ 126 | static void test_hal_mem_fragmentation(void) { 127 | set_up(); 128 | 129 | void* ptrs[5]; 130 | const size_t small_size = 256; // 256 bytes 131 | 132 | // Allocate 5 small blocks 133 | for (int i = 0; i < 5; i++) { 134 | ptrs[i] = hal_mem_alloc(ctx, small_size); 135 | HAL_TEST_ASSERT_NOT_NULL(ptrs[i]); 136 | } 137 | 138 | // Free alternate blocks 139 | hal_mem_free(ctx, ptrs[1]); 140 | hal_mem_free(ctx, ptrs[3]); 141 | 142 | // Try to allocate a block that should fit in two coalesced free blocks 143 | void* large_ptr = hal_mem_alloc(ctx, small_size * 2); 144 | HAL_TEST_ASSERT_NOT_NULL(large_ptr); 145 | 146 | // Clean up 147 | hal_mem_free(ctx, ptrs[0]); 148 | hal_mem_free(ctx, ptrs[2]); 149 | hal_mem_free(ctx, ptrs[4]); 150 | hal_mem_free(ctx, large_ptr); 151 | 152 | tear_down(); 153 | } 154 | 155 | /** 156 | * @brief Test available memory tracking 157 | */ 158 | static void test_hal_mem_available(void) { 159 | set_up(); 160 | 161 | size_t initial_available = hal_mem_available(ctx); 162 | HAL_TEST_ASSERT_EQUAL_UINT64(HAL_ACCEL_MEM_SIZE, initial_available); 163 | 164 | void* ptr = hal_mem_alloc(ctx, TEST_SIZE); 165 | HAL_TEST_ASSERT_NOT_NULL(ptr); 166 | 167 | size_t after_alloc = hal_mem_available(ctx); 168 | HAL_TEST_ASSERT_LESS_THAN(initial_available, after_alloc); 169 | 170 | hal_mem_free(ctx, ptr); 171 | 172 | size_t after_free = hal_mem_available(ctx); 173 | HAL_TEST_ASSERT_EQUAL_UINT64(initial_available, after_free); 174 | 175 | tear_down(); 176 | } 177 | 178 | int main(void) { 179 | HAL_TEST_BEGIN(); 180 | 181 | HAL_TEST_RUN(test_hal_mem_basic_alloc_free); 182 | HAL_TEST_RUN(test_hal_mem_multiple_allocs); 183 | HAL_TEST_RUN(test_hal_virt_to_phys); 184 | HAL_TEST_RUN(test_hal_mem_invalid_params); 185 | HAL_TEST_RUN(test_hal_mem_fragmentation); 186 | HAL_TEST_RUN(test_hal_mem_available); 187 | 188 | HAL_TEST_END(); 189 | } 190 | -------------------------------------------------------------------------------- /04_software/runtime/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) 2 | project(accel_runtime VERSION 1.0 LANGUAGES CXX) 3 | 4 | set(CMAKE_CXX_STANDARD 17) 5 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 6 | set(CMAKE_CXX_EXTENSIONS ON) 7 | 8 | if(CMAKE_BUILD_TYPE MATCHES "Debug") 9 | add_compile_definitions(BUILD_DEBUG) 10 | elseif(CMAKE_BUILD_TYPE MATCHES "Release") 11 | add_compile_definitions(BUILD_RELEASE) 12 | endif() 13 | 14 | list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) 15 | include(accel_driver) 16 | 17 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}/bin) 18 | 19 | add_subdirectory(include) 20 | add_subdirectory(tutorials) -------------------------------------------------------------------------------- /04_software/runtime/cmake/accel_driver.cmake: -------------------------------------------------------------------------------- 1 | # Set paths for pre-built driver 2 | set(DRIVER_ROOT ${PROJECT_SOURCE_DIR}/../driver) 3 | set(DRIVER_INCLUDE_DIR ${DRIVER_ROOT}/include) 4 | set(DRIVER_LIB_DIR ${DRIVER_ROOT}/lib) 5 | 6 | # Create imported target for the pre-built driver library 7 | add_library(accel_driver SHARED IMPORTED) 8 | set_target_properties(accel_driver PROPERTIES 9 | IMPORTED_LOCATION ${DRIVER_LIB_DIR}/libaccel_driver.so 10 | INTERFACE_INCLUDE_DIRECTORIES ${DRIVER_INCLUDE_DIR} 11 | IMPORTED_LINK_INTERFACE_LANGUAGES "C" 12 | ) 13 | 14 | # Add compile definitions to ensure proper C/C++ interop 15 | target_compile_definitions(accel_driver INTERFACE 16 | EXTERN_C_LINKAGE 17 | ) 18 | 19 | # Verify that the library exists 20 | if(NOT EXISTS ${DRIVER_LIB_DIR}/libaccel_driver.so) 21 | message(FATAL_ERROR "Driver library not found at ${DRIVER_LIB_DIR}/libaccel_driver.so. Please build the driver first.") 22 | endif() -------------------------------------------------------------------------------- /04_software/runtime/include/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(accel_runtime INTERFACE) 2 | 3 | target_include_directories(accel_runtime 4 | INTERFACE 5 | ${CMAKE_CURRENT_SOURCE_DIR} 6 | ${DRIVER_INCLUDE_DIR} 7 | ) 8 | 9 | # Link against driver library 10 | target_link_libraries(accel_runtime 11 | INTERFACE 12 | accel_driver 13 | ) -------------------------------------------------------------------------------- /04_software/runtime/include/accel.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "accel/buffer.hpp" 4 | #include "accel/runtime.hpp" 5 | #include "accel/types.hpp" -------------------------------------------------------------------------------- /04_software/runtime/include/accel/buffer.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file buffer.hpp 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief RAII wrapper for accelerator buffer 6 | * @version 1.0.0 7 | * @date 2020-04-08 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #include "accel.h" 17 | 18 | namespace accel { 19 | 20 | /** 21 | * @brief RAII wrapper for accelerator buffer 22 | */ 23 | class Buffer { 24 | public: 25 | /** 26 | * @brief Creates a buffer of specified size 27 | * @param size Buffer size in bytes 28 | * @throws std::runtime_error if allocation fails 29 | */ 30 | explicit Buffer(size_t size) { 31 | buffer_ = accel_alloc_buffer(size); 32 | if (!buffer_) { 33 | throw std::runtime_error("Failed to allocate buffer: " + 34 | std::string(accel_get_error())); 35 | } 36 | } 37 | 38 | ~Buffer() { 39 | if (buffer_) { 40 | accel_free_buffer(buffer_); 41 | } 42 | } 43 | 44 | // Disable copying 45 | Buffer(const Buffer&) = delete; 46 | Buffer& operator=(const Buffer&) = delete; 47 | 48 | // Enable moving 49 | Buffer(Buffer&& other) noexcept : buffer_(other.buffer_) { 50 | other.buffer_ = nullptr; 51 | } 52 | 53 | Buffer& operator=(Buffer&& other) noexcept { 54 | if (this != &other) { 55 | if (buffer_) { 56 | accel_free_buffer(buffer_); 57 | } 58 | buffer_ = other.buffer_; 59 | other.buffer_ = nullptr; 60 | } 61 | return *this; 62 | } 63 | 64 | /** 65 | * @brief Get raw pointer to host memory 66 | * @return Pointer to host memory 67 | */ 68 | void* data() const { return buffer_->host_addr; } 69 | 70 | /** 71 | * @brief Get buffer size 72 | * @return Size in bytes 73 | */ 74 | size_t size() const { return buffer_->size; } 75 | 76 | private: 77 | accel_buffer_t* buffer_; 78 | friend class Runtime; 79 | }; 80 | 81 | } // namespace accel -------------------------------------------------------------------------------- /04_software/runtime/include/accel/runtime.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file runtime.hpp 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Runtime interface for hardware accelerator 6 | * @version 1.0.0 7 | * @date 2020-04-08 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | 14 | #include "buffer.hpp" 15 | #include "types.hpp" 16 | 17 | namespace accel { 18 | 19 | /** 20 | * @brief Runtime interface for accelerator operations 21 | */ 22 | class Runtime { 23 | public: 24 | /** 25 | * @brief Initialize runtime with device path 26 | * @param device_path Path to accelerator device 27 | * @throws std::runtime_error if initialization fails 28 | */ 29 | explicit Runtime(const std::string& device_path) { 30 | if (accel_init(device_path.c_str()) != ACCEL_STATUS_OK) { 31 | throw std::runtime_error("Failed to initialize runtime: " + 32 | std::string(accel_get_error())); 33 | } 34 | 35 | // Set default configuration 36 | if (accel_reset_config() != ACCEL_STATUS_OK) { 37 | throw std::runtime_error("Failed to reset configuration: " + 38 | std::string(accel_get_error())); 39 | } 40 | } 41 | 42 | ~Runtime() { accel_cleanup(); } 43 | 44 | // Disable copying 45 | Runtime(const Runtime&) = delete; 46 | Runtime& operator=(const Runtime&) = delete; 47 | 48 | /** 49 | * @brief Configure runtime parameters 50 | * @param flags Configuration flags 51 | * @param num_channels Number of DMA channels 52 | * @param max_transfer Maximum transfer size 53 | * @param timeout_ms Operation timeout in milliseconds 54 | * @throws std::runtime_error if configuration fails 55 | */ 56 | void Configure(uint32_t flags, uint32_t num_channels = 1, 57 | uint32_t max_transfer = 0x1000000, 58 | uint32_t timeout_ms = 1000) { 59 | accel_config_t config{}; 60 | config.flags = flags; 61 | config.num_channels = num_channels; 62 | config.max_transfer = max_transfer; 63 | config.timeout_ms = timeout_ms; 64 | 65 | if (accel_configure(&config) != ACCEL_STATUS_OK) { 66 | throw std::runtime_error("Failed to configure runtime: " + 67 | std::string(accel_get_error())); 68 | } 69 | } 70 | 71 | /** 72 | * @brief Execute matrix multiplication 73 | * @param input Input buffer 74 | * @param weights Weight buffer 75 | * @param output Output buffer 76 | * @throws std::runtime_error if operation fails 77 | */ 78 | void MatrixMultiply(const Buffer& input, const Buffer& weights, 79 | Buffer& output) { 80 | accel_op_params_t params{}; 81 | params.op_type = ACCEL_OP_MATMUL; 82 | params.input = *input.buffer_; 83 | params.weights = *weights.buffer_; 84 | params.output = *output.buffer_; 85 | 86 | SubmitAndWait(params); 87 | } 88 | 89 | /** 90 | * @brief Execute 2D convolution 91 | * @param input Input buffer 92 | * @param weights Weight buffer 93 | * @param output Output buffer 94 | * @throws std::runtime_error if operation fails 95 | */ 96 | void Convolution2D(const Buffer& input, const Buffer& weights, 97 | Buffer& output) { 98 | accel_op_params_t params{}; 99 | params.op_type = ACCEL_OP_CONV2D; 100 | params.input = *input.buffer_; 101 | params.weights = *weights.buffer_; 102 | params.output = *output.buffer_; 103 | 104 | SubmitAndWait(params); 105 | } 106 | 107 | private: 108 | /** 109 | * @brief Submit operation and wait for completion 110 | * @param params Operation parameters 111 | * @throws std::runtime_error if operation fails 112 | */ 113 | void SubmitAndWait(const accel_op_params_t& params) { 114 | accel_status_t status = accel_submit_op(¶ms); 115 | if (status != ACCEL_STATUS_OK) { 116 | throw std::runtime_error("Failed to submit operation: " + 117 | std::string(accel_get_error())); 118 | } 119 | 120 | status = accel_wait_complete(0); // Wait indefinitely 121 | if (status != ACCEL_STATUS_OK) { 122 | throw std::runtime_error("Operation failed: " + 123 | std::string(accel_get_error())); 124 | } 125 | } 126 | }; 127 | 128 | } // namespace accel -------------------------------------------------------------------------------- /04_software/runtime/include/accel/types.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file types.hpp 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Common type definitions for runtime 6 | * @version 1.0.0 7 | * @date 2020-04-08 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | 14 | namespace accel { 15 | 16 | /** 17 | * @brief Runtime configuration flags 18 | */ 19 | enum ConfigFlags : uint32_t { 20 | kEnableDma = ACCEL_CONFIG_ENABLE_DMA, 21 | kSyncMode = ACCEL_CONFIG_SYNC_MODE, 22 | kHighPriority = ACCEL_CONFIG_HIGH_PRIORITY 23 | }; 24 | 25 | } // namespace accel -------------------------------------------------------------------------------- /04_software/runtime/tutorials/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Add executable 2 | add_executable(inference demo.cc) 3 | 4 | # Add include directories 5 | target_include_directories(inference 6 | PRIVATE 7 | ${PROJECT_SOURCE_DIR}/include 8 | ) 9 | 10 | # Link libraries 11 | target_link_libraries(inference 12 | PRIVATE 13 | accel_driver 14 | ) -------------------------------------------------------------------------------- /04_software/runtime/tutorials/demo.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * @file demo.cc 3 | * @author Leo (zhsleo@outlook.com) 4 | * 5 | * @brief Demo for runtime 6 | * @version 1.0.0 7 | * @date 2020-04-08 8 | */ 9 | 10 | #include 11 | 12 | int main() { 13 | accel::Runtime runtime("/dev/accelerator"); 14 | runtime.Configure(accel::kEnableDma); 15 | 16 | accel::Buffer input(1024); 17 | accel::Buffer weights(1024); 18 | accel::Buffer output(1024); 19 | 20 | runtime.MatrixMultiply(input, weights, output); 21 | return 0; 22 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 User-xLeo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README-CN.md: -------------------------------------------------------------------------------- 1 | ![LOGO](./imgs/logo.svg) 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 4 | 5 | 简体中文 | [English](./README.md) 6 | 7 | ## DNN Accelator 8 | 本项目使用Verilog实现了一些DNN中的常用算子 9 | 10 | ## 交流与反馈 11 | 如果遇到问题,可以通过邮件或仓库中的Issues反馈。 12 | 13 | ## 文档 14 | 相关文档放置在[Wiki](https://github.com/User-xLeo/DNN-Accelerator/wiki),请阅读该文档以了解更多信息。 15 | 16 | ## 贡献者 17 | Leo (zhsleo@outlook.com) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![LOGO](./imgs/logo.svg) 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 4 | 5 | English | [简体中文](./README-CN.md) 6 | 7 | ## DNN Accelator 8 | This repository utilizes Verilog to implement common operators in deep neural network (DNN). 9 | 10 | ## Documentation 11 | The docs have been placed in the [Wiki](https://github.com/User-xLeo/DNN-Accelerator/wiki). Please read the docs for more information. 12 | 13 | ## How to Get Help 14 | Issues can be emailed to me or left as Issues in the repository. I will reply to them as promptly as possible. 15 | 16 | ## Contributor 17 | Leo (zhsleo@outlook.com) -------------------------------------------------------------------------------- /imgs/dma_read_block_design.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/user-xleo/DNN-Accelerator/966a7252615f35cad78c5ae84035297636b0d87f/imgs/dma_read_block_design.png -------------------------------------------------------------------------------- /imgs/dma_read_simulation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/user-xleo/DNN-Accelerator/966a7252615f35cad78c5ae84035297636b0d87f/imgs/dma_read_simulation.png -------------------------------------------------------------------------------- /imgs/dma_write_block_design.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/user-xleo/DNN-Accelerator/966a7252615f35cad78c5ae84035297636b0d87f/imgs/dma_write_block_design.png -------------------------------------------------------------------------------- /imgs/dma_write_simulation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/user-xleo/DNN-Accelerator/966a7252615f35cad78c5ae84035297636b0d87f/imgs/dma_write_simulation.png -------------------------------------------------------------------------------- /imgs/logo.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/user-xleo/DNN-Accelerator/966a7252615f35cad78c5ae84035297636b0d87f/imgs/overview.png --------------------------------------------------------------------------------