├── utils ├── __init__.py └── custom_device_mode.py ├── README.md ├── open_registration_example.py └── cpp_extensions └── open_registration_extension.cpp /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_open_registration_example 2 | Example of using pytorch's open device registration API. It covers: 3 | 4 | (1) Writing custom kernels in C++, and registering them to the PyTorch dispatcher 5 | 6 | (2) Providing a user API for your custom device, so users can invoke the custom code using `torch.foo(..., device="custom_device")` 7 | 8 | (3) Registering a custom memory allocator 9 | 10 | (4) Registering a custom device guard 11 | 12 | This repo should be run using the latest pytorch nightly. If it fails on the latest nightly, please post an issue and I'll take a look! 13 | -------------------------------------------------------------------------------- /utils/custom_device_mode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.cpp_extension 3 | from torch.overrides import TorchFunctionMode 4 | 5 | # Load the C++ extension containing your custom kernels. 6 | foo_module = torch.utils.cpp_extension.load( 7 | name="custom_device_extension", 8 | sources=[ 9 | "cpp_extensions/open_registration_extension.cpp", 10 | ], 11 | extra_include_paths=["cpp_extensions"], 12 | extra_cflags=["-g"], 13 | verbose=True, 14 | ) 15 | 16 | print('Loaded custom extension.') 17 | 18 | # The user will globally enable the below mode when calling this API 19 | def enable_foo_device(): 20 | m = FooDeviceMode() 21 | m.__enter__() 22 | # If you want the mode to never be disabled, then this function shouldn't return anything. 23 | return m 24 | 25 | # This is a simple TorchFunctionMode class that: 26 | # (a) Intercepts all torch.* calls 27 | # (b) Checks for kwargs of the form `device="foo:i"` 28 | # (c) Turns those into custom device objects: `device=foo_module.custom_device(i)` 29 | # (d) Forwards the call along into pytorch. 30 | class FooDeviceMode(TorchFunctionMode): 31 | def __torch_function__(self, func, types, args=(), kwargs=None): 32 | if kwargs is None: 33 | kwargs = {} 34 | if 'device' in kwargs and 'foo' in kwargs['device']: 35 | device_and_idx = kwargs['device'].split(':') 36 | if len(device_and_idx) == 1: 37 | # Case 1: No index specified 38 | kwargs['device'] = foo_module.custom_device() 39 | else: 40 | # Case 2: The user specified a device index. 41 | device_idx = int(device_and_idx[1]) 42 | kwargs['device'] = foo_module.custom_device(device_idx) 43 | with torch._C.DisableTorchFunction(): 44 | return func(*args, **kwargs) 45 | -------------------------------------------------------------------------------- /open_registration_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.custom_device_mode import foo_module, enable_foo_device 3 | 4 | # This file contains an example of how to create a custom device extension 5 | # in PyTorch, through the dispatcher. 6 | # It also shows what two possible user API's for custom devices look like. Either: 7 | # (1) Expose your custom device as an object, device=my_device_obj 8 | # (2) Allow users to directly use device strings: device="my_device" 9 | 10 | # Running this file prints the following: 11 | 12 | # (Correctly) unable to create tensor on device='bar' 13 | # (Correctly) unable to create tensor on device='foo:2' 14 | # Creating x on device 'foo:0' 15 | # Custom aten::empty.memory_format() called! 16 | # Custom allocator's allocate() called! 17 | # Creating y on device 'foo:0' 18 | # Custom aten::empty.memory_format() called! 19 | # Custom allocator's allocate() called! 20 | 21 | # Test START 22 | # x.device=foo:0, x.is_cpu=False 23 | # y.device=foo:0, y.is_cpu=False 24 | # Calling z = x + y 25 | # Custom aten::add.Tensor() called! 26 | # Custom aten::empty.memory_format() called! 27 | # Custom allocator's allocate() called! 28 | # z.device=foo:0, z.is_cpu=False 29 | # Calling z = z.to(device="cpu") 30 | # Custom aten::_copy_from() called! 31 | # z_cpu.device=cpu, z_cpu.is_cpu=True 32 | # Calling z2 = z_cpu + z_cpu 33 | # Test END 34 | 35 | # Custom allocator's delete() called! 36 | # Creating x on device 'foo:1' 37 | # Custom aten::empty.memory_format() called! 38 | # Custom allocator's allocate() called! 39 | # Creating y on device 'foo:1' 40 | # Custom aten::empty.memory_format() called! 41 | # Custom allocator's allocate() called! 42 | 43 | # Test START 44 | # x.device=foo:0, x.is_cpu=False 45 | # y.device=foo:0, y.is_cpu=False 46 | # Calling z = x + y 47 | # Custom aten::add.Tensor() called! 48 | # Custom aten::empty.memory_format() called! 49 | # Custom allocator's allocate() called! 50 | # z.device=foo:0, z.is_cpu=False 51 | # Calling z = z.to(device="cpu") 52 | # Custom aten::_copy_from() called! 53 | # z_cpu.device=cpu, z_cpu.is_cpu=True 54 | # Calling z2 = z_cpu + z_cpu 55 | # Test END 56 | 57 | # Custom allocator's delete() called! 58 | # Custom allocator's delete() called! 59 | # Custom allocator's delete() called! 60 | # Custom allocator's delete() called! 61 | # Custom allocator's delete() called! 62 | 63 | def test(x, y): 64 | print() 65 | print("Test START") 66 | # Check that our device is correct. 67 | print(f'x.device={x.device}, x.is_cpu={x.is_cpu}') 68 | print(f'y.device={y.device}, y.is_cpu={y.is_cpu}') 69 | 70 | # calls out custom add kernel, registered to the dispatcher 71 | print('Calling z = x + y') 72 | z = x + y 73 | print(f'z.device={z.device}, z.is_cpu={z.is_cpu}') 74 | 75 | print('Calling z = z.to(device="cpu")') 76 | z_cpu = z.to(device='cpu') 77 | 78 | # Check that our cross-device copy correctly copied the data to cpu 79 | print(f'z_cpu.device={z_cpu.device}, z_cpu.is_cpu={z_cpu.is_cpu}') 80 | 81 | # Confirm that calling the add kernel no longer invokes our custom kernel, 82 | # since we're using CPU t4ensors. 83 | print('Calling z2 = z_cpu + z_cpu') 84 | z2 = z_cpu + z_cpu 85 | print("Test END") 86 | print() 87 | 88 | # Option 1: Use torch.register_privateuse1_backend("foo"), which will allow 89 | # "foo" as a device string to work seamlessly with pytorch's API's. 90 | # You may need a more recent nightly of PyTorch for this. 91 | torch.register_privateuse1_backend('foo') 92 | 93 | # Show that in general, passing in a custom device string will fail. 94 | try: 95 | x = torch.ones(4, 4, device='bar') 96 | exit("Error: you should not be able to make a tensor on an arbitrary 'bar' device.") 97 | except RuntimeError as e: 98 | print("(Correctly) unable to create tensor on device='bar'") 99 | 100 | # Show that in general, passing in a custom device string will fail. 101 | try: 102 | x = torch.ones(4, 4, device='foo:2') 103 | exit("Error: the foo device only has two valid indices: foo:0 and foo:1") 104 | except RuntimeError as e: 105 | print("(Correctly) unable to create tensor on device='foo:2'") 106 | 107 | print("Creating x on device 'foo:0'") 108 | x1 = torch.ones(4, 4, device='foo:0') 109 | print("Creating y on device 'foo:0'") 110 | y1 = torch.ones(4, 4, device='foo:0') 111 | 112 | test(x1, y1) 113 | 114 | 115 | # Option 2: Directly expose a custom device object 116 | # You can pass an optional index arg, specifying which device index to use. 117 | foo_device1 = foo_module.custom_device(1) 118 | 119 | print("Creating x on device 'foo:1'") 120 | x2 = torch.ones(4, 4, device=foo_device1) 121 | print("Creating y on device 'foo:1'") 122 | y2 = torch.ones(4, 4, device=foo_device1) 123 | 124 | # Option 3: Enable a TorchFunctionMode object in user land, 125 | # that will convert `device="foo"` calls into our custom device objects automatically. 126 | # Option 1 is strictly better here (in particular, printing a.device() will still 127 | # print "privateuseone" instead of your custom device name). Mostly showing this option because: 128 | # (a) Torch Function Modes have been around for longer, and the API in Option 1 129 | # is only available on a more recent nightly. 130 | # (b) This is a cool example of how powerful torch_function and torch_dispatch modes can be! 131 | # holder = enable_foo_device() 132 | # del _holder 133 | 134 | test(x2, y2) 135 | -------------------------------------------------------------------------------- /cpp_extensions/open_registration_extension.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | // This file contains the heavy lifting to add a new C++ backend 16 | // and integrate it directly into the PyTorch backend. It mainly involves: 17 | // 18 | // (1) Writing a custom allocator and registering it to pytorch 19 | // (see DummyCustomAllocator) 20 | // (2) Writing a custom device guard, registering it to pytorch, 21 | // and using the device guard in kernels 22 | // (see DummyDeviceGuard) 23 | // (3) Writing a custom aten::empty.memory_format function 24 | 25 | 26 | // basic dummy add function 27 | at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { 28 | const at::OptionalDeviceGuard device_guard(at::device_of(self)); 29 | std::cout << "Custom aten::add.Tensor() called!" << std::endl; 30 | // Since this custom device is just for testing, not bothering to implement kernels. 31 | return at::empty(self.sizes(), self.options()); 32 | } 33 | 34 | // ===================================== 35 | // ========= Custom Allocators ========= 36 | // ===================================== 37 | 38 | // PyTorch provides an API for registering custom allocators for your device. 39 | // You can create one by inheriting from the at::Allocator class, 40 | // and registering your allocator for the particular device type 41 | // (PrivateUse1 for open registration devices) 42 | 43 | // A dummy allocator for our custom device, that secretly uses the CPU 44 | struct DummyCustomAllocator final : at::Allocator { 45 | DummyCustomAllocator() = default; 46 | at::DataPtr allocate(size_t nbytes) const override { 47 | std::cout << "Custom allocator's allocate() called!" << std::endl; 48 | void* data = c10::alloc_cpu(nbytes); 49 | return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, 0)}; 50 | } 51 | 52 | static void ReportAndDelete(void* ptr) { 53 | if (!ptr) { 54 | return; 55 | } 56 | std::cout << "Custom allocator's delete() called!" << std::endl; 57 | c10::free_cpu(ptr); 58 | } 59 | 60 | at::DeleterFnPtr raw_deleter() const override { 61 | return &ReportAndDelete; 62 | } 63 | }; 64 | 65 | // Register our dummy allocator 66 | static DummyCustomAllocator global_custom_alloc; 67 | REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_custom_alloc); 68 | 69 | // ===================================== 70 | // ============= Device Guards ========= 71 | // ===================================== 72 | 73 | // PyTorch has an API for registering device guards. 74 | // Device guards can be used to set the current "active" device, 75 | // and e.g. error if the user provides an invalid device index. 76 | // 77 | // If your device doesn't support indices (e.g. foo:0 vs. foo:1), 78 | // then the guards probably aren't needed. 79 | // 80 | // You can use it by creating a DeviceGuard class, registering it 81 | // in PyTorch, and invoking the device guard before any kernels are called. 82 | // For a more full-featured example of a device guard, 83 | // check out the code at c10/cuda/CUDAGuard.h 84 | 85 | // Represents the current "active" device. 86 | // The dummy device guard registered below is meant to show how a backend 87 | // can integrate custom device guard with pytorch. 88 | // For something like cuda this represents the current active cuda device, 89 | // which is directly set using the cuda API calls cudaGetDevice/cudaSetDevice. 90 | static uint16_t CURR_DEVICE = -1; 91 | 92 | // Create and register a dummy device guard. 93 | struct DummyDeviceGuardImpl final : public c10::impl::DeviceGuardImplInterface { 94 | static constexpr c10::DeviceType static_type = c10::DeviceType::PrivateUse1; 95 | DummyDeviceGuardImpl() {} 96 | explicit DummyDeviceGuardImpl(c10::DeviceType t) { 97 | TORCH_INTERNAL_ASSERT(t == c10::DeviceType::PrivateUse1); 98 | } 99 | at::DeviceType type() const override { 100 | return at::DeviceType::PrivateUse1; 101 | } 102 | at::Device exchangeDevice(at::Device d) const override { 103 | TORCH_INTERNAL_ASSERT(d.type() == at::DeviceType::PrivateUse1); 104 | TORCH_INTERNAL_ASSERT(d.index() < deviceCount(), "Error: device index ", d.index(), " does not exist."); 105 | at::Device old_device = getDevice(); 106 | if (old_device.index() != d.index()) { 107 | // "set the active device" 108 | CURR_DEVICE = d.index(); 109 | } 110 | return old_device; 111 | } 112 | at::Device getDevice() const override { 113 | return at::Device(at::DeviceType::PrivateUse1, CURR_DEVICE); 114 | } 115 | void setDevice(at::Device d) const override { 116 | TORCH_INTERNAL_ASSERT(d.type() == at::DeviceType::PrivateUse1); 117 | TORCH_INTERNAL_ASSERT(d.index() < deviceCount(), "Error: device index ", d.index(), " does not exist."); 118 | at::Device current_device = getDevice(); 119 | if (current_device != d) { 120 | CURR_DEVICE = d.index(); 121 | } 122 | } 123 | void uncheckedSetDevice(at::Device d) const noexcept override { 124 | auto current_device = getDevice(); 125 | if (current_device != d) { 126 | CURR_DEVICE = d.index(); 127 | } 128 | } 129 | at::Stream getStream(at::Device d) const noexcept override { 130 | // no-op 131 | return at::Stream(at::Stream::DEFAULT, d); 132 | } 133 | // NB: These do NOT set the current device 134 | at::Stream exchangeStream(at::Stream) const noexcept override { 135 | // no-op 136 | return at::Stream(at::Stream::DEFAULT, at::Device(at::DeviceType::PrivateUse1, CURR_DEVICE)); 137 | } 138 | at::DeviceIndex deviceCount() const noexcept override { 139 | // Hardcoding the number of "valid" devices here at 2. 140 | return 2; 141 | } 142 | 143 | // Event-related functions 144 | void record( 145 | void** /*event*/, 146 | const at::Stream& /*stream*/, 147 | const at::DeviceIndex /*device_index*/, 148 | const c10::EventFlag /*flag*/) const override { 149 | TORCH_CHECK(false, at::DeviceType::PrivateUse1, " backend doesn't support events."); 150 | } 151 | void block(void* /*event*/, const at::Stream& /*stream*/) const override { 152 | TORCH_CHECK(false, at::DeviceType::PrivateUse1, " backend doesn't support events.") 153 | } 154 | bool queryEvent(void* /*event*/) const override { 155 | TORCH_CHECK(false, at::DeviceType::PrivateUse1, " backend doesn't support events.") 156 | } 157 | void destroyEvent(void* /*event*/, const at::DeviceIndex /*device_index*/) 158 | const noexcept override {} 159 | 160 | // Stream-related functions 161 | bool queryStream(const at::Stream& /*stream*/) const override { 162 | return true; 163 | } 164 | void synchronizeStream(const at::Stream& /*stream*/) const override { 165 | // Don't wait for anything. 166 | } 167 | }; 168 | 169 | struct DummyGuard { 170 | explicit DummyGuard() = delete; 171 | explicit DummyGuard(at::DeviceIndex device_index) : guard_(device_index) {} 172 | explicit DummyGuard(at::Device device) : guard_(device) {} 173 | DummyGuard(const DummyGuard&) = delete; 174 | DummyGuard& operator=(const DummyGuard&) = delete; 175 | DummyGuard(DummyGuard&& other) = delete; 176 | DummyGuard& operator=(DummyGuard&& other) = delete; 177 | 178 | void set_device(at::Device device) { 179 | guard_.set_device(device); 180 | } 181 | 182 | void reset_device(at::Device device) { 183 | guard_.reset_device(device); 184 | } 185 | 186 | void set_index(at::DeviceIndex device_index) { 187 | guard_.set_index(device_index); 188 | } 189 | 190 | at::Device original_device() const { 191 | return guard_.original_device(); 192 | } 193 | 194 | at::Device current_device() const { 195 | return guard_.current_device(); 196 | } 197 | 198 | private: 199 | c10::impl::InlineDeviceGuard guard_; 200 | }; 201 | 202 | C10_REGISTER_GUARD_IMPL(PrivateUse1, DummyDeviceGuardImpl); 203 | 204 | 205 | // ===================================== 206 | // ============= KERNELS =============== 207 | // ===================================== 208 | 209 | // basic dummy empty function, so we can directly construct tensors on the custom device 210 | // This dummy test device will just use the CPU allocator, and ignores pinned memory. 211 | // 212 | // Note: this kernel is very simple because our "custom device" just uses the normal TensorImpl object 213 | // to store data under the hood. 214 | // In PyTorch core today, both cpu and cuda are implemented with an ordinary TensorImpl class. 215 | // Sometimes, backends prefer to subclass TensorImpl in order to store extra information. 216 | // If this is the case, then this kernel is where you'll be responsible for creating and returning 217 | // a fresh at::Tensor object, that properly stores a TensorImpl of your subclass. 218 | at::Tensor custom_empty_memory_format(at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) { 219 | const at::OptionalDeviceGuard device_guard(device); 220 | std::cout << "Custom aten::empty.memory_format() called!" << std::endl; 221 | constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1); 222 | return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format); 223 | } 224 | 225 | at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) { 226 | const at::OptionalDeviceGuard device_guard(at::device_of(self)); 227 | // Not bothering to implement. 228 | // Should fill the tensor's data with "value". 229 | return self; 230 | } 231 | 232 | // basic dummy copy_() function, so we can copy from the custom device to/from CPU 233 | at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) { 234 | const at::OptionalDeviceGuard device_guard(at::device_of(self)); 235 | std::cout << "Custom aten::_copy_from() called!" << std::endl; 236 | TORCH_CHECK(self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device."); 237 | TORCH_CHECK(dst.is_cpu() || dst.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device."); 238 | 239 | // Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous. 240 | TORCH_CHECK(self.sizes() == dst.sizes()); 241 | TORCH_CHECK(self.scalar_type() == dst.scalar_type()); 242 | TORCH_CHECK(self.is_contiguous() && dst.is_contiguous()); 243 | 244 | std::memcpy(dst.storage().data_ptr().get(), self.storage().data_ptr().get(), self.storage().nbytes()); 245 | return dst; 246 | } 247 | 248 | 249 | // This macro does the heavy lifting. 250 | // With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend. 251 | // For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key. 252 | // Later in this file, we map a custom device to the PrivateUse1 device type, 253 | // which allows user code that puts a tensor on your custom_device to eventually get plumbed 254 | // into the kernels registered here. 255 | // 256 | // This macro registers your kernels to the PyTorch Dispatcher. 257 | // More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/. 258 | TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { 259 | m.impl("add.Tensor", &custom_add_Tensor); 260 | m.impl("empty.memory_format", &custom_empty_memory_format); 261 | m.impl("fill_.Scalar", &custom_fill__scalar); 262 | m.impl("_copy_from", &custom__copy_from); 263 | } 264 | 265 | // This basic implementation doesn't bother dealing with different device indices 266 | // (e.g. custom_device:0 vs. custom_device:1). 267 | // We could do that by letting the user pass in a device index in our exposed device function. 268 | // Note that if you do that, you'll also need to register a device guard to core. 269 | // See `c10/core/impl/DeviceGuardImplInterface.h:C10_REGISTER_GUARD_IMPL`. 270 | c10::Device get_custom_device(int idx) { 271 | return c10::Device(c10::DeviceType::PrivateUse1, idx); 272 | } 273 | 274 | // Here, we're exposing a custom device object that corresponds to our custom backend. 275 | // We do this using pybind: exposing an "extension_name.custom_device()" function in python, 276 | // that's implemented in C++. 277 | // The implementation in this file maps directly to the `PrivateUse1` device type. 278 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 279 | m.def("custom_device", &get_custom_device, "get custom device object"); 280 | } 281 | --------------------------------------------------------------------------------