├── tests └── .gitignore ├── .gitignore ├── CNAME ├── requirements.txt ├── .flake8 ├── pyproject.toml ├── .github ├── FUNDING.yml └── workflows │ └── ci.yml ├── README.md.in ├── LICENSE ├── generate_index_html.py ├── run_tests.py ├── generate_readme_md.py ├── index.html.in ├── conversions.yaml ├── README.md └── index.html /tests/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.pytest_cache/ 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /CNAME: -------------------------------------------------------------------------------- 1 | pytorch-for-numpy-users.wkentaro.com 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | PyYAML 3 | tabulate 4 | torch 5 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = .anaconda3/, tests/ 3 | ignore = 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 79 3 | exclude = ''' 4 | ( 5 | ^/\..* 6 | | ^/docs/ 7 | | ^/github2pypi/ 8 | ) 9 | ''' 10 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [wkentaro] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /README.md.in: -------------------------------------------------------------------------------- 1 | 2 | # PyTorch for Numpy users. 3 | 4 | ![ci](https://github.com/wkentaro/pytorch-for-numpy-users/workflows/ci/badge.svg) 5 | ![gh-pages](https://img.shields.io/github/deployments/wkentaro/pytorch-for-numpy-users/github-pages?label=gh-pages) 6 | 7 | 8 | [PyTorch](https://github.com/pytorch/pytorch.git) version of [_Torch for Numpy users_](https://github.com/torch/torch7/wiki/Torch-for-Numpy-users). 9 | We assume you use the latest PyTorch and Numpy. 10 | 11 | 12 | ## How to contribute? 13 | 14 | ```bash 15 | git clone https://github.com/wkentaro/pytorch-for-numpy-users.git 16 | cd pytorch-for-numpy-users 17 | vim conversions.yaml 18 | git commit -m "Update conversions.yaml" 19 | 20 | ./run_tests.py 21 | ``` 22 | 23 | 24 | $CONTENTS 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kentaro Wada 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: [3.8] 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v1 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Install requirements 24 | run: | 25 | pip install -r requirements.txt 26 | - name: Lint 27 | run: | 28 | pip install black hacking 29 | black --check . 30 | flake8 . 31 | - name: Test 32 | run: | 33 | pip install pytest 34 | ./run_tests.py 35 | - name: Update README.md 36 | run: | 37 | if [ "$GITHUB_EVENT_NAME" = "push" -a "${GITHUB_REF:11}" = "main" ]; then 38 | ./generate_readme_md.py > README.md 39 | ./generate_index_html.py > index.html 40 | 41 | git config --global user.email "www.kentaro.wada@gmail.com" 42 | git config --global user.name "Kentaro Wada" 43 | git add README.md index.html 44 | git diff-index --cached --quiet HEAD || git commit -m "Update README.md and index.html" 45 | git push origin main 46 | fi 47 | -------------------------------------------------------------------------------- /generate_index_html.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import collections 4 | import os.path as osp 5 | import string 6 | 7 | import yaml 8 | 9 | from generate_readme_md import get_section 10 | 11 | 12 | def get_contents(): 13 | # keep order in yaml file 14 | yaml.add_constructor( 15 | yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, 16 | lambda loader, node: collections.OrderedDict( 17 | loader.construct_pairs(node) 18 | ), 19 | ) 20 | 21 | yaml_file = osp.join(here, "conversions.yaml") 22 | with open(yaml_file) as f: 23 | data = yaml.safe_load(f) 24 | contents = [] 25 | contents.append("
") 26 | for title, data in data.items(): 27 | section = get_section(title, data) 28 | section = section.replace( 29 | "", 30 | "
" 31 | "
", 32 | ) 33 | section = section.replace("
", "
") 34 | section = section.replace("", "") 35 | contents.append("
") 36 | contents.append(section) 37 | contents.append("
") 38 | contents.append("") 39 | return "\n".join(contents) 40 | 41 | 42 | here = osp.dirname(osp.abspath(__file__)) 43 | 44 | 45 | def main(): 46 | with open(osp.join(here, "index.html.in")) as f: 47 | template = f.read() 48 | template = string.Template(template) 49 | readme = template.substitute(CONTENTS=get_contents()) 50 | print(readme) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /run_tests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import print_function 4 | 5 | import glob 6 | import os 7 | import os.path as osp 8 | import subprocess 9 | 10 | import yaml 11 | 12 | 13 | def parse(data): 14 | for _, datum in data.items(): 15 | if isinstance(datum, dict): 16 | for key_content in parse(datum): 17 | yield key_content 18 | else: 19 | for item in datum: 20 | for key in item: 21 | item_k = item[key] 22 | if isinstance(item_k, dict): 23 | content = item_k["content"] 24 | is_code = item_k.get("is_code", True) 25 | if item_k.get("skip_test", False): 26 | continue 27 | else: 28 | content = item_k 29 | is_code = True 30 | if is_code and content is not None: 31 | yield key, content 32 | 33 | 34 | def main(): 35 | here = osp.dirname(osp.abspath(__file__)) 36 | 37 | with open(osp.join(here, "conversions.yaml")) as f: 38 | data = yaml.safe_load(f) 39 | 40 | for fname in glob.glob(osp.join(here, "tests/*.py")): 41 | os.remove(fname) 42 | 43 | for i, (key, content) in enumerate(parse(data)): 44 | if key == "numpy": 45 | code = """\ 46 | import numpy as np 47 | 48 | 49 | def test_{key}_{id:04d}(): 50 | x = np.array([[1, 2, 3], [4, 5, 6]]) 51 | {content} 52 | """ 53 | elif key == "pytorch": 54 | code = """\ 55 | import torch 56 | 57 | 58 | def test_{key}_{id:04d}(): 59 | x = torch.tensor([[1, 2, 3], [4, 5, 6]]) 60 | {content} 61 | """ 62 | else: 63 | raise ValueError 64 | 65 | content = "\n".join(" " * 4 + line for line in content.splitlines()) 66 | code = code.format(key=key, id=i, content=content) 67 | 68 | test_file = osp.join( 69 | here, "tests/test_{key}_{id:04d}.py".format(key=key, id=i) 70 | ) 71 | with open(test_file, "w") as f: 72 | f.write(code) 73 | 74 | cmd = "pytest -vs tests" 75 | print("+ %s" % cmd) 76 | subprocess.check_call(cmd, shell=True) 77 | 78 | 79 | if __name__ == "__main__": 80 | main() 81 | -------------------------------------------------------------------------------- /generate_readme_md.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import collections 4 | import os.path as osp 5 | import string 6 | 7 | import tabulate 8 | import yaml 9 | 10 | 11 | def get_section(title, data, h=2): 12 | if not isinstance(data, list): 13 | content = "{1:s}\n\n".format(h, title.capitalize()) 14 | for sub_title, sub_data in data.items(): 15 | content += get_section(sub_title, sub_data, h=h + 1) 16 | return content 17 | 18 | headers = ["Numpy", "PyTorch"] 19 | keys = ["numpy", "pytorch"] 20 | rows = [] 21 | for d in data: 22 | row = [] 23 | for key in keys: 24 | if isinstance(d[key], dict): 25 | content = d[key]["content"] 26 | is_code = d[key].get("is_code", True) 27 | elif d[key] is None: 28 | content = "" 29 | is_code = False 30 | else: 31 | content = d[key] 32 | is_code = True 33 | if is_code and content: 34 | content = "
\n{:s}
".format(content) 35 | row.append(content) 36 | rows.append(row) 37 | 38 | contents = [] 39 | contents.append("{1:s}".format(h, title.capitalize())) 40 | contents.append( 41 | tabulate.tabulate(rows, headers=headers, tablefmt="unsafehtml") 42 | ) 43 | return "\n".join(contents) 44 | 45 | 46 | def get_contents(): 47 | # keep order in yaml file 48 | yaml.add_constructor( 49 | yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, 50 | lambda loader, node: collections.OrderedDict( 51 | loader.construct_pairs(node) 52 | ), 53 | ) 54 | 55 | yaml_file = osp.join(here, "conversions.yaml") 56 | with open(yaml_file) as f: 57 | data = yaml.safe_load(f) 58 | contents = [] 59 | for title, data in data.items(): 60 | section = get_section(title, data) 61 | contents.append(section) 62 | return "\n".join(contents) 63 | 64 | 65 | here = osp.dirname(osp.abspath(__file__)) 66 | 67 | 68 | def main(): 69 | with open(osp.join(here, "README.md.in")) as f: 70 | template = f.read() 71 | template = string.Template(template) 72 | readme = template.substitute(CONTENTS=get_contents()) 73 | print(readme) 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /index.html.in: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 15 | 49 | 50 |
51 |

PyTorch for Numpy Users.

52 | Tweet 53 | 54 |

55 | A cheatsheet for Numpy users to use PyTorch. To add/edit an entry, see it on 56 | GitHub. 57 |

58 | 59 | $CONTENTS 60 |
61 | 62 | 63 | -------------------------------------------------------------------------------- /conversions.yaml: -------------------------------------------------------------------------------- 1 | types: 2 | - numpy: np.ndarray 3 | pytorch: torch.Tensor 4 | - numpy: np.float32 5 | pytorch: torch.float32; torch.float 6 | - numpy: np.float64 7 | pytorch: torch.float64; torch.double 8 | - numpy: np.float16 9 | pytorch: torch.float16; torch.half 10 | - numpy: np.int8 11 | pytorch: torch.int8 12 | - numpy: np.uint8 13 | pytorch: torch.uint8 14 | - numpy: np.int16 15 | pytorch: torch.int16; torch.short 16 | - numpy: np.int32 17 | pytorch: torch.int32; torch.int 18 | - numpy: np.int64 19 | pytorch: torch.int64; torch.long 20 | ones and zeros: 21 | - numpy: np.empty((2, 3)) 22 | pytorch: torch.empty(2, 3) 23 | - numpy: np.empty_like(x) 24 | pytorch: torch.empty_like(x) 25 | - numpy: np.eye 26 | pytorch: torch.eye 27 | - numpy: np.identity 28 | pytorch: torch.eye 29 | - numpy: np.ones 30 | pytorch: torch.ones 31 | - numpy: np.ones_like 32 | pytorch: torch.ones_like 33 | - numpy: np.zeros 34 | pytorch: torch.zeros 35 | - numpy: np.zeros_like 36 | pytorch: torch.zeros_like 37 | from existing data: 38 | - numpy: np.array([[1, 2], [3, 4]]) 39 | pytorch: torch.tensor([[1, 2], [3, 4]]) 40 | - numpy: | 41 | np.array([3.2, 4.3], dtype=np.float16) 42 | np.float16([3.2, 4.3]) 43 | pytorch: torch.tensor([3.2, 4.3], dtype=torch.float16) 44 | - numpy: x.copy() 45 | pytorch: x.clone() 46 | - numpy: x.astype(np.float32) 47 | pytorch: x.type(torch.float32); x.float() 48 | - numpy: 49 | content: np.fromfile(file) 50 | skip_test: true 51 | pytorch: 52 | content: torch.tensor(torch.Storage(file)) 53 | skip_test: true 54 | - numpy: np.frombuffer 55 | pytorch: 56 | - numpy: np.fromfunction 57 | pytorch: 58 | - numpy: np.fromiter 59 | pytorch: 60 | - numpy: np.fromstring 61 | pytorch: 62 | - numpy: np.load 63 | pytorch: torch.load 64 | - numpy: np.loadtxt 65 | pytorch: 66 | - numpy: np.concatenate 67 | pytorch: torch.cat 68 | numerical ranges: 69 | - numpy: np.arange(10) 70 | pytorch: torch.arange(10) 71 | - numpy: np.arange(2, 3, 0.1) 72 | pytorch: torch.arange(2, 3, 0.1) 73 | - numpy: np.linspace 74 | pytorch: torch.linspace 75 | - numpy: np.logspace 76 | pytorch: torch.logspace 77 | linear algebra: 78 | - numpy: np.dot 79 | pytorch: | 80 | torch.dot # 1D arrays only 81 | torch.mm # 2D arrays only 82 | torch.mv # matrix-vector (2D x 1D) 83 | - numpy: np.matmul 84 | pytorch: torch.matmul 85 | - numpy: np.tensordot 86 | pytorch: torch.tensordot 87 | - numpy: np.einsum 88 | pytorch: torch.einsum 89 | 90 | building matrices: 91 | - numpy: np.diag 92 | pytorch: torch.diag 93 | - numpy: np.tril 94 | pytorch: torch.tril 95 | - numpy: np.triu 96 | pytorch: torch.triu 97 | attributes: 98 | - numpy: x.shape 99 | pytorch: x.shape; x.size() 100 | - numpy: x.strides 101 | pytorch: x.stride() 102 | - numpy: x.ndim 103 | pytorch: x.dim() 104 | - numpy: x.data 105 | pytorch: x.data 106 | - numpy: x.size 107 | pytorch: x.nelement() 108 | - numpy: x.dtype 109 | pytorch: x.dtype 110 | indexing: 111 | - numpy: x[0] 112 | pytorch: x[0] 113 | - numpy: x[:, 0] 114 | pytorch: x[:, 0] 115 | - numpy: 116 | content: x[indices] 117 | skip_test: true 118 | pytorch: 119 | content: x[indices] 120 | skip_test: true 121 | - numpy: 122 | content: np.take(x, indices) 123 | skip_test: true 124 | pytorch: 125 | content: torch.take(x, torch.LongTensor(indices)) 126 | skip_test: true 127 | - numpy: x[x != 0] 128 | pytorch: x[x != 0] 129 | shape manipulation: 130 | - numpy: x.reshape 131 | pytorch: x.reshape; x.view 132 | - numpy: x.resize() 133 | pytorch: x.resize_ 134 | - numpy: 135 | pytorch: x.resize_as_ 136 | - numpy: | 137 | x = np.arange(6).reshape(3, 2, 1) 138 | x.transpose(2, 0, 1) # 012 -> 201 139 | pytorch: | 140 | x = torch.arange(6).reshape(3, 2, 1) 141 | x.permute(2, 0, 1); x.transpose(1, 2).transpose(0, 1) # 012 -> 021 -> 201 142 | - numpy: x.flatten 143 | pytorch: x.view(-1) 144 | - numpy: x.squeeze() 145 | pytorch: x.squeeze() 146 | - numpy: x[:, None]; np.expand_dims(x, 1) 147 | pytorch: x[:, None]; x.unsqueeze(1) 148 | item selection and manipulation: 149 | - numpy: np.put 150 | pytorch: 151 | - numpy: x.put 152 | pytorch: x.put_ 153 | - numpy: | 154 | x = np.array([1, 2, 3]) 155 | x.repeat(2) # [1, 1, 2, 2, 3, 3] 156 | pytorch: | 157 | x = torch.tensor([1, 2, 3]) 158 | x.repeat_interleave(2) # [1, 1, 2, 2, 3, 3] 159 | x.repeat(2) # [1, 2, 3, 1, 2, 3] 160 | x.repeat(2).reshape(2, -1).transpose(1, 0).reshape(-1) 161 | # [1, 1, 2, 2, 3, 3] 162 | - numpy: np.tile(x, (3, 2)) 163 | pytorch: x.repeat(3, 2) 164 | - numpy: | 165 | x = np.array([[0, 1], [2, 3], [4, 5]]) 166 | idxs = np.array([0, 2]) 167 | np.choose(idxs, x) # [0, 5] 168 | pytorch: | 169 | x = torch.tensor([[0, 1], [2, 3], [4, 5]]) 170 | idxs = torch.tensor([0, 2]) 171 | x[idxs, torch.arange(x.shape[1])] # [0, 5] 172 | torch.gather(x, 0, idxs[None, :])[0] # [0, 5] 173 | - numpy: np.sort 174 | pytorch: 175 | content: sorted, indices = torch.sort(x, [dim]) 176 | skip_test: true 177 | - numpy: np.argsort 178 | pytorch: 179 | content: sorted, indices = torch.sort(x, [dim]) 180 | skip_test: true 181 | - numpy: np.nonzero 182 | pytorch: torch.nonzero 183 | - numpy: np.where 184 | pytorch: torch.where 185 | - numpy: x[::-1] 186 | pytorch: torch.flip(x, [0]) 187 | - numpy: np.unique(x) 188 | pytorch: torch.unique(x) 189 | calculation: 190 | - numpy: x.min 191 | pytorch: x.min 192 | - numpy: x.argmin 193 | pytorch: x.argmin 194 | - numpy: x.max 195 | pytorch: x.max 196 | - numpy: x.argmax 197 | pytorch: x.argmax 198 | - numpy: x.clip 199 | pytorch: x.clamp 200 | - numpy: x.round 201 | pytorch: x.round 202 | - numpy: np.floor(x) 203 | pytorch: 204 | content: torch.floor(x); x.floor() 205 | skip_test: true 206 | - numpy: np.ceil(x) 207 | pytorch: 208 | content: torch.ceil(x); x.ceil() 209 | skip_test: true 210 | - numpy: x.trace 211 | pytorch: x.trace 212 | - numpy: x.sum 213 | pytorch: x.sum 214 | - numpy: x.sum(axis=0) 215 | pytorch: x.sum(0) 216 | - numpy: x.cumsum 217 | pytorch: x.cumsum 218 | - numpy: x.mean 219 | pytorch: x.mean 220 | - numpy: x.std 221 | pytorch: x.std 222 | - numpy: x.prod 223 | pytorch: x.prod 224 | - numpy: x.cumprod 225 | pytorch: x.cumprod 226 | - numpy: x.all 227 | pytorch: x.all 228 | - numpy: x.any 229 | pytorch: x.any 230 | arithmetic and comparison operations: 231 | - numpy: np.less 232 | pytorch: x.lt 233 | - numpy: np.less_equal 234 | pytorch: x.le 235 | - numpy: np.greater 236 | pytorch: x.gt 237 | - numpy: np.greater_equal 238 | pytorch: x.ge 239 | - numpy: np.equal 240 | pytorch: x.eq 241 | - numpy: np.not_equal 242 | pytorch: x.ne 243 | random numbers: 244 | - numpy: np.random.seed 245 | pytorch: torch.manual_seed 246 | - numpy: np.random.permutation(5) 247 | pytorch: torch.randperm(5) 248 | numerical operations: 249 | - numpy: np.sign 250 | pytorch: torch.sign 251 | - numpy: np.sqrt 252 | pytorch: torch.sqrt 253 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # PyTorch for Numpy users. 3 | 4 | ![ci](https://github.com/wkentaro/pytorch-for-numpy-users/workflows/ci/badge.svg) 5 | ![gh-pages](https://img.shields.io/github/deployments/wkentaro/pytorch-for-numpy-users/github-pages?label=gh-pages) 6 | 7 | 8 | [PyTorch](https://github.com/pytorch/pytorch.git) version of [_Torch for Numpy users_](https://github.com/torch/torch7/wiki/Torch-for-Numpy-users). 9 | We assume you use the latest PyTorch and Numpy. 10 | 11 | 12 | ## How to contribute? 13 | 14 | ```bash 15 | git clone https://github.com/wkentaro/pytorch-for-numpy-users.git 16 | cd pytorch-for-numpy-users 17 | vim conversions.yaml 18 | git commit -m "Update conversions.yaml" 19 | 20 | ./run_tests.py 21 | ``` 22 | 23 | 24 |

Types

25 | 26 | 27 | 28 | 29 | 30 | 33 | 36 | 39 | 42 | 45 | 48 | 51 | 54 | 57 | 58 |
Numpy PyTorch
 31 | np.ndarray
 32 | torch.Tensor
 34 | np.float32
 35 | torch.float32; torch.float
 37 | np.float64
 38 | torch.float64; torch.double
 40 | np.float16
 41 | torch.float16; torch.half
 43 | np.int8
 44 | torch.int8
 46 | np.uint8
 47 | torch.uint8
 49 | np.int16
 50 | torch.int16; torch.short
 52 | np.int32
 53 | torch.int32; torch.int
 55 | np.int64
 56 | torch.int64; torch.long
59 |

Ones and zeros

60 | 61 | 62 | 63 | 64 | 65 | 68 | 71 | 74 | 77 | 80 | 83 | 86 | 89 | 90 |
Numpy PyTorch
 66 | np.empty((2, 3))
 67 | torch.empty(2, 3)
 69 | np.empty_like(x)
 70 | torch.empty_like(x)
 72 | np.eye
 73 | torch.eye
 75 | np.identity
 76 | torch.eye
 78 | np.ones
 79 | torch.ones
 81 | np.ones_like
 82 | torch.ones_like
 84 | np.zeros
 85 | torch.zeros
 87 | np.zeros_like
 88 | torch.zeros_like
91 |

From existing data

92 | 93 | 94 | 95 | 96 | 97 | 100 | 105 | 108 | 111 | 114 | 116 | 118 | 120 | 122 | 125 | 127 | 130 | 131 |
Numpy PyTorch
 98 | np.array([[1, 2], [3, 4]])
 99 | torch.tensor([[1, 2], [3, 4]])
101 | np.array([3.2, 4.3], dtype=np.float16)
102 | np.float16([3.2, 4.3])
103 | 
104 | torch.tensor([3.2, 4.3], dtype=torch.float16)
106 | x.copy()
107 | x.clone()
109 | x.astype(np.float32)
110 | x.type(torch.float32); x.float()
112 | np.fromfile(file)
113 | torch.tensor(torch.Storage(file))
115 | np.frombuffer
117 | np.fromfunction
119 | np.fromiter
121 | np.fromstring
123 | np.load
124 | torch.load
126 | np.loadtxt
128 | np.concatenate
129 | torch.cat
132 |

Numerical ranges

133 | 134 | 135 | 136 | 137 | 138 | 141 | 144 | 147 | 150 | 151 |
Numpy PyTorch
139 | np.arange(10)
140 | torch.arange(10)
142 | np.arange(2, 3, 0.1)
143 | torch.arange(2, 3, 0.1)
145 | np.linspace
146 | torch.linspace
148 | np.logspace
149 | torch.logspace
152 |

Linear algebra

153 | 154 | 155 | 156 | 157 | 158 | 164 | 167 | 170 | 173 | 174 |
Numpy PyTorch
159 | np.dot
160 | torch.dot   # 1D arrays only
161 | torch.mm    # 2D arrays only
162 | torch.mv    # matrix-vector (2D x 1D)
163 | 
165 | np.matmul
166 | torch.matmul
168 | np.tensordot
169 | torch.tensordot
171 | np.einsum
172 | torch.einsum
175 |

Building matrices

176 | 177 | 178 | 179 | 180 | 181 | 184 | 187 | 190 | 191 |
Numpy PyTorch
182 | np.diag
183 | torch.diag
185 | np.tril
186 | torch.tril
188 | np.triu
189 | torch.triu
192 |

Attributes

193 | 194 | 195 | 196 | 197 | 198 | 201 | 204 | 207 | 210 | 213 | 216 | 217 |
Numpy PyTorch
199 | x.shape
200 | x.shape; x.size()
202 | x.strides
203 | x.stride()
205 | x.ndim
206 | x.dim()
208 | x.data
209 | x.data
211 | x.size
212 | x.nelement()
214 | x.dtype
215 | x.dtype
218 |

Indexing

219 | 220 | 221 | 222 | 223 | 224 | 227 | 230 | 233 | 236 | 239 | 240 |
Numpy PyTorch
225 | x[0]
226 | x[0]
228 | x[:, 0]
229 | x[:, 0]
231 | x[indices]
232 | x[indices]
234 | np.take(x, indices)
235 | torch.take(x, torch.LongTensor(indices))
237 | x[x != 0]
238 | x[x != 0]
241 |

Shape manipulation

242 | 243 | 244 | 245 | 246 | 247 | 250 | 253 | 255 | 262 | 265 | 268 | 271 | 272 |
Numpy PyTorch
248 | x.reshape
249 | x.reshape; x.view
251 | x.resize()
252 | x.resize_
254 | x.resize_as_
256 | x = np.arange(6).reshape(3, 2, 1)
257 | x.transpose(2, 0, 1)  # 012 -> 201
258 | 
259 | x = torch.arange(6).reshape(3, 2, 1)
260 | x.permute(2, 0, 1); x.transpose(1, 2).transpose(0, 1)  # 012 -> 021 -> 201
261 | 
263 | x.flatten
264 | x.view(-1)
266 | x.squeeze()
267 | x.squeeze()
269 | x[:, None]; np.expand_dims(x, 1)
270 | x[:, None]; x.unsqueeze(1)
273 |

Item selection and manipulation

274 | 275 | 276 | 277 | 278 | 279 | 281 | 284 | 294 | 297 | 307 | 310 | 313 | 316 | 319 | 322 | 325 | 326 |
Numpy PyTorch
280 | np.put
282 | x.put
283 | x.put_
285 | x = np.array([1, 2, 3])
286 | x.repeat(2)  # [1, 1, 2, 2, 3, 3]
287 | 
288 | x = torch.tensor([1, 2, 3])
289 | x.repeat_interleave(2)  # [1, 1, 2, 2, 3, 3]
290 | x.repeat(2)  # [1, 2, 3, 1, 2, 3]
291 | x.repeat(2).reshape(2, -1).transpose(1, 0).reshape(-1)
292 | # [1, 1, 2, 2, 3, 3]
293 | 
295 | np.tile(x, (3, 2))
296 | x.repeat(3, 2)
298 | x = np.array([[0, 1], [2, 3], [4, 5]])
299 | idxs = np.array([0, 2])
300 | np.choose(idxs, x) # [0, 5]
301 | 
302 | x = torch.tensor([[0, 1], [2, 3], [4, 5]])
303 | idxs = torch.tensor([0, 2])
304 | x[idxs, torch.arange(x.shape[1])] # [0, 5]
305 | torch.gather(x, 0, idxs[None, :])[0] # [0, 5]
306 | 
308 | np.sort
309 | sorted, indices = torch.sort(x, [dim])
311 | np.argsort
312 | sorted, indices = torch.sort(x, [dim])
314 | np.nonzero
315 | torch.nonzero
317 | np.where
318 | torch.where
320 | x[::-1]
321 | torch.flip(x, [0])
323 | np.unique(x)
324 | torch.unique(x)
327 |

Calculation

328 | 329 | 330 | 331 | 332 | 333 | 336 | 339 | 342 | 345 | 348 | 351 | 354 | 357 | 360 | 363 | 366 | 369 | 372 | 375 | 378 | 381 | 384 | 387 | 388 |
Numpy PyTorch
334 | x.min
335 | x.min
337 | x.argmin
338 | x.argmin
340 | x.max
341 | x.max
343 | x.argmax
344 | x.argmax
346 | x.clip
347 | x.clamp
349 | x.round
350 | x.round
352 | np.floor(x)
353 | torch.floor(x); x.floor()
355 | np.ceil(x)
356 | torch.ceil(x); x.ceil()
358 | x.trace
359 | x.trace
361 | x.sum
362 | x.sum
364 | x.sum(axis=0)
365 | x.sum(0)
367 | x.cumsum
368 | x.cumsum
370 | x.mean
371 | x.mean
373 | x.std
374 | x.std
376 | x.prod
377 | x.prod
379 | x.cumprod
380 | x.cumprod
382 | x.all
383 | x.all
385 | x.any
386 | x.any
389 |

Arithmetic and comparison operations

390 | 391 | 392 | 393 | 394 | 395 | 398 | 401 | 404 | 407 | 410 | 413 | 414 |
Numpy PyTorch
396 | np.less
397 | x.lt
399 | np.less_equal
400 | x.le
402 | np.greater
403 | x.gt
405 | np.greater_equal
406 | x.ge
408 | np.equal
409 | x.eq
411 | np.not_equal
412 | x.ne
415 |

Random numbers

416 | 417 | 418 | 419 | 420 | 421 | 424 | 427 | 428 |
Numpy PyTorch
422 | np.random.seed
423 | torch.manual_seed
425 | np.random.permutation(5)
426 | torch.randperm(5)
429 |

Numerical operations

430 | 431 | 432 | 433 | 434 | 435 | 438 | 441 | 442 |
Numpy PyTorch
436 | np.sign
437 | torch.sign
439 | np.sqrt
440 | torch.sqrt
443 | 444 | -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 15 | 49 | 50 |
51 |

PyTorch for Numpy Users.

52 | Tweet 53 | 54 |

55 | A cheatsheet for Numpy users to use PyTorch. To add/edit an entry, see it on 56 | GitHub. 57 |

58 | 59 |
60 |
61 |

Types

62 |
63 | 64 | 65 | 66 | 67 | 70 | 73 | 76 | 79 | 82 | 85 | 88 | 91 | 94 | 95 |
Numpy PyTorch
 68 | np.ndarray
 69 | torch.Tensor
 71 | np.float32
 72 | torch.float32; torch.float
 74 | np.float64
 75 | torch.float64; torch.double
 77 | np.float16
 78 | torch.float16; torch.half
 80 | np.int8
 81 | torch.int8
 83 | np.uint8
 84 | torch.uint8
 86 | np.int16
 87 | torch.int16; torch.short
 89 | np.int32
 90 | torch.int32; torch.int
 92 | np.int64
 93 | torch.int64; torch.long
96 |
97 |
98 |

Ones and zeros

99 |
100 | 101 | 102 | 103 | 104 | 107 | 110 | 113 | 116 | 119 | 122 | 125 | 128 | 129 |
Numpy PyTorch
105 | np.empty((2, 3))
106 | torch.empty(2, 3)
108 | np.empty_like(x)
109 | torch.empty_like(x)
111 | np.eye
112 | torch.eye
114 | np.identity
115 | torch.eye
117 | np.ones
118 | torch.ones
120 | np.ones_like
121 | torch.ones_like
123 | np.zeros
124 | torch.zeros
126 | np.zeros_like
127 | torch.zeros_like
130 |
131 |
132 |

From existing data

133 |
134 | 135 | 136 | 137 | 138 | 141 | 146 | 149 | 152 | 155 | 157 | 159 | 161 | 163 | 166 | 168 | 171 | 172 |
Numpy PyTorch
139 | np.array([[1, 2], [3, 4]])
140 | torch.tensor([[1, 2], [3, 4]])
142 | np.array([3.2, 4.3], dtype=np.float16)
143 | np.float16([3.2, 4.3])
144 | 
145 | torch.tensor([3.2, 4.3], dtype=torch.float16)
147 | x.copy()
148 | x.clone()
150 | x.astype(np.float32)
151 | x.type(torch.float32); x.float()
153 | np.fromfile(file)
154 | torch.tensor(torch.Storage(file))
156 | np.frombuffer
158 | np.fromfunction
160 | np.fromiter
162 | np.fromstring
164 | np.load
165 | torch.load
167 | np.loadtxt
169 | np.concatenate
170 | torch.cat
173 |
174 |
175 |

Numerical ranges

176 |
177 | 178 | 179 | 180 | 181 | 184 | 187 | 190 | 193 | 194 |
Numpy PyTorch
182 | np.arange(10)
183 | torch.arange(10)
185 | np.arange(2, 3, 0.1)
186 | torch.arange(2, 3, 0.1)
188 | np.linspace
189 | torch.linspace
191 | np.logspace
192 | torch.logspace
195 |
196 |
197 |

Linear algebra

198 |
199 | 200 | 201 | 202 | 203 | 209 | 212 | 215 | 218 | 219 |
Numpy PyTorch
204 | np.dot
205 | torch.dot   # 1D arrays only
206 | torch.mm    # 2D arrays only
207 | torch.mv    # matrix-vector (2D x 1D)
208 | 
210 | np.matmul
211 | torch.matmul
213 | np.tensordot
214 | torch.tensordot
216 | np.einsum
217 | torch.einsum
220 |
221 |
222 |

Building matrices

223 |
224 | 225 | 226 | 227 | 228 | 231 | 234 | 237 | 238 |
Numpy PyTorch
229 | np.diag
230 | torch.diag
232 | np.tril
233 | torch.tril
235 | np.triu
236 | torch.triu
239 |
240 |
241 |

Attributes

242 |
243 | 244 | 245 | 246 | 247 | 250 | 253 | 256 | 259 | 262 | 265 | 266 |
Numpy PyTorch
248 | x.shape
249 | x.shape; x.size()
251 | x.strides
252 | x.stride()
254 | x.ndim
255 | x.dim()
257 | x.data
258 | x.data
260 | x.size
261 | x.nelement()
263 | x.dtype
264 | x.dtype
267 |
268 |
269 |

Indexing

270 |
271 | 272 | 273 | 274 | 275 | 278 | 281 | 284 | 287 | 290 | 291 |
Numpy PyTorch
276 | x[0]
277 | x[0]
279 | x[:, 0]
280 | x[:, 0]
282 | x[indices]
283 | x[indices]
285 | np.take(x, indices)
286 | torch.take(x, torch.LongTensor(indices))
288 | x[x != 0]
289 | x[x != 0]
292 |
293 |
294 |

Shape manipulation

295 |
296 | 297 | 298 | 299 | 300 | 303 | 306 | 308 | 315 | 318 | 321 | 324 | 325 |
Numpy PyTorch
301 | x.reshape
302 | x.reshape; x.view
304 | x.resize()
305 | x.resize_
307 | x.resize_as_
309 | x = np.arange(6).reshape(3, 2, 1)
310 | x.transpose(2, 0, 1)  # 012 -> 201
311 | 
312 | x = torch.arange(6).reshape(3, 2, 1)
313 | x.permute(2, 0, 1); x.transpose(1, 2).transpose(0, 1)  # 012 -> 021 -> 201
314 | 
316 | x.flatten
317 | x.view(-1)
319 | x.squeeze()
320 | x.squeeze()
322 | x[:, None]; np.expand_dims(x, 1)
323 | x[:, None]; x.unsqueeze(1)
326 |
327 |
328 |

Item selection and manipulation

329 |
330 | 331 | 332 | 333 | 334 | 336 | 339 | 349 | 352 | 362 | 365 | 368 | 371 | 374 | 377 | 380 | 381 |
Numpy PyTorch
335 | np.put
337 | x.put
338 | x.put_
340 | x = np.array([1, 2, 3])
341 | x.repeat(2)  # [1, 1, 2, 2, 3, 3]
342 | 
343 | x = torch.tensor([1, 2, 3])
344 | x.repeat_interleave(2)  # [1, 1, 2, 2, 3, 3]
345 | x.repeat(2)  # [1, 2, 3, 1, 2, 3]
346 | x.repeat(2).reshape(2, -1).transpose(1, 0).reshape(-1)
347 | # [1, 1, 2, 2, 3, 3]
348 | 
350 | np.tile(x, (3, 2))
351 | x.repeat(3, 2)
353 | x = np.array([[0, 1], [2, 3], [4, 5]])
354 | idxs = np.array([0, 2])
355 | np.choose(idxs, x) # [0, 5]
356 | 
357 | x = torch.tensor([[0, 1], [2, 3], [4, 5]])
358 | idxs = torch.tensor([0, 2])
359 | x[idxs, torch.arange(x.shape[1])] # [0, 5]
360 | torch.gather(x, 0, idxs[None, :])[0] # [0, 5]
361 | 
363 | np.sort
364 | sorted, indices = torch.sort(x, [dim])
366 | np.argsort
367 | sorted, indices = torch.sort(x, [dim])
369 | np.nonzero
370 | torch.nonzero
372 | np.where
373 | torch.where
375 | x[::-1]
376 | torch.flip(x, [0])
378 | np.unique(x)
379 | torch.unique(x)
382 |
383 |
384 |

Calculation

385 |
386 | 387 | 388 | 389 | 390 | 393 | 396 | 399 | 402 | 405 | 408 | 411 | 414 | 417 | 420 | 423 | 426 | 429 | 432 | 435 | 438 | 441 | 444 | 445 |
Numpy PyTorch
391 | x.min
392 | x.min
394 | x.argmin
395 | x.argmin
397 | x.max
398 | x.max
400 | x.argmax
401 | x.argmax
403 | x.clip
404 | x.clamp
406 | x.round
407 | x.round
409 | np.floor(x)
410 | torch.floor(x); x.floor()
412 | np.ceil(x)
413 | torch.ceil(x); x.ceil()
415 | x.trace
416 | x.trace
418 | x.sum
419 | x.sum
421 | x.sum(axis=0)
422 | x.sum(0)
424 | x.cumsum
425 | x.cumsum
427 | x.mean
428 | x.mean
430 | x.std
431 | x.std
433 | x.prod
434 | x.prod
436 | x.cumprod
437 | x.cumprod
439 | x.all
440 | x.all
442 | x.any
443 | x.any
446 |
447 |
448 |

Arithmetic and comparison operations

449 |
450 | 451 | 452 | 453 | 454 | 457 | 460 | 463 | 466 | 469 | 472 | 473 |
Numpy PyTorch
455 | np.less
456 | x.lt
458 | np.less_equal
459 | x.le
461 | np.greater
462 | x.gt
464 | np.greater_equal
465 | x.ge
467 | np.equal
468 | x.eq
470 | np.not_equal
471 | x.ne
474 |
475 |
476 |

Random numbers

477 |
478 | 479 | 480 | 481 | 482 | 485 | 488 | 489 |
Numpy PyTorch
483 | np.random.seed
484 | torch.manual_seed
486 | np.random.permutation(5)
487 | torch.randperm(5)
490 |
491 |
492 |

Numerical operations

493 |
494 | 495 | 496 | 497 | 498 | 501 | 504 | 505 |
Numpy PyTorch
499 | np.sign
500 | torch.sign
502 | np.sqrt
503 | torch.sqrt
506 |
507 |
508 |
509 | 510 | 511 | 512 | --------------------------------------------------------------------------------