├── CHANGELOG.md ├── README.md ├── examples ├── agent.py ├── llama3.py ├── llama3_70b.py ├── mistral_7b.py ├── qwen1_5.py └── qwen1_5_int8.py └── minit ├── __init__.py ├── chat ├── agent.py ├── sample.py ├── server.py ├── template.py └── tokenizer.py ├── collective ├── dispatch.py ├── group.py ├── operator.py ├── spec.py └── tensor.py ├── compiler ├── cache.py ├── cxx.py ├── gcc.py ├── python.py └── template.py ├── core ├── __init__.py ├── array.py ├── cache.py ├── device_operator.py ├── dispatch.py ├── dtype.py ├── meta.py ├── object.py ├── operator.py ├── scalar.py ├── shape.py ├── tensor.py └── torch.py ├── cuda ├── __init__.py ├── allocator.py ├── compiler.py ├── dispatch.py ├── kernel │ ├── cublas.py │ ├── cuda │ │ ├── broadcast.py │ │ ├── cast.py │ │ ├── elemwise.py │ │ ├── fill.py │ │ ├── generate.py │ │ ├── index.py │ │ ├── reduce.py │ │ ├── rms_norm.py │ │ ├── rope.py │ │ ├── slice.py │ │ ├── softmax.py │ │ ├── transpose.py │ │ └── triangle.py │ └── utils.py ├── lib │ └── cuda_runtime.py ├── tensor.py └── toolkit.py ├── distributed ├── communicator.py ├── group.py └── operator.py ├── functional ├── __init__.py ├── arith.py ├── control_flow.py ├── einops.py ├── generate.py ├── index.py ├── linalg.py ├── reduce.py ├── shape.py ├── special.py └── utils.py ├── graph ├── __init__.py └── optimize.py ├── lazy ├── __init__.py ├── dispatch.py └── tensor.py ├── module ├── __init__.py ├── checkpoint.py ├── list.py └── module.py ├── nccl ├── dispatch.py ├── kernel.py ├── library.py ├── server.py └── tensor.py ├── operator ├── __init__.py ├── arith.py ├── control_flow.py ├── generate.py ├── index.py ├── linalg.py ├── memory.py ├── random.py ├── reduce.py ├── shape.py └── special.py ├── quantize ├── dispatch.py ├── functional.py ├── operator.py └── tensor.py ├── remote ├── __init__.py ├── actor.py ├── channel.py ├── controller.py ├── function.py ├── object.py ├── registry.py ├── utils.py └── value.py ├── trace ├── dispatch.py ├── executor.py ├── function.py └── tensor.py └── triton └── __init__.py /CHANGELOG.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GetUpEarlier/minit/48d227e638c0316cf998295f7638f909fbc1b9f6/CHANGELOG.md -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Project MiniT 2 | 3 | 一个轻量的推理框架 4 | 5 | ## Examples 6 | 目前验证过Mistral-7B,Qwen-1.5及llama3-70B等类llama模型,examples里有完整代码 7 | 8 | ## Features 9 | - Tensor定义及相关接口 10 | - 基础CUDA算子 11 | - GPTQ推理支持 12 | - 基于nccl的分布式支持 13 | - function trace及graph pattern match功能 14 | - DTensor接口 15 | 16 | ## Dev Plan 17 | 持续开发中 18 | 19 | 计划支持大部分常见特性,包括分布式,量化,LoDTensor,trace等 20 | 21 | 随缘更新 22 | -------------------------------------------------------------------------------- /examples/agent.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Generator, Generic, List, Optional, TypeVar 3 | 4 | from minit.core.tensor import Tensor 5 | 6 | 7 | _Session = TypeVar("_Session") 8 | 9 | 10 | class Server(Generic[_Session]): 11 | @abstractmethod 12 | def decode(self, session: _Session, input_ids: List[int]) -> Tensor: 13 | ... 14 | 15 | @abstractmethod 16 | def create_session(self) -> _Session: 17 | ... 18 | 19 | 20 | class Tokenizer: 21 | @abstractmethod 22 | def tokenize(self, text: str) -> List[int]: 23 | ... 24 | 25 | @abstractmethod 26 | def detokenize(self, ids: List[int]) -> str: 27 | ... 28 | 29 | 30 | class Template: 31 | @abstractmethod 32 | def generate_prompt(self, prompt: str) -> str: 33 | ... 34 | 35 | @abstractmethod 36 | def eos(self) -> str: 37 | ... 38 | 39 | @abstractmethod 40 | def first_chat(self, chat: str) -> str: 41 | ... 42 | 43 | @abstractmethod 44 | def next_chat(self, chat: str) -> str: 45 | ... 46 | 47 | 48 | class Sampler: 49 | @abstractmethod 50 | def sample(self, probs: Tensor) -> int: 51 | ... 52 | 53 | 54 | class Agent: 55 | server: Server 56 | tokenizer: Tokenizer 57 | template: Template 58 | sampler: Sampler 59 | 60 | def __init__(self, server: Server, tokenizer: Tokenizer, template: Template, sampler: Sampler) -> None: 61 | self.server = server 62 | self.tokenizer = tokenizer 63 | self.template = template 64 | self.sampler = sampler 65 | 66 | def chat(self, prompt: str) -> Generator[Optional[str], str, None]: 67 | session = self.server.create_session() 68 | text = self.template.generate_prompt(prompt) 69 | question = yield None 70 | text += self.template.first_chat(question) 71 | end_chat = False 72 | while True: 73 | output_id = None 74 | response = "" 75 | while True: 76 | if output_id is not None: 77 | input_ids = [output_id] 78 | else: 79 | input_ids = self.tokenizer.tokenize(text) 80 | output_probs = self.server.decode(session, input_ids) 81 | output_id = self.sampler.sample(output_probs) 82 | text = self.tokenizer.detokenize([output_id]) 83 | response += text 84 | if text == self.template.eos(): 85 | end_chat = True 86 | if end_chat: 87 | end_chat = False 88 | break 89 | pong = yield text 90 | assert pong is None 91 | question = yield None 92 | assert question is not None 93 | text = self.template.next_chat(question) 94 | 95 | 96 | def chat_cmdline(agent: Agent, prompt: str): 97 | chat = agent.chat(prompt) 98 | assert chat.send(None) is None 99 | while True: 100 | print("User:", end="\t", flush=True) 101 | request = input() 102 | print("Assistent:", end="\t", flush=True) 103 | response = chat.send(request) 104 | while response is not None: 105 | print(response, end="", flush=True) 106 | response = chat.send(None) 107 | print() 108 | -------------------------------------------------------------------------------- /minit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GetUpEarlier/minit/48d227e638c0316cf998295f7638f909fbc1b9f6/minit/__init__.py -------------------------------------------------------------------------------- /minit/chat/agent.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional 2 | 3 | from ..core.tensor import Tensor 4 | from .server import Server 5 | from .tokenizer import Tokenizer 6 | from .template import Template 7 | from .sample import Sampler 8 | 9 | 10 | class Agent: 11 | server: Server 12 | tokenizer: Tokenizer 13 | template: Template 14 | sampler: Sampler 15 | 16 | class Session: 17 | decoder: Callable[[List[int]], Tensor] 18 | tokenizer: Tokenizer 19 | template: Template 20 | sampler: Sampler 21 | 22 | def __init__(self, server: Server, tokenizer: Tokenizer, template: Template, sampler: Sampler, prompt: Optional[str]) -> None: 23 | session = server.create_session() 24 | def decoder(input_ids: List[int]): 25 | return server.decode(session, input_ids) 26 | self.decoder = decoder 27 | header = template.generate_header() 28 | if prompt is not None: 29 | header += template.generate_prompt(prompt) 30 | if len(header) > 0: 31 | self.decoder(tokenizer.tokenize(header)) 32 | self.template = template 33 | self.sampler = sampler 34 | self.tokenizer = tokenizer 35 | 36 | def interact(self, user: str) -> str: 37 | prefix, postfix = self.template.generate(user) 38 | assistant = "" 39 | eos = self.template.eos() 40 | input_ids = self.tokenizer.tokenize(prefix) 41 | while True: 42 | output_probs = self.decoder(input_ids) 43 | output_id = self.sampler.sample(output_probs) 44 | output = self.tokenizer.detokenize([output_id]) 45 | if output == eos: 46 | break 47 | assistant += output 48 | input_ids = [output_id] 49 | self.decoder(self.tokenizer.tokenize(postfix)) 50 | return assistant 51 | 52 | def __init__(self, server: Server, tokenizer: Tokenizer, template: Template, sampler: Sampler) -> None: 53 | self.server = server 54 | self.tokenizer = tokenizer 55 | self.template = template 56 | self.sampler = sampler 57 | 58 | def create_session(self, prompt: str) -> Session: 59 | return Agent.Session(self.server, self.tokenizer, self.template, self.sampler, prompt) 60 | -------------------------------------------------------------------------------- /minit/chat/sample.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from ..core.tensor import Tensor 3 | 4 | 5 | class Sampler: 6 | @abstractmethod 7 | def sample(self, probs: Tensor) -> int: 8 | ... 9 | 10 | 11 | class Top1Sampler(Sampler): 12 | def sample(self, probs: Tensor) -> int: 13 | output_id = probs.numpy().argmax(axis=-1).item() 14 | return output_id 15 | -------------------------------------------------------------------------------- /minit/chat/server.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from dataclasses import dataclass, field 3 | from typing import Generic, List, Tuple, TypeVar 4 | 5 | from ..core.tensor import Tensor 6 | from ..functional.arith import constant 7 | 8 | 9 | _Session = TypeVar("_Session") 10 | 11 | 12 | class Server(Generic[_Session]): 13 | @abstractmethod 14 | def decode(self, session: _Session, input_ids: List[int]) -> Tensor: 15 | ... 16 | 17 | @abstractmethod 18 | def create_session(self) -> _Session: 19 | ... 20 | 21 | 22 | @dataclass 23 | class CacheInputSession: 24 | input_ids: List[int] = field(default_factory=list) 25 | 26 | 27 | @dataclass 28 | class CacheKeyValueSession: 29 | kv_cache_list: List[Tuple[Tensor, Tensor]] 30 | offset: int 31 | 32 | def __init__(self, kv_cache_list: List[Tuple[Tensor, Tensor]]): 33 | self.kv_cache_list = kv_cache_list 34 | self.offset = 0 35 | -------------------------------------------------------------------------------- /minit/chat/template.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Tuple 3 | 4 | 5 | class Template: 6 | @abstractmethod 7 | def generate_header(self) -> str: 8 | raise NotImplementedError() 9 | 10 | @abstractmethod 11 | def generate_prompt(self, prompt: str) -> str: 12 | raise NotImplementedError() 13 | 14 | @abstractmethod 15 | def generate(self, user: str) -> Tuple[str, str]: 16 | raise NotImplementedError() 17 | 18 | @abstractmethod 19 | def eos(self) -> str: 20 | raise NotImplementedError() 21 | -------------------------------------------------------------------------------- /minit/chat/tokenizer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import List 3 | 4 | 5 | class Tokenizer: 6 | @abstractmethod 7 | def tokenize(self, text: str) -> List[int]: 8 | ... 9 | 10 | @abstractmethod 11 | def detokenize(self, ids: List[int]) -> str: 12 | ... 13 | 14 | 15 | class SentencePieceTokenizer(Tokenizer): 16 | def __init__(self, path: str) -> None: 17 | super().__init__() 18 | import sentencepiece 19 | self.sentence_piece = sentencepiece.SentencePieceProcessor(path) 20 | 21 | def tokenize(self, text: str) -> List[int]: 22 | [input_ids,] = self.sentence_piece.tokenize([ 23 | text 24 | ]) 25 | return input_ids 26 | 27 | def detokenize(self, ids: List[int]) -> str: 28 | text = "" 29 | for id in ids: 30 | text += self.sentence_piece.id_to_piece(id) 31 | return text 32 | 33 | 34 | class HuggingFaceTokenizer(Tokenizer): 35 | def __init__(self, path: str) -> None: 36 | super().__init__() 37 | import transformers 38 | self.hugging_face = transformers.PreTrainedTokenizerFast(tokenizer_file=path) 39 | 40 | def tokenize(self, text: str) -> List[int]: 41 | input_ids = self.hugging_face.encode(text, add_special_tokens=False) 42 | return input_ids 43 | 44 | def detokenize(self, ids: List[int]) -> str: 45 | text = "" 46 | for id in ids: 47 | text += self.hugging_face.decode(id) 48 | return text 49 | -------------------------------------------------------------------------------- /minit/collective/dispatch.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Union 2 | 3 | from ..core.object import match_pattern 4 | from ..core.tensor import Tensor 5 | from ..core.dispatch import register_dispatch, dispatch 6 | from ..functional.linalg import matrix_multiply 7 | from .spec import CollectiveSpec, CollectiveSpecBroadcast, CollectiveSpecSplit 8 | from ..operator.linalg import MatrixMultiply 9 | from .tensor import CollectiveTensor 10 | from ..core.operator import Operator 11 | 12 | 13 | @register_dispatch(predicate=lambda *args: any(match_pattern(CollectiveTensor[CollectiveSpec], arg) is not None for arg in args), priority=1) 14 | def dispatch_collective(op: Operator, *args: Union[CollectiveTensor[CollectiveSpec], Tensor]): 15 | communicator = None 16 | for arg in args: 17 | if isinstance(arg, CollectiveTensor): 18 | communicator = arg._communicator 19 | local_args = [arg.to_broadcast()._local if isinstance(arg, CollectiveTensor) else arg for arg in args] 20 | for local_arg in local_args: 21 | assert not isinstance(local_arg, CollectiveTensor) 22 | local_outputs = dispatch(op, *local_args) 23 | return tuple([CollectiveTensor.from_broadcast(communicator, local_output) for local_output in local_outputs]) 24 | 25 | 26 | @register_dispatch(priority=2) 27 | def dispatch_collective(op: MatrixMultiply, x: CollectiveTensor[CollectiveSpec], y: CollectiveTensor[CollectiveSpecSplit[Literal[0]]]): 28 | communicator = x._communicator 29 | x = x.to_broadcast() 30 | return tuple([CollectiveTensor.from_split(communicator, matrix_multiply(x._local, y._local), 1)]) 31 | 32 | 33 | @register_dispatch(priority=2) 34 | def dispatch_collective(op: MatrixMultiply, x: CollectiveTensor[CollectiveSpec], y: CollectiveTensor[CollectiveSpecSplit[Literal[1]]]): 35 | communicator = x._communicator 36 | x = x.to_split(1) 37 | return tuple([CollectiveTensor.from_partial(communicator, matrix_multiply(x._local, y._local))]) 38 | 39 | 40 | @register_dispatch(priority=2) 41 | def dispatch_collective(op: MatrixMultiply, x: CollectiveTensor[CollectiveSpec], y: CollectiveTensor[CollectiveSpecBroadcast]): 42 | communicator = x._communicator 43 | x = x.to_broadcast() 44 | return tuple([CollectiveTensor.from_broadcast(communicator, matrix_multiply(x._local, y._local))]) 45 | -------------------------------------------------------------------------------- /minit/collective/group.py: -------------------------------------------------------------------------------- 1 | from ..distributed.group import DistributedGroup as CollectiveGroup, get_world 2 | -------------------------------------------------------------------------------- /minit/collective/operator.py: -------------------------------------------------------------------------------- 1 | class CollectiveSendRecv: 2 | ... 3 | 4 | class CollectiveBroadcast: 5 | ... 6 | 7 | class CollectiveSplit: 8 | ... 9 | 10 | class CollectiveAllGather: 11 | ... 12 | 13 | class CollectiveUnique: 14 | ... 15 | 16 | class CollectiveReduceScatter: 17 | ... 18 | 19 | class CollectiveScatter: 20 | ... 21 | 22 | class CollectiveAllToAll: 23 | ... 24 | -------------------------------------------------------------------------------- /minit/collective/spec.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | from typing import Generic, Literal, TypeVar, Union 4 | 5 | from ..core.object import Object 6 | 7 | 8 | _Axis = TypeVar("_Axis") 9 | _Rank = TypeVar("_Rank") 10 | 11 | 12 | @dataclass(frozen=True) 13 | class CollectiveSpecSplit(Object, Generic[_Axis]): 14 | axis: int 15 | 16 | def type(self): 17 | return CollectiveSpecSplit[Literal[self.axis]] # type: ignore 18 | 19 | @dataclass(frozen=True) 20 | class CollectiveSpecPartial(Object): 21 | def type(self): 22 | return CollectiveSpecPartial 23 | 24 | @dataclass(frozen=True) 25 | class CollectiveSpecBroadcast(Object): 26 | def type(self): 27 | return CollectiveSpecBroadcast 28 | 29 | @dataclass(frozen=True) 30 | class CollectiveSpecUnique(Object, Generic[_Rank]): 31 | rank: int 32 | 33 | def type(self): 34 | return CollectiveSpecUnique[Literal[self.rank]] # type: ignore 35 | 36 | CollectiveSpec = Union[ 37 | CollectiveSpecSplit[int], 38 | CollectiveSpecPartial, 39 | CollectiveSpecBroadcast, 40 | CollectiveSpecUnique[int], 41 | ] 42 | -------------------------------------------------------------------------------- /minit/collective/tensor.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, Optional, Tuple, TypeVar 2 | 3 | import numpy 4 | 5 | from ..core.meta import MetaTensor 6 | 7 | from ..distributed.communicator import DistributedCommunicator 8 | from .group import get_world 9 | from ..core.tensor import Tensor 10 | from .spec import CollectiveSpecBroadcast, CollectiveSpec, CollectiveSpecSplit, CollectiveSpecUnique, CollectiveSpecPartial 11 | 12 | 13 | _Spec = TypeVar("_Spec", bound=CollectiveSpec) 14 | 15 | 16 | class CollectiveTensor(Tensor, Generic[_Spec]): 17 | __slots__ = [ 18 | "_communicator", 19 | "_shape", 20 | "_local", 21 | "_spec", 22 | ] 23 | 24 | _communicator: DistributedCommunicator 25 | _shape: Tuple[Tensor, ...] 26 | _local: Optional[Tensor] 27 | _spec: _Spec 28 | 29 | def __init__(self, communicator: DistributedCommunicator, local: Optional[Tensor], spec: _Spec, shape: Tuple[Tensor, ...]) -> None: 30 | super().__init__() 31 | self._communicator = communicator 32 | self._local = local 33 | self._spec = spec 34 | self._shape = shape 35 | assert not isinstance(self._local, CollectiveTensor) 36 | 37 | @property 38 | def shape(self): 39 | return tuple([CollectiveTensor.from_broadcast(self._communicator, dim) for dim in self._shape]) 40 | 41 | @property 42 | def dtype(self): 43 | return self._local.dtype 44 | 45 | @property 46 | def spec(self): 47 | return self._spec 48 | 49 | @staticmethod 50 | def from_broadcast(communicator: DistributedCommunicator, local: Tensor) -> "CollectiveTensor": 51 | return CollectiveTensor(communicator, local, CollectiveSpecBroadcast(), local.shape) 52 | 53 | @staticmethod 54 | def from_split(communicator: DistributedCommunicator, local: Tensor, axis: int) -> "CollectiveTensor": 55 | shape = local.shape 56 | shape = shape[:axis] + ((shape[axis] * get_world().size),) + shape[axis+1:] 57 | return CollectiveTensor(communicator, local, CollectiveSpecSplit(axis), shape) 58 | 59 | @staticmethod 60 | def from_unique(communicator: DistributedCommunicator, local: Tensor, rank: int) -> "CollectiveTensor": 61 | return CollectiveTensor(communicator, local, CollectiveSpecUnique(rank), local.shape) 62 | 63 | @staticmethod 64 | def from_partial(communicator: DistributedCommunicator, local: Tensor) -> "CollectiveTensor": 65 | return CollectiveTensor(communicator, local, CollectiveSpecPartial(), local.shape) 66 | 67 | def to_broadcast(self) -> "CollectiveTensor": 68 | if isinstance(self.spec, CollectiveSpecBroadcast): 69 | return self 70 | elif isinstance(self.spec, CollectiveSpecSplit): 71 | return CollectiveTensor.from_broadcast(self._communicator, self._communicator.all_gather(self._local, self.spec.axis)) 72 | elif isinstance(self.spec, CollectiveSpecPartial): 73 | return CollectiveTensor.from_broadcast(self._communicator, self._communicator.all_reduce(self._local)) 74 | elif isinstance(self.spec, CollectiveSpecUnique): 75 | return CollectiveTensor.from_broadcast(self._communicator, self._communicator.broadcast(self._local, self.spec.rank)) 76 | else: 77 | assert False 78 | 79 | def numpy(self) -> numpy.array: 80 | return self.to_broadcast()._local.numpy() 81 | 82 | def to_partial(self): 83 | if isinstance(self.spec, CollectiveSpecPartial): 84 | return self 85 | self = self.to_broadcast() 86 | return CollectiveTensor.from_partial(self._communicator, self._local / get_world().size) 87 | 88 | def to_split(self, axis: int): 89 | if isinstance(self.spec, CollectiveSpecSplit) and self.spec.axis == axis: 90 | return self 91 | self = self.to_broadcast() 92 | rank = get_world().rank 93 | size = self._shape[axis] / get_world().size 94 | return CollectiveTensor.from_split(self._communicator, self._local.slice(size*rank, size*(rank+1), axis), axis) 95 | 96 | def to_unique(self, rank: int): 97 | if isinstance(self.spec, CollectiveSpecUnique) and self.spec.rank == rank: 98 | return self 99 | self = self.to_broadcast() 100 | return CollectiveTensor.from_unique(self._communicator, self._local if get_world().rank == rank else MetaTensor(self._local.shape, self._local.dtype), rank) 101 | 102 | def type(self): 103 | return CollectiveTensor[self._spec.type()] 104 | 105 | def item(self): 106 | assert self.spec == CollectiveSpecBroadcast() 107 | return self._local.item() 108 | 109 | @property 110 | def device(self): 111 | return self._local.device 112 | -------------------------------------------------------------------------------- /minit/compiler/cache.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import subprocess 4 | from typing import Dict, List 5 | import filelock 6 | 7 | 8 | def cached_execute(commands: List[str], files: Dict[str, str]) -> str: 9 | hash = hashlib.md5(str(commands).encode('utf-8')) 10 | hash.update(str(files).encode('utf-8')) 11 | md5_truncate = 8 12 | md5 = hash.hexdigest()[:md5_truncate] 13 | cached_path = os.path.join(os.path.expanduser("~"), ".minit", "cached", md5) 14 | lock_path = f"{cached_path}.filelock" 15 | with filelock.FileLock(lock_path): 16 | os.makedirs(cached_path, exist_ok=True) 17 | if os.path.exists(os.path.join(cached_path, ".success")): 18 | return cached_path 19 | print(f"compiling on {cached_path}") 20 | os.makedirs(cached_path, exist_ok=True) 21 | for name, content in files.items(): 22 | with open(os.path.join(cached_path, name), "w+") as f: 23 | f.write(content) 24 | if len(commands) > 0: 25 | subprocess.check_call(commands, cwd=cached_path) 26 | open(os.path.join(cached_path, ".success"), 'w+').close() 27 | print(f"compiling completed") 28 | return cached_path 29 | -------------------------------------------------------------------------------- /minit/compiler/cxx.py: -------------------------------------------------------------------------------- 1 | from ctypes import CDLL 2 | import ctypes 3 | from dataclasses import dataclass, field 4 | import inspect 5 | from types import FunctionType 6 | from typing import Any, Dict, List, Optional 7 | import nvtx 8 | 9 | 10 | def find_library(name: str) -> str: 11 | ... 12 | 13 | def find_include(name: str) -> str: 14 | ... 15 | 16 | 17 | @dataclass 18 | class CXXUnit: 19 | source: str 20 | includes: List[str] = field(default_factory=list) 21 | libraries: List[str] = field(default_factory=list) 22 | defines: List[str] = field(default_factory=list) 23 | 24 | 25 | class CXXLibrary: 26 | __slots__ = [ 27 | "library", 28 | "symbols", 29 | ] 30 | 31 | library: CDLL 32 | symbols: Dict[str, ctypes._CFuncPtr] 33 | 34 | def __init__(self, library: CDLL): 35 | self.library = library 36 | self.symbols = {} 37 | 38 | 39 | def import_symbol(cdll: CDLL, name: str): 40 | def decorator(function: FunctionType): 41 | arg_types = [] 42 | signature = inspect.signature(function) 43 | parameters = list(signature.parameters.values()) 44 | for param in parameters: 45 | assert param.kind in [param.POSITIONAL_ONLY , param.POSITIONAL_OR_KEYWORD] 46 | arg_types.append(param.annotation) 47 | if signature.return_annotation is signature.empty: 48 | res_type = None 49 | else: 50 | res_type = signature.return_annotation 51 | symbol = getattr(cdll, name) 52 | symbol.restype = res_type 53 | symbol.argtypes = arg_types 54 | # @nvtx.annotate(name) 55 | # def decorated(*args): 56 | # return symbol(*args) 57 | # return decorated 58 | return symbol 59 | return decorator 60 | 61 | 62 | class CXXCompiler: 63 | def compile(self, unit: CXXUnit) -> CDLL: 64 | ... 65 | -------------------------------------------------------------------------------- /minit/compiler/gcc.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import os 3 | 4 | from .cache import cached_execute 5 | from .cxx import CXXCompiler, CXXLibrary, CXXUnit 6 | 7 | 8 | class GCC(CXXCompiler): 9 | def compile(self, unit: CXXUnit) -> CXXLibrary: 10 | commands = ["g++"] 11 | for include in unit.includes: 12 | commands += ["-I", include] 13 | for define in unit.defines: 14 | commands += ["-D", define] 15 | commands += ["-shared"] 16 | commands += ["-fPIC"] 17 | commands += ["-O0"] 18 | commands += ["-g"] 19 | commands += ["main.cpp"] 20 | commands += ["-o", "library.so"] 21 | for library in unit.libraries: 22 | commands += [library] 23 | result = cached_execute(commands, {"main.cpp": unit.source}) 24 | library = ctypes.CDLL(os.path.join(result, "library.so")) 25 | return library 26 | 27 | 28 | gcc = GCC() 29 | -------------------------------------------------------------------------------- /minit/compiler/python.py: -------------------------------------------------------------------------------- 1 | import os 2 | from types import CodeType 3 | 4 | from .cache import cached_execute 5 | 6 | 7 | class PythonCompiler: 8 | def compile(self, source: str) -> CodeType: 9 | path = cached_execute([], { 10 | "kernel.py": source 11 | }) 12 | kernel_code = compile(source, os.path.join(path, "kernel.py"), "exec") 13 | globals = {} 14 | exec(kernel_code, globals) 15 | return globals 16 | 17 | 18 | pythonc = PythonCompiler() 19 | -------------------------------------------------------------------------------- /minit/compiler/template.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | 4 | def substitude(source: str, args: Dict[str, str]): 5 | for k, v in args.items(): 6 | source = source.replace(f"${{{k}}}", v) 7 | return source 8 | -------------------------------------------------------------------------------- /minit/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GetUpEarlier/minit/48d227e638c0316cf998295f7638f909fbc1b9f6/minit/core/__init__.py -------------------------------------------------------------------------------- /minit/core/array.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Generic, List, Optional, Tuple, TypeVar 2 | 3 | 4 | _T = TypeVar("_T") 5 | 6 | 7 | def _reversed(binary_operator: Callable[[_T, _T], _T]): 8 | def reversed(x: _T, y: _T) -> _T: 9 | return binary_operator(y, x) 10 | return reversed 11 | 12 | 13 | class Array(Generic[_T]): 14 | def fold(self, start: int, stop: int): 15 | from ..functional.shape import fold 16 | return fold(self, start, stop) 17 | 18 | def expand(self, axis: int, sizes: Tuple[_T, ...]): 19 | from ..functional.shape import expand 20 | return expand(self, axis, sizes) 21 | 22 | def add_axis(self, axis: int, size: Optional[_T] = None): 23 | from ..functional.shape import add_axis 24 | return add_axis(self, axis, size) 25 | 26 | def remove_axis(self, axis: int): 27 | from ..functional.shape import remove_axis 28 | return remove_axis(self, axis) 29 | 30 | def fold(self, start: int, stop: int): 31 | from ..functional.shape import fold 32 | return fold(self, start, stop) 33 | 34 | def broadcast(self, axis: int, size: _T): 35 | from ..functional.shape import broadcast 36 | return broadcast(self, axis, size) 37 | 38 | def transpose(self, axis_a: int, axis_b: int): 39 | from ..functional.shape import transpose 40 | return transpose(self, axis_a, axis_b) 41 | 42 | def repeat(self, axis: int, size: _T): 43 | from ..functional.shape import repeat 44 | return repeat(self, axis, size) 45 | 46 | def repeat_interleaved(self, axis: int, size: _T): 47 | from ..functional.shape import repeat_interleaved 48 | return repeat_interleaved(self, axis, size) 49 | 50 | def sum(self, axis: int): 51 | from ..functional.reduce import sum 52 | return sum(self, axis) 53 | 54 | def mean(self, axis: int): 55 | from ..functional.reduce import mean 56 | return mean(self, axis) 57 | 58 | def max(self, axis: int): 59 | from ..functional.reduce import max 60 | return max(self, axis) 61 | 62 | def slice(self, start: _T, stop: _T, axis: int): 63 | from ..functional.index import slice 64 | return slice(self, start, stop, axis) 65 | 66 | def slice_set(self, start: _T, stop: _T, axis: int, value: _T): 67 | from ..functional.index import slice_set 68 | return slice_set(self, start, stop, axis, value) 69 | 70 | def index(self, index: _T, axis: int): 71 | from ..functional.index import index as index_get 72 | return index_get(self, index, axis) 73 | 74 | def index_set(self, index: _T, axis: int, value: _T): 75 | from ..functional.index import index_set 76 | return index_set(self, index, axis, value) 77 | 78 | def split(self, axis: int, sizes: Tuple[_T, ...]): 79 | from ..functional.index import split 80 | return split(self, axis, sizes) 81 | 82 | def add(self, y: _T): 83 | from ..functional.arith import add 84 | return add(self, y) 85 | 86 | def subtract(self, y: _T): 87 | from ..functional.arith import subtract 88 | return subtract(self, y) 89 | 90 | def multiply(self, y: _T): 91 | from ..functional.arith import multiply 92 | return multiply(self, y) 93 | 94 | def divide(self, y: _T): 95 | from ..functional.arith import divide 96 | return divide(self, y) 97 | 98 | def floor_divide(self, y: _T): 99 | from ..functional.arith import floor_divide 100 | return floor_divide(self, y) 101 | 102 | def modulo(self, y: _T): 103 | from ..functional.arith import modulo 104 | return modulo(self, y) 105 | 106 | def power(self, y: _T): 107 | from ..functional.arith import power 108 | return power(self, y) 109 | 110 | def exponential(self): 111 | from ..functional.arith import exponential 112 | return exponential(self) 113 | 114 | def logarithm(self): 115 | from ..functional.arith import logarithm 116 | return logarithm(self) 117 | 118 | def square(self): 119 | from ..functional.arith import square 120 | return square(self) 121 | 122 | def square_root(self): 123 | from ..functional.arith import square_root 124 | return square_root(self) 125 | 126 | def sine(self): 127 | from ..functional.arith import sine 128 | return sine(self) 129 | 130 | def cosine(self): 131 | from ..functional.arith import cosine 132 | return cosine(self) 133 | 134 | def reinterpret(self, target: str): 135 | from ..functional.shape import reinterpret 136 | return reinterpret(self, target) 137 | 138 | def cast(self, dtype: str): 139 | from ..functional.arith import cast 140 | return cast(self, dtype) 141 | 142 | def rearrange(self, equation: str, variables: Optional[Dict[str, _T]]=None): 143 | from ..functional.einops import rearrange 144 | return rearrange(equation, self, variables) 145 | 146 | def greater_than(self, y: _T): 147 | from ..functional.arith import greater_than 148 | return greater_than(self, y) 149 | 150 | def less_than(self, y: _T): 151 | from ..functional.arith import less_than 152 | return less_than(self, y) 153 | 154 | def greater_equal(self, y: _T): 155 | from ..functional.arith import greater_equal 156 | return greater_equal(self, y) 157 | 158 | def less_equal(self, y: _T): 159 | from ..functional.arith import less_equal 160 | return less_equal(self, y) 161 | 162 | def equal(self, y: _T): 163 | from ..functional.arith import equal 164 | return equal(self, y) 165 | 166 | def not_equal(self, y: _T): 167 | from ..functional.arith import not_equal 168 | return not_equal(self, y) 169 | 170 | def logical_not(self): 171 | from ..functional.arith import logical_not 172 | return logical_not(self) 173 | 174 | def logical_and(self, y: _T): 175 | from ..functional.arith import logical_and 176 | return logical_and(self, y) 177 | 178 | def logical_or(self, y: _T): 179 | from ..functional.arith import logical_or 180 | return logical_or(self, y) 181 | 182 | @property 183 | def size(self): 184 | shape = self.shape 185 | if len(shape) == 0: 186 | from ..functional.arith import constant 187 | return constant(1, "int32") 188 | size = shape[0] 189 | for dim in shape[1:]: 190 | size = size * dim 191 | return size 192 | 193 | __add__ = add 194 | __radd__ = _reversed(add) 195 | __sub__ = subtract 196 | __rsub__ = _reversed(subtract) 197 | __mul__ = multiply 198 | __rmul__ = _reversed(multiply) 199 | __truediv__ = divide 200 | __rtruediv__ = _reversed(divide) 201 | __floordiv__ = floor_divide 202 | __rfloordiv__ = _reversed(floor_divide) 203 | __mod__ = modulo 204 | __rmod__ = _reversed(modulo) 205 | __pow__ = power 206 | __rpow__ = _reversed(power) 207 | __gt__ = greater_than 208 | __lt__ = less_than 209 | __ge__ = greater_equal 210 | __le__ = less_equal 211 | __eq__ = equal 212 | __ne__ = not_equal 213 | __not__ = logical_not 214 | __and__ = logical_and 215 | __or__ = logical_or 216 | 217 | def __getitem__(self, index) -> _T: 218 | assert not isinstance(index, tuple) 219 | if isinstance(index, slice): 220 | assert index.step is None 221 | if index.start is None: 222 | if index.stop is None: 223 | return self 224 | else: 225 | return self.slice(0, index.stop, 0) 226 | else: 227 | if index.stop is None: 228 | return self.slice(index.start, self.shape[0], 0) 229 | else: 230 | return self.slice(index.start, index.stop, 0) 231 | else: 232 | return self.slice(index, index+1, 0).remove_axis(0) 233 | 234 | @property 235 | def shape(self): 236 | raise NotImplementedError() 237 | 238 | @property 239 | def dtype(self): 240 | raise NotImplementedError() 241 | -------------------------------------------------------------------------------- /minit/core/cache.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from types import FunctionType 3 | 4 | 5 | def cached(): 6 | def decorator(function: FunctionType): 7 | memo = {} 8 | @functools.wraps(function) 9 | def decorated(*args): 10 | try: 11 | return memo[args] 12 | except KeyError: 13 | result = function(*args) 14 | memo[args] = result 15 | return result 16 | return decorated 17 | return decorator 18 | -------------------------------------------------------------------------------- /minit/core/device_operator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Generic, Literal, TypeVar 3 | 4 | from .dispatch import dispatch, register_dispatch 5 | from .operator import Operator 6 | from .tensor import Tensor 7 | 8 | _Operator = TypeVar("_Operator", bound=Operator) 9 | _Device = TypeVar("_Device", bound=str) 10 | 11 | 12 | @dataclass 13 | class DeviceOperator(Operator, Generic[_Operator, _Device]): 14 | operator: _Operator 15 | device: _Device 16 | 17 | def type(self): 18 | return DeviceOperator[self.operator.type(), Literal[self.device]] # type: ignore 19 | -------------------------------------------------------------------------------- /minit/core/dispatch.py: -------------------------------------------------------------------------------- 1 | from types import FunctionType 2 | from typing import Callable, List, Optional, Protocol, Sequence, Tuple, Type, overload 3 | from .operator import Operator 4 | from .tensor import Tensor 5 | from .object import FunctionSignature, Object, extract_function_signature, match_function_args 6 | import inspect 7 | import nvtx 8 | 9 | DISPATCH_CACHE = {} 10 | DISPATCH_TABLE: List[Tuple[FunctionSignature, FunctionType, Optional["DispatchPredicate"], int]] = [] 11 | 12 | def lookup_implementation_from_types(tys: Tuple[Type, ...]): 13 | matching = [] 14 | max_priority = None 15 | for signature, func, predicate, priority in DISPATCH_TABLE: 16 | match_result = match_function_args(signature, tys) 17 | if match_result is not None: 18 | predicate_result = predicate is None or predicate(*tys) 19 | if predicate_result: 20 | matching.append((match_result, signature, func, predicate, priority)) 21 | if max_priority is None or priority > max_priority: 22 | max_priority = priority 23 | assert len(matching) > 0, f"no matching function for {tys}" 24 | selected = [] 25 | for match_result, signature, func, predicate, priority in matching: 26 | if priority == max_priority: 27 | selected.append((match_result, signature, func, predicate, priority)) 28 | assert len(selected) == 1, f"more than one function matches for {tys}" 29 | _, _, cache, _, _ = selected[0] 30 | return cache 31 | 32 | @overload 33 | def dispatch(operator: Operator, *args: Tensor) -> Tuple[Tensor, ...]: 34 | ... 35 | 36 | # @nvtx.annotate("dispatch") 37 | def dispatch(*args: Object) -> Tuple[Tensor, ...]: 38 | arg_types = tuple([arg.type() for arg in args]) 39 | try: 40 | func = DISPATCH_CACHE[arg_types] 41 | except KeyError: 42 | func = lookup_implementation_from_types(arg_types) 43 | DISPATCH_CACHE[arg_types] = func 44 | outputs = func(*args) 45 | return outputs 46 | 47 | class DispatchPredicate(Protocol): 48 | def __call__(self, *args: Type) -> bool: 49 | ... 50 | 51 | DEFAULT_PRIORITY = 0 52 | 53 | def register_dispatch(*, predicate: Optional[DispatchPredicate] = None, priority: int = DEFAULT_PRIORITY): 54 | def decorator(function: FunctionType): 55 | signature = extract_function_signature(function) 56 | print(f"registering {signature}") 57 | DISPATCH_TABLE.append((signature, function, predicate, priority)) 58 | return decorator 59 | -------------------------------------------------------------------------------- /minit/core/dtype.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional 3 | 4 | 5 | class DataTypeInfo: 6 | __slots__ = [ 7 | "name", 8 | "python_type", 9 | "size_in_bytes", 10 | ] 11 | 12 | def __init__(self, name: str, python_type: Optional[type], size_in_bytes: int) -> None: 13 | self.name = name 14 | self.python_type = python_type 15 | self.size_in_bytes = size_in_bytes 16 | 17 | name: str 18 | python_type: Optional[type] 19 | size_in_bytes: int 20 | 21 | 22 | DATA_TYPES: Dict[str, DataTypeInfo] = {} 23 | 24 | 25 | def register_dtype(name: str, python_type: Optional[type], size_in_bytes: int): 26 | assert name not in DATA_TYPES 27 | DATA_TYPES[name] = DataTypeInfo( 28 | name, python_type, size_in_bytes 29 | ) 30 | 31 | 32 | def dtype_info(name: str) -> DataTypeInfo: 33 | return DATA_TYPES[name] 34 | 35 | 36 | register_dtype("float64", float, 8) 37 | register_dtype("float32", float, 4) 38 | register_dtype("float16", float, 2) 39 | register_dtype("bfloat16", float, 2) 40 | 41 | register_dtype("int64", int, 8) 42 | register_dtype("int32", int, 4) 43 | register_dtype("int16", int, 2) 44 | register_dtype("int8", int, 1) 45 | 46 | register_dtype("uint64", int, 8) 47 | register_dtype("uint32", int, 4) 48 | register_dtype("uint16", int, 2) 49 | register_dtype("uint8", int, 1) 50 | -------------------------------------------------------------------------------- /minit/core/meta.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Tuple 2 | 3 | from .shape import to_symbolic_shape 4 | from .tensor import Tensor 5 | 6 | 7 | class MetaTensor(Tensor): 8 | def __init__(self, shape: Tuple[Tensor, ...], dtype: str) -> None: 9 | super().__init__() 10 | self._shape = shape 11 | self._dtype = dtype 12 | 13 | @property 14 | def shape(self): 15 | return to_symbolic_shape(self._shape) 16 | 17 | @property 18 | def dtype(self): 19 | return self._dtype 20 | 21 | @staticmethod 22 | def make(shape: Sequence[Optional[Tensor]], dtype: str): 23 | shape = tuple([dim if dim is not None else MetaTensor((), "int32") for dim in shape]) 24 | return MetaTensor(shape, dtype) 25 | 26 | @property 27 | def device(self): 28 | return "meta" 29 | -------------------------------------------------------------------------------- /minit/core/object.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import inspect 3 | from types import FunctionType 4 | from typing import Generic, Literal, Optional, Tuple, Type, Union, get_args, get_origin 5 | from typing_extensions import Self 6 | 7 | 8 | class Object: 9 | def type(self) -> Type[Self]: 10 | raise NotImplementedError() 11 | 12 | 13 | def simplify_type(ty: Type) -> Tuple[Type, ...]: 14 | if get_origin(ty) == Union: 15 | return tuple( 16 | simplified_arg 17 | for arg in get_args(ty) 18 | for simplified_arg in simplify_type(arg) 19 | ) 20 | else: 21 | return (ty,) 22 | 23 | 24 | @dataclass 25 | class FunctionSignature: 26 | args: Tuple[Tuple[Type, ...], ...] 27 | vaargs: Optional[Type] 28 | 29 | 30 | def extract_function_signature(fn: FunctionType) -> FunctionSignature: 31 | signature = inspect.signature(fn) 32 | args = [] 33 | vaargs = None 34 | for param in signature.parameters.values(): 35 | assert param.kind in [param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD, param.VAR_POSITIONAL] 36 | simplified = simplify_type(param.annotation) 37 | if param.kind == param.VAR_POSITIONAL: 38 | vaargs = simplified 39 | else: 40 | args.append(simplified) 41 | return FunctionSignature(args, vaargs) 42 | 43 | 44 | def is_literal(ty: Type): 45 | return get_origin_or_self(ty) == Literal 46 | 47 | 48 | def get_origin_or_self(ty: Type): 49 | return get_origin(ty) or ty 50 | 51 | 52 | def is_union(ty: Type): 53 | return get_origin_or_self(ty) == Union 54 | 55 | 56 | # mro without dedup 57 | def generic_mro(ty: Type): 58 | if not hasattr(ty, "__origin__"): 59 | return inspect.getmro(ty) 60 | orig_bases = ty.__origin__.__orig_bases__ 61 | generic = None 62 | for orig_base in orig_bases: 63 | if get_origin_or_self(orig_base) == Generic: 64 | generic = orig_base 65 | assert generic is not None, f"generic not found for {ty}" 66 | assert len(get_args(ty)) == len(get_args(generic)) 67 | args_mapping = { generic_arg: arg for arg, generic_arg in zip(get_args(ty), get_args(generic), strict=True) } 68 | if len(orig_bases) == 1: 69 | return ty 70 | def substitude(ty: Type): 71 | if not hasattr(ty, "__origin__"): 72 | return ty 73 | args = tuple(map(lambda t: args_mapping.get(t, default=t), get_args(ty))) 74 | return get_origin_or_self(ty)[args] 75 | results = [] 76 | for orig_base in orig_bases: 77 | if orig_base is generic: 78 | continue 79 | for orig_base_mro in generic_mro(substitude(orig_base)): 80 | results.append(orig_base_mro) 81 | return tuple(results) 82 | 83 | 84 | def match_pattern(pattern: Type, arg: Type): 85 | # literal 86 | if is_literal(pattern): 87 | if is_literal(arg) and pattern == arg: 88 | return pattern 89 | return None 90 | # object 91 | if is_union(pattern): 92 | for pattern_arg in get_args(pattern): 93 | result = match_pattern(pattern_arg, arg) 94 | if result is not None: 95 | return result 96 | return None 97 | # arg literal 98 | origin_pattern = get_origin_or_self(pattern) 99 | if is_literal(arg): 100 | if isinstance(get_args(arg)[0], origin_pattern): 101 | return pattern 102 | return None 103 | origin_arg = get_origin_or_self(arg) 104 | if not issubclass(origin_arg, origin_pattern): 105 | return None 106 | if origin_arg != origin_pattern: 107 | for base in generic_mro(arg): 108 | if get_origin_or_self(base) == origin_pattern: 109 | arg = base 110 | origin_arg = get_origin_or_self(base) 111 | assert origin_arg == origin_pattern 112 | if get_args(pattern) == (): 113 | return pattern 114 | if get_args(arg) == (): 115 | return None 116 | for arg_arg, pattern_arg in zip(get_args(arg), get_args(pattern), strict=True): 117 | match_result = match_pattern(pattern_arg, arg_arg) 118 | if match_result is None: 119 | return None 120 | return pattern 121 | 122 | 123 | def match_patterns(patterns: Tuple[Type, ...], arg: Type) -> Optional[Type]: 124 | for pattern in patterns: 125 | match_result = match_pattern(pattern, arg) 126 | if match_result is not None: 127 | return match_result 128 | return None 129 | 130 | 131 | def match_function_args(signature: FunctionSignature, args: Tuple[Type, ...]): 132 | result = [] 133 | if len(args) < len(signature.args): 134 | return None 135 | if len(args) > len(signature.args): 136 | if signature.vaargs is None: 137 | return None 138 | nr_args = len(signature.args) 139 | for i in range(nr_args): 140 | match_result = match_patterns(signature.args[i], args[i]) 141 | if match_result is None: 142 | return None 143 | result.append(match_result) 144 | for arg in args[nr_args:]: 145 | match_result = match_patterns(signature.vaargs, arg) 146 | if match_result is None: 147 | return None 148 | result.append(match_result) 149 | return tuple(result) 150 | -------------------------------------------------------------------------------- /minit/core/operator.py: -------------------------------------------------------------------------------- 1 | from .object import Object 2 | 3 | 4 | class Operator(Object): 5 | def type(self): 6 | return type(self) 7 | -------------------------------------------------------------------------------- /minit/core/scalar.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | from typing import Tuple 3 | 4 | from .dtype import dtype_info 5 | from .tensor import Tensor 6 | 7 | 8 | class ScalarTensor(Tensor): 9 | __slots__ = [ 10 | "_value", 11 | "_shape", 12 | "_dtype", 13 | ] 14 | 15 | def __init__(self, value: Number, shape: Tuple[Tensor, ...], dtype: str) -> None: 16 | super().__init__() 17 | assert not isinstance(value, Tensor) 18 | assert shape is not None 19 | assert dtype is not None 20 | self._value = value 21 | self._shape = shape 22 | self._dtype = dtype 23 | 24 | @property 25 | def shape(self): 26 | return self._shape 27 | 28 | @property 29 | def dtype(self): 30 | return self._dtype 31 | 32 | def value(self): 33 | return self._value 34 | 35 | def item(self) -> Number: 36 | assert len(self.shape) == 0 37 | return dtype_info(self._dtype).python_type(self._value) 38 | 39 | def __repr__(self) -> str: 40 | return f"Scalar({self._value}: {self._dtype})" 41 | 42 | def type(self): 43 | return ScalarTensor 44 | -------------------------------------------------------------------------------- /minit/core/shape.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | from .tensor import Tensor 4 | from .scalar import ScalarTensor 5 | 6 | 7 | Shape = Tuple[Union[int, Tensor], ...] 8 | ImmediateShape = Tuple[int, ...] 9 | SymbolicShape = Tuple[Tensor, ...] 10 | 11 | 12 | def to_immediate_shape(shape: Shape) -> ImmediateShape: 13 | return tuple([dim.item() if isinstance(dim, Tensor) else dim for dim in shape]) 14 | 15 | 16 | def to_symbolic_shape(shape: Shape) -> SymbolicShape: 17 | return tuple([dim if isinstance(dim, Tensor) else ScalarTensor(dim, (), "int32") for dim in shape]) 18 | -------------------------------------------------------------------------------- /minit/core/tensor.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from .object import Object 4 | from .array import Array 5 | 6 | class Tensor(Array["Tensor"], Object): 7 | @property 8 | def shape(self) -> Tuple["Tensor", ...]: 9 | raise NotImplementedError() 10 | 11 | @property 12 | def dtype(self) -> str: 13 | raise NotImplementedError() 14 | 15 | @property 16 | def device(self) -> str: 17 | raise NotImplementedError() 18 | 19 | def type(self): 20 | return type(self) 21 | -------------------------------------------------------------------------------- /minit/core/torch.py: -------------------------------------------------------------------------------- 1 | from .tensor import Tensor 2 | import torch 3 | 4 | 5 | class TorchTensor(Tensor): 6 | def __init__(self, value: torch.Tensor) -> None: 7 | super().__init__() 8 | self._value = value 9 | 10 | @property 11 | def value(self): 12 | return self._value 13 | 14 | @property 15 | def device(self): 16 | return "torch" 17 | -------------------------------------------------------------------------------- /minit/cuda/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GetUpEarlier/minit/48d227e638c0316cf998295f7638f909fbc1b9f6/minit/cuda/__init__.py -------------------------------------------------------------------------------- /minit/cuda/allocator.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import functools 3 | import nvtx 4 | 5 | from ..core.cache import cached 6 | 7 | from .toolkit import find_cuda_include_directory, find_cuda_libraries 8 | from ..compiler.cxx import CXXLibrary, CXXUnit, import_symbol 9 | from ..compiler.gcc import gcc 10 | 11 | 12 | @cached() 13 | def _generate_library(): 14 | source =\ 15 | """ 16 | #include 17 | #include 18 | #include 19 | 20 | 21 | #define CUDA_ASSERT(expr) \\ 22 | do { \\ 23 | auto _err = (expr); \\ 24 | if (_err != cudaSuccess) { \\ 25 | throw std::runtime_error(cudaGetErrorString(_err)); \\ 26 | } \\ 27 | } while (0) 28 | 29 | 30 | extern "C" void* allocate_cuda_memory(size_t size) { 31 | void* pointer; 32 | CUDA_ASSERT(cudaMallocAsync(&pointer, size, nullptr)); 33 | return pointer; 34 | } 35 | 36 | extern "C" void copy_cuda_memory(void* dst, const void* src, size_t size) { 37 | CUDA_ASSERT(cudaMemcpyAsync(dst, src, size, cudaMemcpyDefault, nullptr)); 38 | } 39 | 40 | extern "C" void free_cuda_memory(void* pointer) { 41 | CUDA_ASSERT(cudaFreeAsync(pointer, nullptr)); 42 | } 43 | 44 | extern "C" void sync_cuda() { 45 | CUDA_ASSERT(cudaStreamSynchronize(nullptr)); 46 | } 47 | 48 | extern "C" void reset(void* pointer, size_t size) { 49 | CUDA_ASSERT(cudaMemsetAsync(pointer, 0, size, nullptr)); 50 | } 51 | """ 52 | return gcc.compile(CXXUnit(source=source, libraries=find_cuda_libraries(), includes=[ 53 | find_cuda_include_directory() 54 | ])) 55 | 56 | 57 | _library = _generate_library() 58 | 59 | 60 | @import_symbol(_library, "allocate_cuda_memory") 61 | def allocate_cuda_memory(size: ctypes.c_size_t) -> ctypes.c_void_p: 62 | ... 63 | 64 | 65 | @import_symbol(_library, "free_cuda_memory") 66 | def free_cuda_memory(pointer: ctypes.c_void_p): 67 | ... 68 | 69 | 70 | @import_symbol(_library, "copy_cuda_memory") 71 | def copy_cuda_memory(dst: ctypes.c_void_p, src: ctypes.c_void_p, size: ctypes.c_size_t): 72 | ... 73 | 74 | 75 | @import_symbol(_library, "sync_cuda") 76 | def sync_cuda(): 77 | ... 78 | 79 | 80 | @import_symbol(_library, "reset") 81 | def reset(dst: ctypes.c_void_p, size: ctypes.c_size_t): 82 | ... 83 | 84 | 85 | class CUDAMemory: 86 | __slots__ = [ 87 | "_pointer", 88 | "_size", 89 | ] 90 | 91 | _pointer: int 92 | _size: int 93 | 94 | def __init__(self, size: int) -> None: 95 | self._pointer = allocate_cuda_memory(size) 96 | self._size = size 97 | 98 | 99 | def __del__(self): 100 | free_cuda_memory(self._pointer) 101 | self._pointer = None 102 | self._size = None 103 | 104 | 105 | def copy_from(self, src: "CUDAMemory"): 106 | assert self._size == src._size 107 | copy_cuda_memory(self._pointer, src._pointer, self._size) 108 | 109 | 110 | def copy(self) -> "CUDAMemory": 111 | new = CUDAMemory(self._size) 112 | copy_cuda_memory(new._pointer, self._pointer, self._size) 113 | return new 114 | 115 | 116 | def reset(self): 117 | reset(self._pointer, self._size) 118 | -------------------------------------------------------------------------------- /minit/cuda/compiler.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import os 3 | 4 | from ..compiler.cache import cached_execute 5 | from ..compiler.cxx import CXXCompiler, CXXLibrary, CXXUnit 6 | from .toolkit import get_cuda_home 7 | 8 | 9 | class NVCC(CXXCompiler): 10 | def compile(self, unit: CXXUnit) -> CXXLibrary: 11 | commands = [os.path.join(get_cuda_home(), "bin", "nvcc")] 12 | for include in unit.includes: 13 | commands += ["-I", include] 14 | for library in unit.libraries: 15 | commands += ["-l", library] 16 | for define in unit.defines: 17 | commands += ["-D", define] 18 | commands += ["-shared"] 19 | commands += ["--compiler-options", "-fPIC"] 20 | commands += ["-gencode=arch=compute_70,code=compute_70"] 21 | commands += ["main.cu"] 22 | commands += ["-o", "library.so"] 23 | result = cached_execute(commands, {"main.cu": unit.source}) 24 | library = ctypes.CDLL(os.path.join(result, "library.so")) 25 | return library 26 | 27 | 28 | nvcc = NVCC() 29 | -------------------------------------------------------------------------------- /minit/cuda/kernel/cublas.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import functools 3 | import os 4 | 5 | from ...core.cache import cached 6 | 7 | from ...compiler.template import substitude 8 | from ...compiler.cxx import CXXUnit, import_symbol 9 | from ...compiler.gcc import gcc 10 | from ..toolkit import find_cuda_include_directory, get_cuda_home 11 | 12 | @cached() 13 | def generate_cublas_kernel(name: str, dtype: str): 14 | kernel_name = f"minit_{name}" 15 | kernel_template =\ 16 | """ 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | 24 | #define CUBLAS_ASSERT(expr) \\ 25 | do { \\ 26 | auto _err = (expr); \\ 27 | if (_err != CUBLAS_STATUS_SUCCESS) { \\ 28 | fprintf(stderr, "cublas errno: %d\\n", (int)_err); \\ 29 | throw std::runtime_error("cublas error at " #expr); \\ 30 | } \\ 31 | } while (0) 32 | 33 | 34 | #define CUDA_ASSERT(expr) \\ 35 | do { \\ 36 | auto _err = (expr); \\ 37 | if (_err != cudaSuccess) { \\ 38 | throw std::runtime_error(cudaGetErrorString(_err)); \\ 39 | } \\ 40 | } while (0) 41 | 42 | using T = ${DATA_TYPE}; 43 | static constexpr cudaDataType_t kCudaDataType = ${CUDA_DATA_TYPE}; 44 | static constexpr cublasComputeType_t kCudaComputeDataType = ${CUDA_COMPUTE_DATA_TYPE}; 45 | 46 | cublasLtMatrixLayout_t create_layout(size_t batch, size_t nr_rows, size_t nr_cols) { 47 | cublasLtMatrixLayout_t layout; 48 | CUBLAS_ASSERT(cublasLtMatrixLayoutCreate(&layout, kCudaDataType, nr_rows, nr_cols, nr_rows)); 49 | int32_t batch_i32 = batch; 50 | int64_t batch_stride_i64 = nr_rows * nr_cols; 51 | CUBLAS_ASSERT(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_i32, sizeof(batch_i32))); 52 | CUBLAS_ASSERT(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &batch_stride_i64, sizeof(batch_stride_i64))); 53 | return layout; 54 | } 55 | 56 | cublasLtHandle_t get_handle() { 57 | static thread_local std::unordered_map handles; 58 | int device; 59 | CUDA_ASSERT(cudaGetDevice(&device)); 60 | if (!handles.count(device)) { 61 | cublasLtHandle_t handle; 62 | CUBLAS_ASSERT(cublasLtCreate(&handle)); 63 | handles[device] = handle; 64 | } 65 | return handles[device]; 66 | } 67 | 68 | extern "C" void ${KERNEL_NAME}(cudaStream_t stream, void* a, void* b, void* c, size_t batch, size_t m, size_t n, size_t k, void* workspace) { 69 | auto handle = get_handle(); 70 | cublasLtMatmulDesc_t desc; 71 | CUBLAS_ASSERT(cublasLtMatmulDescCreate(&desc, kCudaComputeDataType, kCudaDataType)); 72 | cublasLtMatmulPreference_t preference; 73 | CUBLAS_ASSERT(cublasLtMatmulPreferenceCreate(&preference)); 74 | size_t workspace_size = 0; 75 | CUBLAS_ASSERT(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); 76 | cublasOperation_t transa = CUBLAS_OP_T, transb = CUBLAS_OP_N, transc = CUBLAS_OP_N; 77 | CUBLAS_ASSERT(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa))); 78 | CUBLAS_ASSERT(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb))); 79 | CUBLAS_ASSERT(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSC, &transc, sizeof(transc))); 80 | // CUBLAS_ASSERT(cublasLtMatmulPreferenceInit(preference)); 81 | int heuristics_count = 1; 82 | cublasLtMatmulHeuristicResult_t heuristics[heuristics_count]; 83 | // (k, m) @ (n, k) => (m, n) 84 | auto a_layout = create_layout(batch, k, n); 85 | auto b_layout = create_layout(batch, k, m); 86 | auto c_layout = create_layout(batch, n, m); 87 | CUBLAS_ASSERT(cublasLtMatmulAlgoGetHeuristic(handle, desc, a_layout, b_layout, c_layout, c_layout, preference, heuristics_count, heuristics, &heuristics_count)); 88 | if (heuristics_count == 0) { 89 | fprintf(stderr, "no gemm algo found\\n"); 90 | } 91 | float alpha = 1.0, beta = 1.0; 92 | CUDA_ASSERT(cudaMemset(c, 0, batch * m * n * sizeof(T))); 93 | CUBLAS_ASSERT(cublasLtMatmul( 94 | handle, 95 | desc, 96 | &alpha, 97 | b, 98 | a_layout, 99 | a, 100 | b_layout, 101 | &beta, 102 | c, 103 | c_layout, 104 | c, 105 | c_layout, 106 | &heuristics[0].algo, 107 | workspace, 108 | 0, 109 | stream 110 | )); 111 | } 112 | """ 113 | source = substitude(kernel_template, { 114 | "KERNEL_NAME": kernel_name, 115 | "DATA_TYPE": dtype, 116 | "CUDA_DATA_TYPE": { 117 | "double": "CUDA_R_64F", 118 | "float": "CUDA_R_32F", 119 | "__half": "CUDA_R_16F", 120 | "__nv_bfloat16": "CUDA_R_16BF", 121 | }[dtype], 122 | "CUDA_COMPUTE_DATA_TYPE": { 123 | "double": "CUBLAS_COMPUTE_64F", 124 | "float": "CUBLAS_COMPUTE_32F", 125 | "__half": "CUBLAS_COMPUTE_32F", 126 | "__nv_bfloat16": "CUBLAS_COMPUTE_32F", 127 | }[dtype] 128 | }) 129 | kernel = gcc.compile(CXXUnit(source=source, includes=[find_cuda_include_directory()], libraries=[ 130 | os.path.join(get_cuda_home(), "lib64", "libcublasLt.so"), 131 | os.path.join(get_cuda_home(), "lib64", "libcudart.so"), 132 | ])) 133 | @import_symbol(kernel, kernel_name) 134 | def entrance( 135 | stream: ctypes.c_void_p, 136 | a: ctypes.c_void_p, 137 | b: ctypes.c_void_p, 138 | c: ctypes.c_void_p, 139 | batch: ctypes.c_size_t, 140 | m: ctypes.c_size_t, 141 | n: ctypes.c_size_t, 142 | k: ctypes.c_size_t, 143 | workspace: ctypes.c_void_p, 144 | ): 145 | ... 146 | return entrance 147 | -------------------------------------------------------------------------------- /minit/cuda/kernel/cuda/broadcast.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import functools 3 | 4 | from ....core.cache import cached 5 | 6 | from ....compiler.template import substitude 7 | from ....compiler.cxx import CXXUnit, import_symbol 8 | from ...compiler import nvcc 9 | 10 | @cached() 11 | def generate_broadcast_kernel(name, dtype): 12 | kernel_name = f"minit_{name}" 13 | kernel_template =\ 14 | """ 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | 21 | #define CUDA_ASSERT(expr) \\ 22 | do { \\ 23 | auto _err = (expr); \\ 24 | if (_err != cudaSuccess) { \\ 25 | throw std::runtime_error(cudaGetErrorString(_err)); \\ 26 | } \\ 27 | } while (0) 28 | 29 | using T = ${DATA_TYPE}; 30 | 31 | template 32 | struct TensorIterator { 33 | size_t shape[nr_ranks]; 34 | 35 | __device__ cuda::std::array to_index(size_t offset) const { 36 | cuda::std::array index; 37 | for (size_t i = 0; i < nr_ranks; ++i) { 38 | index[nr_ranks-i-1] = offset % shape[nr_ranks-i-1]; 39 | offset /= shape[nr_ranks-i-1]; 40 | } 41 | return index; 42 | } 43 | 44 | __device__ size_t to_offset(cuda::std::array index) const { 45 | size_t offset = 0; 46 | for (size_t i = 0; i < nr_ranks; ++i) { 47 | offset *= shape[i]; 48 | offset += index[i]; 49 | } 50 | return offset; 51 | } 52 | }; 53 | 54 | // input[a, 1, c] -> output[a, b, c] 55 | __global__ void kernel(T* input, T* output, size_t a, size_t b, size_t c) { 56 | size_t stride = blockDim.x * gridDim.x; 57 | TensorIterator<2> input_iterator; 58 | input_iterator.shape[0] = a; 59 | input_iterator.shape[1] = c; 60 | TensorIterator<3> output_iterator; 61 | output_iterator.shape[0] = a; 62 | output_iterator.shape[1] = b; 63 | output_iterator.shape[2] = c; 64 | size_t nr_elements = a * b * c; 65 | for (size_t offset = blockIdx.x * blockDim.x + threadIdx.x; offset < nr_elements; offset += stride) { 66 | size_t output_offset = offset; 67 | auto output_index = output_iterator.to_index(output_offset); 68 | size_t input_offset = input_iterator.to_offset({output_index[0], output_index[2]}); 69 | output[output_offset] = input[input_offset]; 70 | } 71 | } 72 | 73 | extern "C" void ${KERNEL_NAME}(cudaStream_t stream, T* input, T* output, size_t a, size_t b, size_t c) { 74 | size_t nr_elements = a * b * c; 75 | static constexpr size_t nr_sms = 108; 76 | size_t nr_threads_per_block = std::min((size_t)1024, (size_t)((nr_elements + nr_sms - 1) / nr_sms)); 77 | size_t nr_blocks = (nr_elements + nr_threads_per_block - 1) / nr_threads_per_block; 78 | kernel<<>>(input, output, a, b, c); 79 | CUDA_ASSERT(cudaGetLastError()); 80 | } 81 | """ 82 | source = substitude(kernel_template, { 83 | "DATA_TYPE": dtype, 84 | "KERNEL_NAME": kernel_name, 85 | }) 86 | kernel = nvcc.compile(CXXUnit(source=source)) 87 | @import_symbol(kernel, kernel_name) 88 | def entrance( 89 | stream: ctypes.c_void_p, 90 | input: ctypes.c_void_p, 91 | output: ctypes.c_void_p, 92 | a: ctypes.c_size_t, 93 | b: ctypes.c_size_t, 94 | c: ctypes.c_size_t, 95 | ): 96 | ... 97 | return entrance 98 | -------------------------------------------------------------------------------- /minit/cuda/kernel/cuda/cast.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import functools 3 | 4 | from ....core.cache import cached 5 | 6 | from ....compiler.template import substitude 7 | from ....compiler.cxx import CXXUnit, import_symbol 8 | from ...compiler import nvcc 9 | 10 | @cached() 11 | def generate_cast_kernel(name: str, source_dtype: str, target_dtype: str): 12 | kernel_name = f"minit_cast_{name}" 13 | kernel_template =\ 14 | """ 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | 25 | #define CUDA_ASSERT(expr) \\ 26 | do { \\ 27 | auto _err = (expr); \\ 28 | if (_err != cudaSuccess) { \\ 29 | throw std::runtime_error(cudaGetErrorString(_err)); \\ 30 | } \\ 31 | } while (0) 32 | 33 | using source_type = ${SOURCE_TYPE}; 34 | using target_type = ${TARGET_TYPE}; 35 | 36 | __global__ void kernel(source_type* input, target_type* output, size_t nr_elements) { 37 | size_t stride = blockDim.x * gridDim.x; 38 | for (size_t offset = blockIdx.x * blockDim.x + threadIdx.x; offset < nr_elements; offset += stride) { 39 | auto value = input[offset]; 40 | output[offset] = (target_type)value; 41 | } 42 | } 43 | 44 | extern "C" void ${KERNEL_NAME}(cudaStream_t stream, source_type* input, target_type* output, size_t nr_elements) { 45 | if (nr_elements == 0) { 46 | return; 47 | } 48 | static constexpr size_t nr_sms = 108; 49 | size_t nr_threads_per_block = std::min((size_t)1024, (size_t)((nr_elements + nr_sms - 1) / nr_sms)); 50 | size_t nr_blocks = (nr_elements + nr_threads_per_block - 1) / nr_threads_per_block; 51 | kernel<<>>(input, output, nr_elements); 52 | CUDA_ASSERT(cudaGetLastError()); 53 | } 54 | """ 55 | source = substitude(kernel_template, { 56 | "SOURCE_TYPE": source_dtype, 57 | "TARGET_TYPE": target_dtype, 58 | "KERNEL_NAME": kernel_name, 59 | }) 60 | kernel = nvcc.compile(CXXUnit(source=source)) 61 | @import_symbol(kernel, kernel_name) 62 | def entrance( 63 | stream: ctypes.c_void_p, 64 | input: ctypes.c_void_p, 65 | output: ctypes.c_void_p, 66 | nr_elements: ctypes.c_size_t, 67 | ): 68 | ... 69 | return entrance 70 | -------------------------------------------------------------------------------- /minit/cuda/kernel/cuda/elemwise.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import functools 3 | 4 | from ....core.cache import cached 5 | 6 | from ....compiler.template import substitude 7 | from ....compiler.cxx import CXXUnit, import_symbol 8 | from ...compiler import nvcc 9 | 10 | @cached() 11 | def generate_elemwise_kernel(name: str, nr_inputs: int, expr: str, dtype: str): 12 | kernel_name = f"minit_elemwise_{name}" 13 | kernel_template =\ 14 | """ 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | 25 | #define CUDA_ASSERT(expr) \\ 26 | do { \\ 27 | auto _err = (expr); \\ 28 | if (_err != cudaSuccess) { \\ 29 | throw std::runtime_error(cudaGetErrorString(_err)); \\ 30 | } \\ 31 | } while (0) 32 | 33 | using T = ${DATA_TYPE}; 34 | static constexpr size_t nr_inputs = ${NR_INPUTS}; 35 | 36 | __global__ void kernel(cuda::std::array inputs, cuda::std::array strides, T* output, size_t nr_elements) { 37 | for (size_t offset = blockIdx.x * blockDim.x + threadIdx.x; offset < nr_elements; offset += blockDim.x * gridDim.x) { 38 | T values[nr_inputs]; 39 | #pragma unroll 40 | for (size_t i = 0; i < nr_inputs; ++i) { 41 | values[i] = __ldg(inputs[i] + offset*strides[i]); 42 | } 43 | output[offset] = ${EXPR}; 44 | } 45 | } 46 | 47 | extern "C" void ${KERNEL_NAME}(cudaStream_t stream, T** inputs, int* strides, T* output, size_t nr_elements) { 48 | if (nr_elements == 0) { 49 | return; 50 | } 51 | static constexpr size_t nr_sms = 108; 52 | size_t nr_threads_per_block = std::min((size_t)1024, (size_t)((nr_elements + nr_sms - 1) / nr_sms)); 53 | size_t nr_blocks = (nr_elements + nr_threads_per_block - 1) / nr_threads_per_block; 54 | cuda::std::array inputs_array; 55 | std::memcpy(&inputs_array, inputs, sizeof(inputs_array)); 56 | cuda::std::array stride_array; 57 | std::memcpy(&stride_array, strides, sizeof(stride_array)); 58 | kernel<<>>(inputs_array, stride_array, output, nr_elements); 59 | CUDA_ASSERT(cudaGetLastError()); 60 | } 61 | """ 62 | source = substitude(kernel_template, { 63 | "DATA_TYPE": dtype, 64 | "NR_INPUTS": str(nr_inputs), 65 | "EXPR": expr.format(*[f"values[{i}]" for i in range(nr_inputs)]), 66 | "KERNEL_NAME": kernel_name, 67 | }) 68 | kernel = nvcc.compile(CXXUnit(source=source)) 69 | @import_symbol(kernel, kernel_name) 70 | def entrance( 71 | stream: ctypes.c_void_p, 72 | input: ctypes.c_void_p, 73 | strides: ctypes.c_void_p, 74 | output: ctypes.c_void_p, 75 | nr_elements: ctypes.c_size_t, 76 | ): 77 | ... 78 | return entrance 79 | -------------------------------------------------------------------------------- /minit/cuda/kernel/cuda/fill.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import functools 3 | 4 | from ....core.cache import cached 5 | 6 | from ....compiler.template import substitude 7 | from ....compiler.cxx import CXXUnit, import_symbol 8 | from ...compiler import nvcc 9 | 10 | @cached() 11 | def generate_fill_kernel(name: str, dtype: str): 12 | kernel_name = f"minit_fill_{name}" 13 | kernel_template =\ 14 | """ 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | 22 | #define CUDA_ASSERT(expr) \\ 23 | do { \\ 24 | auto _err = (expr); \\ 25 | if (_err != cudaSuccess) { \\ 26 | throw std::runtime_error(cudaGetErrorString(_err)); \\ 27 | } \\ 28 | } while (0) 29 | 30 | using T = ${DATA_TYPE}; 31 | 32 | __global__ void kernel(T* output, size_t nr_elements, T value) { 33 | size_t stride = blockDim.x * gridDim.x; 34 | for (size_t offset = blockIdx.x * blockDim.x + threadIdx.x; offset < nr_elements; offset += stride) { 35 | output[offset] = value; 36 | } 37 | } 38 | 39 | extern "C" void ${KERNEL_NAME}(cudaStream_t stream, T* output, size_t nr_elements, double value) { 40 | if (nr_elements == 0) { 41 | return; 42 | } 43 | static constexpr size_t nr_sms = 108; 44 | size_t nr_threads_per_block = std::min((size_t)1024, (size_t)((nr_elements + nr_sms - 1) / nr_sms)); 45 | size_t nr_blocks = (nr_elements + nr_threads_per_block - 1) / nr_threads_per_block; 46 | kernel<<>>(output, nr_elements, (T)value); 47 | } 48 | """ 49 | source = substitude(kernel_template, { 50 | "DATA_TYPE": dtype, 51 | "KERNEL_NAME": kernel_name, 52 | }) 53 | kernel = nvcc.compile(CXXUnit(source=source)) 54 | @import_symbol(kernel, kernel_name) 55 | def entrance( 56 | stream: ctypes.c_void_p, 57 | output: ctypes.c_void_p, 58 | nr_elements: ctypes.c_size_t, 59 | value: ctypes.c_double, 60 | ): 61 | ... 62 | return entrance 63 | -------------------------------------------------------------------------------- /minit/cuda/kernel/cuda/generate.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import functools 3 | 4 | from ....core.cache import cached 5 | 6 | from ....compiler.template import substitude 7 | from ....compiler.cxx import CXXUnit, import_symbol 8 | from ...compiler import nvcc 9 | 10 | @cached() 11 | def generate_sequence_kernel(name: str, dtype: str): 12 | kernel_name = f"minit_{name}" 13 | kernel_template =\ 14 | """ 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | 25 | #define CUDA_ASSERT(expr) \\ 26 | do { \\ 27 | auto _err = (expr); \\ 28 | if (_err != cudaSuccess) { \\ 29 | throw std::runtime_error(cudaGetErrorString(_err)); \\ 30 | } \\ 31 | } while (0) 32 | 33 | using T = ${DATA_TYPE}; 34 | 35 | __global__ void kernel(T* output, T* start, T* step, size_t nr_elements) { 36 | size_t stride = blockDim.x * gridDim.x; 37 | for (size_t offset = blockIdx.x * blockDim.x + threadIdx.x; offset < nr_elements; offset += stride) { 38 | output[offset] = (*start) + (T)((double)(*step) * (double)offset); 39 | } 40 | } 41 | 42 | extern "C" void ${KERNEL_NAME}(cudaStream_t stream, T* output, T* start, T* step, size_t nr_elements) { 43 | static constexpr size_t nr_sms = 108; 44 | size_t nr_threads_per_block = std::min((size_t)1024, (size_t)((nr_elements + nr_sms - 1) / nr_sms)); 45 | size_t nr_blocks = (nr_elements + nr_threads_per_block - 1) / nr_threads_per_block; 46 | kernel<<>>(output, start, step, nr_elements); 47 | CUDA_ASSERT(cudaGetLastError()); 48 | } 49 | """ 50 | source = substitude(kernel_template, { 51 | "DATA_TYPE": dtype, 52 | "KERNEL_NAME": kernel_name, 53 | }) 54 | kernel = nvcc.compile(CXXUnit(source=source)) 55 | @import_symbol(kernel, kernel_name) 56 | def entrance( 57 | stream: ctypes.c_void_p, 58 | output: ctypes.c_void_p, 59 | start: ctypes.c_void_p, 60 | step: ctypes.c_void_p, 61 | nr_elements: ctypes.c_size_t, 62 | ): 63 | ... 64 | return entrance 65 | -------------------------------------------------------------------------------- /minit/cuda/kernel/cuda/index.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import functools 3 | 4 | from ....core.cache import cached 5 | 6 | from ....compiler.template import substitude 7 | from ....compiler.cxx import CXXUnit, import_symbol 8 | from ...compiler import nvcc 9 | 10 | @cached() 11 | def generate_index_kernel(name: str, dtype: str): 12 | kernel_name = f"minit_{name}" 13 | kernel_template =\ 14 | """ 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | 25 | #define CUDA_ASSERT(expr) \\ 26 | do { \\ 27 | auto _err = (expr); \\ 28 | if (_err != cudaSuccess) { \\ 29 | throw std::runtime_error(cudaGetErrorString(_err)); \\ 30 | } \\ 31 | } while (0) 32 | 33 | using T = ${DATA_TYPE}; 34 | 35 | template 36 | struct TensorIterator { 37 | size_t shape[nr_ranks]; 38 | 39 | __device__ cuda::std::array to_index(size_t offset) const { 40 | cuda::std::array index; 41 | for (size_t i = 0; i < nr_ranks; ++i) { 42 | index[nr_ranks-i-1] = offset % shape[nr_ranks-i-1]; 43 | offset /= shape[nr_ranks-i-1]; 44 | } 45 | return index; 46 | } 47 | 48 | __device__ size_t to_offset(cuda::std::array index) const { 49 | size_t offset = 0; 50 | for (size_t i = 0; i < nr_ranks; ++i) { 51 | offset *= shape[i]; 52 | offset += index[i]; 53 | } 54 | return offset; 55 | } 56 | }; 57 | 58 | __global__ void kernel(T* input, std::int32_t* index, T* output, size_t a, size_t b, size_t c, size_t d) { 59 | size_t nr_elements = a * c * d; 60 | size_t stride = blockDim.x * gridDim.x; 61 | TensorIterator<3> output_iterator; 62 | output_iterator.shape[0] = a; 63 | output_iterator.shape[1] = c; 64 | output_iterator.shape[2] = d; 65 | TensorIterator<3> input_iterator; 66 | input_iterator.shape[0] = a; 67 | input_iterator.shape[1] = b; 68 | input_iterator.shape[2] = d; 69 | for (size_t offset = blockIdx.x * blockDim.x + threadIdx.x; offset < nr_elements; offset += stride) { 70 | auto output_offset = offset; 71 | auto output_index = output_iterator.to_index(output_offset); 72 | if (index[output_index[1]] < 0) { 73 | __trap(); 74 | } 75 | if (index[output_index[1]] >= b) { 76 | __trap(); 77 | } 78 | auto input_offset = input_iterator.to_offset({output_index[0], index[output_index[1]], output_index[2]}); 79 | output[offset] = input[input_offset]; 80 | } 81 | } 82 | 83 | extern "C" void ${KERNEL_NAME}(cudaStream_t stream, T* input, std::int32_t* index, T* output, size_t a, size_t b, size_t c, size_t d) { 84 | size_t nr_elements = a * c * d; 85 | if (nr_elements == 0) { 86 | return; 87 | } 88 | static constexpr size_t nr_sms = 108; 89 | size_t nr_threads_per_block = std::min((size_t)1024, (size_t)((nr_elements + nr_sms - 1) / nr_sms)); 90 | size_t nr_blocks = (nr_elements + nr_threads_per_block - 1) / nr_threads_per_block; 91 | kernel<<>>(input, index, output, a, b, c, d); 92 | CUDA_ASSERT(cudaGetLastError()); 93 | } 94 | """ 95 | source = substitude(kernel_template, { 96 | "DATA_TYPE": dtype, 97 | "KERNEL_NAME": kernel_name, 98 | }) 99 | kernel = nvcc.compile(CXXUnit(source=source)) 100 | @import_symbol(kernel, kernel_name) 101 | def entrance( 102 | stream: ctypes.c_void_p, 103 | input: ctypes.c_void_p, 104 | index: ctypes.c_void_p, 105 | output: ctypes.c_void_p, 106 | a: ctypes.c_size_t, 107 | b: ctypes.c_size_t, 108 | c: ctypes.c_size_t, 109 | d: ctypes.c_size_t, 110 | ): 111 | ... 112 | return entrance 113 | -------------------------------------------------------------------------------- /minit/cuda/kernel/cuda/reduce.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import functools 3 | 4 | from ....core.cache import cached 5 | 6 | from ....compiler.template import substitude 7 | from ....compiler.cxx import CXXUnit, import_symbol 8 | from ...compiler import nvcc 9 | 10 | @cached() 11 | def generate_reduce_kernel(name: str, init: str, expr: str, dtype: str): 12 | kernel_name = f"minit_{name}" 13 | kernel_template =\ 14 | """ 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | 22 | #define CUDA_ASSERT(expr) \\ 23 | do { \\ 24 | auto _err = (expr); \\ 25 | if (_err != cudaSuccess) { \\ 26 | throw std::runtime_error(cudaGetErrorString(_err)); \\ 27 | } \\ 28 | } while (0) 29 | 30 | using T = ${DATA_TYPE}; 31 | 32 | template 33 | struct TensorIterator { 34 | size_t shape[nr_ranks]; 35 | 36 | __device__ cuda::std::array to_index(size_t offset) const { 37 | cuda::std::array index; 38 | for (size_t i = 0; i < nr_ranks; ++i) { 39 | index[nr_ranks-i-1] = offset % shape[nr_ranks-i-1]; 40 | offset /= shape[nr_ranks-i-1]; 41 | } 42 | return index; 43 | } 44 | 45 | __device__ size_t to_offset(cuda::std::array index) const { 46 | size_t offset = 0; 47 | for (size_t i = 0; i < nr_ranks; ++i) { 48 | offset *= shape[i]; 49 | offset += index[i]; 50 | } 51 | return offset; 52 | } 53 | }; 54 | 55 | __global__ void thread_reduce_kernel(T* input, T* output, size_t a, size_t b, size_t c) { 56 | size_t stride = blockDim.x * gridDim.x; 57 | size_t nr_lines = a * c; 58 | for (size_t offset = blockIdx.x * blockDim.x + threadIdx.x; offset < nr_lines; offset += stride) { 59 | T result = ${REDUCE_INIT}; 60 | for (size_t j = 0; j < b; j++) { 61 | T x = result; 62 | T y = input[(offset / c) * (b*c) + offset%c + j * c]; 63 | result = ${REDUCE_EXPR}(x, y); 64 | } 65 | output[offset] = result; 66 | } 67 | } 68 | 69 | template 70 | __launch_bounds__(nr_threads) __global__ void block_reduce_kernel(T* input, T* output, size_t a, size_t b, size_t c) { 71 | typedef cub::BlockReduce BlockReduce; 72 | __shared__ typename BlockReduce::TempStorage temp_storage; 73 | size_t nr_lines = a * c; 74 | TensorIterator<3> input_iterator{a, b, c}; 75 | TensorIterator<2> output_iterator{a, c}; 76 | for (size_t line = blockIdx.x; line < nr_lines; line += gridDim.x) { 77 | auto output_index = output_iterator.to_index(line); 78 | float result = ${REDUCE_INIT}; 79 | size_t nr_loops = (b + blockDim.x - 1) / blockDim.x; 80 | for (size_t i = 0; i < nr_loops; ++i) { 81 | size_t index = i * blockDim.x + threadIdx.x; 82 | float value = index < b ? input[input_iterator.to_offset({output_index[0], index, output_index[1]})] : (T)${REDUCE_INIT}; 83 | float aggregate = BlockReduce(temp_storage).Reduce((float)value, ${REDUCE_EXPR}, nr_threads); 84 | if (threadIdx.x == 0) { 85 | result = ${REDUCE_EXPR}(result, aggregate); 86 | } 87 | __syncthreads(); 88 | } 89 | if (threadIdx.x == 0) { 90 | output[line] = (T)result; 91 | } 92 | } 93 | } 94 | 95 | extern "C" void ${KERNEL_NAME}(cudaStream_t stream, T* input, T* output, size_t a, size_t b, size_t c) { 96 | static constexpr size_t nr_sms = 108; 97 | static constexpr size_t nr_threads_per_block = 1024; 98 | size_t nr_blocks = nr_sms; 99 | block_reduce_kernel<<>>(input, output, a, b, c); 100 | } 101 | """ 102 | source = substitude(kernel_template, { 103 | "DATA_TYPE": dtype, 104 | "KERNEL_NAME": kernel_name, 105 | "REDUCE_INIT": init, 106 | "REDUCE_EXPR": expr, 107 | }) 108 | kernel = nvcc.compile(CXXUnit(source=source)) 109 | @import_symbol(kernel, kernel_name) 110 | def entrance( 111 | stream: ctypes.c_void_p, 112 | input: ctypes.c_void_p, 113 | output: ctypes.c_void_p, 114 | a: ctypes.c_size_t, 115 | b: ctypes.c_size_t, 116 | c: ctypes.c_size_t, 117 | ): 118 | ... 119 | return entrance 120 | -------------------------------------------------------------------------------- /minit/cuda/kernel/cuda/rms_norm.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import functools 3 | 4 | from ....core.cache import cached 5 | 6 | from ....compiler.template import substitude 7 | from ....compiler.cxx import CXXUnit, import_symbol 8 | from ...compiler import nvcc 9 | 10 | @cached() 11 | def generate_rms_norm_kernel(name: str, dtype: str): 12 | kernel_name = f"minit_{name}" 13 | kernel_template =\ 14 | """ 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | 22 | #define CUDA_ASSERT(expr) \\ 23 | do { \\ 24 | auto _err = (expr); \\ 25 | if (_err != cudaSuccess) { \\ 26 | throw std::runtime_error(cudaGetErrorString(_err)); \\ 27 | } \\ 28 | } while (0) 29 | 30 | using T = ${DATA_TYPE}; 31 | 32 | template 33 | struct TensorIterator { 34 | size_t shape[nr_ranks]; 35 | 36 | __device__ cuda::std::array to_index(size_t offset) const { 37 | cuda::std::array index; 38 | for (size_t i = 0; i < nr_ranks; ++i) { 39 | index[nr_ranks-i-1] = offset % shape[nr_ranks-i-1]; 40 | offset /= shape[nr_ranks-i-1]; 41 | } 42 | return index; 43 | } 44 | 45 | __device__ size_t to_offset(cuda::std::array index) const { 46 | size_t offset = 0; 47 | for (size_t i = 0; i < nr_ranks; ++i) { 48 | offset *= shape[i]; 49 | offset += index[i]; 50 | } 51 | return offset; 52 | } 53 | }; 54 | 55 | template 56 | __launch_bounds__(nr_threads) __global__ void block_rms_norm_kernel(T* input, T* weight, T* output, size_t a, size_t b, size_t c, float eps) { 57 | typedef cub::BlockReduce BlockReduce; 58 | __shared__ typename BlockReduce::TempStorage temp_storage; 59 | size_t nr_lines = a * c; 60 | TensorIterator<2> line_iterator{a, c}; 61 | TensorIterator<3> iterator{a, b, c}; 62 | for (size_t line = blockIdx.x; line < nr_lines; line += gridDim.x) { 63 | auto line_index = line_iterator.to_index(line); 64 | __shared__ float sum; 65 | __syncthreads(); 66 | if (threadIdx.x == 0) { 67 | sum = 0; 68 | } 69 | size_t nr_loops = (b + blockDim.x - 1) / blockDim.x; 70 | for (size_t i = 0; i < nr_loops; ++i) { 71 | size_t index = i * blockDim.x + threadIdx.x; 72 | float value = index < b ? input[iterator.to_offset({line_index[0], index, line_index[1]})] : (T)0; 73 | value = value * value; 74 | float aggregate = BlockReduce(temp_storage).Reduce((float)value, cub::Sum(), nr_threads); 75 | if (threadIdx.x == 0) { 76 | sum = cub::Sum()(sum, aggregate); 77 | } 78 | __syncthreads(); 79 | } 80 | float mean = sum / b; 81 | float rms = sqrt(mean + eps); 82 | for (size_t i = 0; i < nr_loops; ++i) { 83 | size_t index = i * blockDim.x + threadIdx.x; 84 | if (index < b) { 85 | float value = input[iterator.to_offset({line_index[0], index, line_index[1]})]; 86 | value = value / rms; 87 | value *= (float)weight[index]; 88 | output[iterator.to_offset({line_index[0], index, line_index[1]})] = value; 89 | } 90 | } 91 | } 92 | } 93 | 94 | extern "C" void ${KERNEL_NAME}(cudaStream_t stream, T* input, T* weight, T* output, size_t a, size_t b, size_t c, double eps) { 95 | static constexpr size_t nr_sms = 108; 96 | static constexpr size_t nr_threads_per_block = 1024; 97 | size_t nr_blocks = nr_sms; 98 | block_rms_norm_kernel<<>>(input, weight, output, a, b, c, (float)eps); 99 | } 100 | """ 101 | source = substitude(kernel_template, { 102 | "DATA_TYPE": dtype, 103 | "KERNEL_NAME": kernel_name, 104 | }) 105 | kernel = nvcc.compile(CXXUnit(source=source)) 106 | @import_symbol(kernel, kernel_name) 107 | def entrance( 108 | stream: ctypes.c_void_p, 109 | input: ctypes.c_void_p, 110 | weight: ctypes.c_void_p, 111 | output: ctypes.c_void_p, 112 | a: ctypes.c_size_t, 113 | b: ctypes.c_size_t, 114 | c: ctypes.c_size_t, 115 | eps: ctypes.c_double, 116 | ): 117 | ... 118 | return entrance 119 | -------------------------------------------------------------------------------- /minit/cuda/kernel/cuda/rope.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import functools 3 | 4 | from ....core.cache import cached 5 | 6 | from ....compiler.template import substitude 7 | from ....compiler.cxx import CXXUnit, import_symbol 8 | from ...compiler import nvcc 9 | 10 | @cached() 11 | def generate_rope_kernel(name: str, dtype: str): 12 | kernel_name = f"minit_{name}" 13 | kernel_template =\ 14 | """ 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | 25 | #define CUDA_ASSERT(expr) \\ 26 | do { \\ 27 | auto _err = (expr); \\ 28 | if (_err != cudaSuccess) { \\ 29 | throw std::runtime_error(cudaGetErrorString(_err)); \\ 30 | } \\ 31 | } while (0) 32 | 33 | using T = ${DATA_TYPE}; 34 | 35 | template 36 | struct TensorIterator { 37 | size_t shape[nr_ranks]; 38 | 39 | __device__ cuda::std::array to_index(size_t offset) const { 40 | cuda::std::array index; 41 | for (size_t i = 0; i < nr_ranks; ++i) { 42 | index[nr_ranks-i-1] = offset % shape[nr_ranks-i-1]; 43 | offset /= shape[nr_ranks-i-1]; 44 | } 45 | return index; 46 | } 47 | 48 | __device__ size_t to_offset(cuda::std::array index) const { 49 | size_t offset = 0; 50 | for (size_t i = 0; i < nr_ranks; ++i) { 51 | offset *= shape[i]; 52 | offset += index[i]; 53 | } 54 | return offset; 55 | } 56 | }; 57 | 58 | __global__ void kernel(T* input, float* freqs_cos, float* freqs_sin, T* output, size_t batch_size, size_t seqlen, size_t nr_heads, size_t head_size) { 59 | size_t nr_complexes = batch_size * seqlen * nr_heads * (head_size/2); 60 | size_t stride = blockDim.x * gridDim.x; 61 | TensorIterator<4> input_iterator {batch_size, seqlen, nr_heads, head_size/2}; 62 | TensorIterator<2> freqs_iterator {seqlen, head_size/2}; 63 | for (size_t offset = blockIdx.x * blockDim.x + threadIdx.x; offset < nr_complexes; offset += stride) { 64 | auto input_index = input_iterator.to_index(offset); 65 | auto real = input[offset*2]; 66 | auto imag = input[offset*2+1]; 67 | auto freqs_offset = freqs_iterator.to_offset({input_index[1], input_index[3]}); 68 | auto freq_cos = freqs_cos[freqs_offset]; 69 | auto freq_sin = freqs_sin[freqs_offset]; 70 | output[offset*2] = real * freq_cos - imag * freq_sin; 71 | output[offset*2+1] = imag * freq_cos + real * freq_sin; 72 | } 73 | } 74 | 75 | extern "C" void ${KERNEL_NAME}(cudaStream_t stream, T* input, float* freqs_cos, float* freqs_sin, T* output, size_t batch_size, size_t seqlen, size_t nr_heads, size_t head_size) { 76 | size_t nr_complexes = batch_size * seqlen * nr_heads * (head_size/2); 77 | if (nr_complexes == 0) { 78 | return; 79 | } 80 | static constexpr size_t nr_sms = 108; 81 | size_t nr_threads_per_block = std::min((size_t)1024, (size_t)((nr_complexes + nr_sms - 1) / nr_sms)); 82 | size_t nr_blocks = (nr_complexes + nr_threads_per_block - 1) / nr_threads_per_block; 83 | kernel<<>>(input, freqs_cos, freqs_sin, output, batch_size, seqlen, nr_heads, head_size); 84 | CUDA_ASSERT(cudaGetLastError()); 85 | } 86 | """ 87 | source = substitude(kernel_template, { 88 | "DATA_TYPE": dtype, 89 | "KERNEL_NAME": kernel_name, 90 | }) 91 | kernel = nvcc.compile(CXXUnit(source=source)) 92 | @import_symbol(kernel, kernel_name) 93 | def entrance( 94 | stream: ctypes.c_void_p, 95 | input: ctypes.c_void_p, 96 | freqs_cos: ctypes.c_void_p, 97 | freqs_sin: ctypes.c_void_p, 98 | output: ctypes.c_void_p, 99 | batch_size: ctypes.c_size_t, 100 | seqlen: ctypes.c_size_t, 101 | nr_heads: ctypes.c_size_t, 102 | head_size: ctypes.c_size_t, 103 | ): 104 | ... 105 | return entrance 106 | -------------------------------------------------------------------------------- /minit/cuda/kernel/cuda/slice.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import functools 3 | 4 | from ....core.cache import cached 5 | 6 | from ....compiler.template import substitude 7 | from ....compiler.cxx import CXXUnit, import_symbol 8 | from ...compiler import nvcc 9 | 10 | @cached() 11 | def generate_slice_kernel(name: str, dtype: str): 12 | kernel_name = f"minit_{name}" 13 | kernel_template =\ 14 | """ 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | 25 | #define CUDA_ASSERT(expr) \\ 26 | do { \\ 27 | auto _err = (expr); \\ 28 | if (_err != cudaSuccess) { \\ 29 | throw std::runtime_error(cudaGetErrorString(_err)); \\ 30 | } \\ 31 | } while (0) 32 | 33 | using T = ${DATA_TYPE}; 34 | 35 | template 36 | struct TensorIterator { 37 | size_t shape[nr_ranks]; 38 | 39 | __device__ cuda::std::array to_index(size_t offset) const { 40 | cuda::std::array index; 41 | for (size_t i = 0; i < nr_ranks; ++i) { 42 | index[nr_ranks-i-1] = offset % shape[nr_ranks-i-1]; 43 | offset /= shape[nr_ranks-i-1]; 44 | } 45 | return index; 46 | } 47 | 48 | __device__ size_t to_offset(cuda::std::array index) const { 49 | size_t offset = 0; 50 | for (size_t i = 0; i < nr_ranks; ++i) { 51 | offset *= shape[i]; 52 | offset += index[i]; 53 | } 54 | return offset; 55 | } 56 | }; 57 | 58 | __global__ void kernel(T* input, T* output, size_t a, size_t b, size_t c, size_t start, size_t stop) { 59 | size_t nr_elements = a * c * (stop - start); 60 | size_t stride = blockDim.x * gridDim.x; 61 | TensorIterator<3> input_iterator; 62 | input_iterator.shape[0] = a; 63 | input_iterator.shape[1] = b; 64 | input_iterator.shape[2] = c; 65 | TensorIterator<3> output_iterator; 66 | output_iterator.shape[0] = a; 67 | output_iterator.shape[1] = stop - start; 68 | output_iterator.shape[2] = c; 69 | for (size_t offset = blockIdx.x * blockDim.x + threadIdx.x; offset < nr_elements; offset += stride) { 70 | auto index = output_iterator.to_index(offset); 71 | index[1] += start; 72 | output[offset] = input[input_iterator.to_offset(index)]; 73 | } 74 | } 75 | 76 | extern "C" void ${KERNEL_NAME}(cudaStream_t stream, T* input, T* output, size_t a, size_t b, size_t c, size_t start, size_t stop) { 77 | size_t nr_elements = a * c * (stop - start); 78 | if (nr_elements == 0) { 79 | return; 80 | } 81 | static constexpr size_t nr_sms = 108; 82 | size_t nr_threads_per_block = std::min((size_t)1024, (size_t)((nr_elements + nr_sms - 1) / nr_sms)); 83 | size_t nr_blocks = (nr_elements + nr_threads_per_block - 1) / nr_threads_per_block; 84 | kernel<<>>(input, output, a, b, c, start, stop); 85 | CUDA_ASSERT(cudaGetLastError()); 86 | } 87 | """ 88 | source = substitude(kernel_template, { 89 | "DATA_TYPE": dtype, 90 | "KERNEL_NAME": kernel_name, 91 | }) 92 | kernel = nvcc.compile(CXXUnit(source=source)) 93 | @import_symbol(kernel, kernel_name) 94 | def entrance( 95 | stream: ctypes.c_void_p, 96 | input: ctypes.c_void_p, 97 | output: ctypes.c_void_p, 98 | a: ctypes.c_size_t, 99 | b: ctypes.c_size_t, 100 | c: ctypes.c_size_t, 101 | start: ctypes.c_size_t, 102 | stop: ctypes.c_size_t, 103 | ): 104 | ... 105 | return entrance 106 | 107 | 108 | @cached() 109 | def generate_slice_set_kernel(name: str, dtype: str): 110 | kernel_name = f"minit_{name}" 111 | kernel_template =\ 112 | """ 113 | #include 114 | #include 115 | #include 116 | #include 117 | #include 118 | #include 119 | #include 120 | #include 121 | 122 | 123 | #define CUDA_ASSERT(expr) \\ 124 | do { \\ 125 | auto _err = (expr); \\ 126 | if (_err != cudaSuccess) { \\ 127 | throw std::runtime_error(cudaGetErrorString(_err)); \\ 128 | } \\ 129 | } while (0) 130 | 131 | using T = ${DATA_TYPE}; 132 | 133 | template 134 | struct TensorIterator { 135 | size_t shape[nr_ranks]; 136 | 137 | __device__ cuda::std::array to_index(size_t offset) const { 138 | cuda::std::array index; 139 | for (size_t i = 0; i < nr_ranks; ++i) { 140 | index[nr_ranks-i-1] = offset % shape[nr_ranks-i-1]; 141 | offset /= shape[nr_ranks-i-1]; 142 | } 143 | return index; 144 | } 145 | 146 | __device__ size_t to_offset(cuda::std::array index) const { 147 | size_t offset = 0; 148 | for (size_t i = 0; i < nr_ranks; ++i) { 149 | offset *= shape[i]; 150 | offset += index[i]; 151 | } 152 | return offset; 153 | } 154 | }; 155 | 156 | __global__ void kernel(T* input, T* output, size_t a, size_t b, size_t c, size_t start, size_t stop) { 157 | size_t nr_elements = a * c * (stop - start); 158 | size_t stride = blockDim.x * gridDim.x; 159 | TensorIterator<3> input_iterator; 160 | input_iterator.shape[0] = a; 161 | input_iterator.shape[1] = stop - start; 162 | input_iterator.shape[2] = c; 163 | TensorIterator<3> output_iterator; 164 | output_iterator.shape[0] = a; 165 | output_iterator.shape[1] = b; 166 | output_iterator.shape[2] = c; 167 | for (size_t offset = blockIdx.x * blockDim.x + threadIdx.x; offset < nr_elements; offset += stride) { 168 | auto index = input_iterator.to_index(offset); 169 | index[1] += start; 170 | output[output_iterator.to_offset(index)] = input[offset]; 171 | } 172 | } 173 | 174 | extern "C" void ${KERNEL_NAME}(cudaStream_t stream, T* input, T* output, size_t a, size_t b, size_t c, size_t start, size_t stop) { 175 | size_t nr_elements = a * c * (stop - start); 176 | if (nr_elements == 0) { 177 | return; 178 | } 179 | static constexpr size_t nr_sms = 108; 180 | size_t nr_threads_per_block = std::min((size_t)1024, (size_t)((nr_elements + nr_sms - 1) / nr_sms)); 181 | size_t nr_blocks = (nr_elements + nr_threads_per_block - 1) / nr_threads_per_block; 182 | kernel<<>>(input, output, a, b, c, start, stop); 183 | CUDA_ASSERT(cudaGetLastError()); 184 | } 185 | """ 186 | source = substitude(kernel_template, { 187 | "DATA_TYPE": dtype, 188 | "KERNEL_NAME": kernel_name, 189 | }) 190 | kernel = nvcc.compile(CXXUnit(source=source)) 191 | @import_symbol(kernel, kernel_name) 192 | def entrance( 193 | stream: ctypes.c_void_p, 194 | input: ctypes.c_void_p, 195 | output: ctypes.c_void_p, 196 | a: ctypes.c_size_t, 197 | b: ctypes.c_size_t, 198 | c: ctypes.c_size_t, 199 | start: ctypes.c_size_t, 200 | stop: ctypes.c_size_t, 201 | ): 202 | ... 203 | return entrance 204 | -------------------------------------------------------------------------------- /minit/cuda/kernel/cuda/softmax.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import functools 3 | 4 | from ....core.cache import cached 5 | 6 | from ....compiler.template import substitude 7 | from ....compiler.cxx import CXXUnit, import_symbol 8 | from ...compiler import nvcc 9 | 10 | @cached() 11 | def generate_softmax_kernel(name: str, dtype: str): 12 | kernel_name = f"minit_{name}" 13 | kernel_template =\ 14 | """ 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | 22 | #define CUDA_ASSERT(expr) \\ 23 | do { \\ 24 | auto _err = (expr); \\ 25 | if (_err != cudaSuccess) { \\ 26 | throw std::runtime_error(cudaGetErrorString(_err)); \\ 27 | } \\ 28 | } while (0) 29 | 30 | using T = ${DATA_TYPE}; 31 | 32 | template 33 | struct TensorIterator { 34 | size_t shape[nr_ranks]; 35 | 36 | __device__ cuda::std::array to_index(size_t offset) const { 37 | cuda::std::array index; 38 | for (size_t i = 0; i < nr_ranks; ++i) { 39 | index[nr_ranks-i-1] = offset % shape[nr_ranks-i-1]; 40 | offset /= shape[nr_ranks-i-1]; 41 | } 42 | return index; 43 | } 44 | 45 | __device__ size_t to_offset(cuda::std::array index) const { 46 | size_t offset = 0; 47 | for (size_t i = 0; i < nr_ranks; ++i) { 48 | offset *= shape[i]; 49 | offset += index[i]; 50 | } 51 | return offset; 52 | } 53 | }; 54 | 55 | template 56 | __launch_bounds__(nr_threads) __global__ void softmax_kernel(T* input, T* output, size_t a, size_t b, size_t c) { 57 | typedef cub::BlockReduce BlockReduce; 58 | __shared__ typename BlockReduce::TempStorage temp_storage; 59 | size_t nr_lines = a * c; 60 | TensorIterator<3> iterator{a, b, c}; 61 | TensorIterator<2> line_iterator{a, c}; 62 | for (size_t line = blockIdx.x; line < nr_lines; line += gridDim.x) { 63 | auto line_index = line_iterator.to_index(line); 64 | __shared__ float max; 65 | __shared__ float sum; 66 | __syncthreads(); 67 | if (threadIdx.x == 0) { 68 | max = -INFINITY; 69 | sum = 0; 70 | } 71 | size_t nr_loops = (b + blockDim.x - 1) / blockDim.x; 72 | for (size_t i = 0; i < nr_loops; ++i) { 73 | size_t index = i * blockDim.x + threadIdx.x; 74 | float value = index < b ? (float)input[iterator.to_offset({line_index[0], index, line_index[1]})] : -INFINITY; 75 | float aggregate = BlockReduce(temp_storage).Reduce((float)value, cub::Max(), nr_threads); 76 | if (threadIdx.x == 0) { 77 | max = cub::Max()(max, aggregate); 78 | } 79 | __syncthreads(); 80 | } 81 | for (size_t i = 0; i < nr_loops; ++i) { 82 | size_t index = i * blockDim.x + threadIdx.x; 83 | float value = 0; 84 | if (index < b) { 85 | value = (float)input[iterator.to_offset({line_index[0], index, line_index[1]})]; 86 | value -= max; 87 | value = expf(value); 88 | } 89 | float aggregate = BlockReduce(temp_storage).Reduce((float)value, cub::Sum(), nr_threads); 90 | if (threadIdx.x == 0) { 91 | sum = cub::Sum()(sum, aggregate); 92 | } 93 | __syncthreads(); 94 | } 95 | for (size_t i = 0; i < nr_loops; ++i) { 96 | size_t index = i * blockDim.x + threadIdx.x; 97 | if (index < b) { 98 | float value = input[iterator.to_offset({line_index[0], index, line_index[1]})]; 99 | value -= max; 100 | value = expf(value); 101 | value = value / sum; 102 | output[iterator.to_offset({line_index[0], index, line_index[1]})] = (T)value; 103 | } 104 | __syncthreads(); 105 | } 106 | } 107 | } 108 | 109 | extern "C" void ${KERNEL_NAME}(cudaStream_t stream, T* input, T* output, size_t a, size_t b, size_t c) { 110 | static constexpr size_t nr_sms = 108; 111 | static constexpr size_t nr_threads_per_block = 1024; 112 | size_t nr_blocks = nr_sms; 113 | softmax_kernel<<>>(input, output, a, b, c); 114 | } 115 | """ 116 | source = substitude(kernel_template, { 117 | "DATA_TYPE": dtype, 118 | "KERNEL_NAME": kernel_name, 119 | }) 120 | kernel = nvcc.compile(CXXUnit(source=source)) 121 | @import_symbol(kernel, kernel_name) 122 | def entrance( 123 | stream: ctypes.c_void_p, 124 | input: ctypes.c_void_p, 125 | output: ctypes.c_void_p, 126 | a: ctypes.c_size_t, 127 | b: ctypes.c_size_t, 128 | c: ctypes.c_size_t, 129 | ): 130 | ... 131 | return entrance 132 | -------------------------------------------------------------------------------- /minit/cuda/kernel/cuda/transpose.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import functools 3 | 4 | from ....core.cache import cached 5 | 6 | from ....compiler.template import substitude 7 | from ....compiler.cxx import CXXUnit, import_symbol 8 | from ...compiler import nvcc 9 | 10 | @cached() 11 | def generate_transpose_kernel(name: str, dtype: str): 12 | kernel_name = f"minit_{name}" 13 | kernel_template =\ 14 | """ 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | 21 | #define CUDA_ASSERT(expr) \\ 22 | do { \\ 23 | auto _err = (expr); \\ 24 | if (_err != cudaSuccess) { \\ 25 | throw std::runtime_error(cudaGetErrorString(_err)); \\ 26 | } \\ 27 | } while (0) 28 | 29 | using T = ${DATA_TYPE}; 30 | 31 | template 32 | struct TensorIterator { 33 | size_t shape[nr_ranks]; 34 | 35 | __device__ cuda::std::array to_index(size_t offset) const { 36 | cuda::std::array index; 37 | for (size_t i = 0; i < nr_ranks; ++i) { 38 | index[nr_ranks-i-1] = offset % shape[nr_ranks-i-1]; 39 | offset /= shape[nr_ranks-i-1]; 40 | } 41 | return index; 42 | } 43 | 44 | __device__ size_t to_offset(cuda::std::array index) const { 45 | size_t offset = 0; 46 | for (size_t i = 0; i < nr_ranks; ++i) { 47 | offset *= shape[i]; 48 | offset += index[i]; 49 | } 50 | return offset; 51 | } 52 | }; 53 | 54 | // input[a, b, c, d, e] -> output[a, d, c, b, e] 55 | __global__ void kernel(T* input, T* output, size_t a, size_t b, size_t c, size_t d, size_t e) { 56 | size_t stride = blockDim.x * gridDim.x; 57 | size_t nr_elements = a * b * c * d * e; 58 | TensorIterator<5> input_iterator{a, b, c, d, e}; 59 | TensorIterator<5> output_iterator{a, d, c, b, e}; 60 | for (size_t offset = blockIdx.x * blockDim.x + threadIdx.x; offset < nr_elements; offset += stride) { 61 | size_t output_offset = offset; 62 | auto index = output_iterator.to_index(output_offset); 63 | cuda::std::swap(index[1], index[3]); 64 | size_t input_offset = input_iterator.to_offset(index); 65 | output[output_offset] = input[input_offset]; 66 | } 67 | } 68 | 69 | extern "C" void ${KERNEL_NAME} (cudaStream_t stream, T* input, T* output, size_t a, size_t b, size_t c, size_t d, size_t e) { 70 | static constexpr size_t nr_sms = 64; 71 | size_t nr_elements = a * b * c * d * e; 72 | size_t nr_threads_per_block = std::min((size_t)1024, (size_t)((nr_elements + nr_sms - 1) / nr_sms)); 73 | size_t nr_blocks = (nr_elements + nr_threads_per_block - 1) / nr_threads_per_block; 74 | kernel<<>>(input, output, a, b, c, d, e); 75 | } 76 | """ 77 | source = substitude(kernel_template, { 78 | "DATA_TYPE": dtype, 79 | "KERNEL_NAME": kernel_name, 80 | }) 81 | kernel = nvcc.compile(CXXUnit(source=source)) 82 | @import_symbol(kernel, kernel_name) 83 | def entrance( 84 | stream: ctypes.c_void_p, 85 | input: ctypes.c_void_p, 86 | output: ctypes.c_void_p, 87 | a: ctypes.c_size_t, 88 | b: ctypes.c_size_t, 89 | c: ctypes.c_size_t, 90 | d: ctypes.c_size_t, 91 | e: ctypes.c_size_t, 92 | ): 93 | ... 94 | return entrance 95 | -------------------------------------------------------------------------------- /minit/cuda/kernel/cuda/triangle.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import functools 3 | 4 | from ....core.cache import cached 5 | 6 | from ....compiler.template import substitude 7 | from ....compiler.cxx import CXXUnit, import_symbol 8 | from ...compiler import nvcc 9 | 10 | @cached() 11 | def generate_triangle_kernel(name: str, predicate: str, dtype: str): 12 | kernel_name = f"minit_{name}" 13 | kernel_template =\ 14 | """ 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | 21 | #define CUDA_ASSERT(expr) \\ 22 | do { \\ 23 | auto _err = (expr); \\ 24 | if (_err != cudaSuccess) { \\ 25 | throw std::runtime_error(cudaGetErrorString(_err)); \\ 26 | } \\ 27 | } while (0) 28 | 29 | using T = ${DATA_TYPE}; 30 | 31 | template 32 | struct TensorIterator { 33 | size_t shape[nr_ranks]; 34 | 35 | __device__ cuda::std::array to_index(size_t offset) const { 36 | cuda::std::array index; 37 | for (size_t i = 0; i < nr_ranks; ++i) { 38 | index[nr_ranks-i-1] = offset % shape[nr_ranks-i-1]; 39 | offset /= shape[nr_ranks-i-1]; 40 | } 41 | return index; 42 | } 43 | 44 | __device__ size_t to_offset(cuda::std::array index) const { 45 | size_t offset = 0; 46 | for (size_t i = 0; i < nr_ranks; ++i) { 47 | offset *= shape[i]; 48 | offset += index[i]; 49 | } 50 | return offset; 51 | } 52 | }; 53 | 54 | // input[a, b, c, d, e] -> output[a, d, c, b, e] 55 | __global__ void kernel(T* input, T* output, size_t diagonal, size_t a, size_t b, size_t c, size_t d) { 56 | size_t stride = blockDim.x * gridDim.x; 57 | size_t nr_elements = a * b * c * d; 58 | TensorIterator<4> iterator{a, b, c, d}; 59 | for (size_t offset = blockIdx.x * blockDim.x + threadIdx.x; offset < nr_elements; offset += stride) { 60 | auto index = iterator.to_index(offset); 61 | if (${PREDICATE}) { 62 | output[offset] = input[offset]; 63 | } else { 64 | output[offset] = 0.0; 65 | } 66 | } 67 | } 68 | 69 | extern "C" void ${KERNEL_NAME}(cudaStream_t stream, T* input, T* output, size_t diagonal, size_t a, size_t b, size_t c, size_t d) { 70 | static constexpr size_t nr_sms = 112; 71 | size_t nr_elements = a * b * c * d; 72 | size_t nr_threads_per_block = std::min((size_t)1024, (size_t)((nr_elements + nr_sms - 1) / nr_sms)); 73 | size_t nr_blocks = (nr_elements + nr_threads_per_block - 1) / nr_threads_per_block; 74 | kernel<<>>(input, output, diagonal, a, b, c, d); 75 | } 76 | """ 77 | source = substitude(kernel_template, { 78 | "DATA_TYPE": dtype, 79 | "KERNEL_NAME": kernel_name, 80 | "PREDICATE": predicate, 81 | }) 82 | kernel = nvcc.compile(CXXUnit(source=source)) 83 | @import_symbol(kernel, kernel_name) 84 | def entrance( 85 | stream: ctypes.c_void_p, 86 | input: ctypes.c_void_p, 87 | output: ctypes.c_void_p, 88 | diagonal: ctypes.c_size_t, 89 | a: ctypes.c_size_t, 90 | b: ctypes.c_size_t, 91 | c: ctypes.c_size_t, 92 | d: ctypes.c_size_t, 93 | ): 94 | ... 95 | return entrance 96 | -------------------------------------------------------------------------------- /minit/cuda/kernel/utils.py: -------------------------------------------------------------------------------- 1 | CUDA_DTYPE_MAPPING = { 2 | "bfloat16": "__nv_bfloat16", 3 | "float16": "__half", 4 | "float32": "float", 5 | "float64": "double", 6 | "int8": "std::int8_t", 7 | "int16": "std::int16_t", 8 | "int32": "std::int32_t", 9 | "int64": "std::int64_t", 10 | "uint8": "std::uint8_t", 11 | "uint16": "std::uint16_t", 12 | "uint32": "std::uint32_t", 13 | "uint64": "std::uint64_t", 14 | } 15 | 16 | def get_cuda_dtype(name: str): 17 | return CUDA_DTYPE_MAPPING[name] 18 | -------------------------------------------------------------------------------- /minit/cuda/lib/cuda_runtime.py: -------------------------------------------------------------------------------- 1 | from ctypes import CDLL 2 | import functools 3 | import os 4 | from minit.compiler.cxx import CXXLibrary 5 | from ...core.cache import cached 6 | from ..toolkit import get_cuda_home 7 | 8 | 9 | @cached() 10 | def load_cuda_runtime(): 11 | path = os.path.join(get_cuda_home(), "lib64", "libcudart.so") 12 | library = CDLL(path) 13 | return CXXLibrary(library) 14 | -------------------------------------------------------------------------------- /minit/cuda/tensor.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | from typing import Optional, Tuple 3 | 4 | from ..core.scalar import ScalarTensor 5 | 6 | from ..core.shape import to_immediate_shape, to_symbolic_shape 7 | from ..core.dtype import dtype_info 8 | from ..core.tensor import Tensor 9 | from .allocator import CUDAMemory, copy_cuda_memory, sync_cuda 10 | 11 | import numpy 12 | import nvtx 13 | 14 | 15 | class CUDATensor(Tensor): 16 | __slots__ = [ 17 | "_memory", 18 | "_shape", 19 | "_item", 20 | "_dtype", 21 | ] 22 | 23 | _memory: Optional[CUDAMemory] 24 | _shape: Tuple[int, ...] 25 | _item: Optional[Number] 26 | _dtype: str 27 | 28 | def __init__(self, shape: Tuple[int, ...], dtype: str) -> None: 29 | self._memory = None 30 | self._item = None 31 | self._shape = shape 32 | self._dtype = dtype 33 | 34 | @property 35 | def shape(self) -> Tuple[Tensor, ...]: 36 | return tuple([ScalarTensor(dim, (), "int32") for dim in self._shape]) 37 | 38 | @property 39 | def dtype(self) -> str: 40 | return self._dtype 41 | 42 | @property 43 | def device(self): 44 | return "cuda" 45 | 46 | @property 47 | def data_ptr(self) -> int: 48 | return self.memory._pointer 49 | 50 | @property 51 | def memory(self) -> CUDAMemory: 52 | if self._memory is None and self._item is not None: 53 | self._memory = CUDAMemory(dtype_info(self.dtype).size_in_bytes) 54 | self.copy_from_numpy(numpy.full((), self._item, getattr(numpy, self.dtype))) 55 | return self._memory 56 | 57 | @staticmethod 58 | def allocate(shape: Tuple[int, ...], dtype: str) -> "CUDATensor": 59 | size = dtype_info(dtype).size_in_bytes 60 | for dim in shape: 61 | assert isinstance(dim, int) 62 | size *= dim 63 | memory = CUDAMemory(size) 64 | result = CUDATensor(shape, dtype) 65 | result._memory = memory 66 | return result 67 | 68 | @staticmethod 69 | def wrap(shape: Tuple[int, ...], dtype: str, memory: CUDAMemory) -> "CUDATensor": 70 | size = dtype_info(dtype).size_in_bytes 71 | for dim in shape: 72 | assert isinstance(dim, int) 73 | size *= dim 74 | assert memory._size == size 75 | result = CUDATensor(shape, dtype) 76 | result._memory = memory 77 | return result 78 | 79 | @staticmethod 80 | def from_numpy(array: numpy.ndarray): 81 | dtype = str(array.dtype).split(".")[-1] 82 | device_array = CUDATensor.allocate(array.shape, dtype) 83 | device_array.copy_from_numpy(array) 84 | return device_array 85 | 86 | @staticmethod 87 | def from_item(item: Number, dtype: str): 88 | result = CUDATensor((), dtype) 89 | result._item = item 90 | return result 91 | 92 | def copy_from_numpy(self, array: numpy.ndarray): 93 | assert array.shape == self._shape, f"{array.shape} vs {self._shape}" 94 | import torch 95 | if isinstance(array, torch.Tensor): 96 | array: torch.Tensor 97 | if not array.is_cpu: 98 | array = array.cpu() 99 | array = array.contiguous() 100 | pointer = array.data_ptr() 101 | else: 102 | array: numpy.ndarray 103 | array = numpy.ascontiguousarray(array) 104 | pointer, _read_only_flag = array.__array_interface__['data'] 105 | dtype = str(array.dtype).split(".")[-1] 106 | assert dtype == self.dtype 107 | size = dtype_info(dtype).size_in_bytes 108 | for dim in array.shape: 109 | size *= dim 110 | sync_cuda() 111 | copy_cuda_memory(self.data_ptr, pointer, size) 112 | 113 | def numpy(self): 114 | host_data = numpy.full(self._shape, 0, self.dtype) 115 | host_data = numpy.ascontiguousarray(host_data) 116 | pointer, _read_only_flag = host_data.__array_interface__['data'] 117 | size = dtype_info(self.dtype).size_in_bytes 118 | for dim in self._shape: 119 | size *= dim 120 | copy_cuda_memory(pointer, self.data_ptr, size) 121 | sync_cuda() 122 | return host_data 123 | 124 | def item(self): 125 | if self._item is not None: 126 | return dtype_info(self.dtype).python_type(self._item) 127 | else: 128 | return self.numpy().item() 129 | 130 | def type(self): 131 | return CUDATensor 132 | -------------------------------------------------------------------------------- /minit/cuda/toolkit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | _cuda_home = None 5 | 6 | 7 | def get_cuda_home(): 8 | global _cuda_home 9 | if _cuda_home is None: 10 | _cuda_home = os.getenv("CUDA_HOME", "/usr/local/cuda") 11 | return _cuda_home 12 | 13 | 14 | def find_cuda_library_directory(): 15 | return os.path.join(get_cuda_home(), "lib64") 16 | 17 | 18 | def find_cuda_include_directory(): 19 | return os.path.join(get_cuda_home(), "include") 20 | 21 | 22 | def find_cuda_libraries(): 23 | libraries = [ 24 | os.path.join(find_cuda_library_directory(), "libcudart.so"), 25 | ] 26 | return libraries 27 | -------------------------------------------------------------------------------- /minit/distributed/communicator.py: -------------------------------------------------------------------------------- 1 | from .operator import ( 2 | DistributedSend, 3 | DistributedRecv, 4 | DistributedBroadcast, 5 | DistributedAllGather, 6 | DistributedReduceScatter, 7 | DistributedScatter, 8 | DistributedAllToAll, 9 | DistributedAllReduce, 10 | ) 11 | from ..core.dispatch import dispatch 12 | from ..core.tensor import Tensor 13 | 14 | 15 | class DistributedCommunicator: 16 | version: Tensor 17 | 18 | def __init__(self, version: Tensor) -> None: 19 | self.version = version 20 | 21 | def send(self, x: Tensor, target: int): 22 | (self.version,) = dispatch(DistributedSend(target), self.version, x) 23 | 24 | def recv(self, source: int): 25 | (self.version, z) = dispatch(DistributedRecv(source), self.version) 26 | return z 27 | 28 | def broadcast(self, x: Tensor, source: int): 29 | (self.version, z) = dispatch(DistributedBroadcast(source), self.version, x) 30 | return z 31 | 32 | def all_reduce(self, x: Tensor): 33 | (self.version, z) = dispatch(DistributedAllReduce(), self.version, x) 34 | return z 35 | 36 | def all_gather(self, x: Tensor, axis: int): 37 | (self.version, z) = dispatch(DistributedAllGather(axis), self.version, x) 38 | return z 39 | 40 | def reduce_scatter(self, x: Tensor, axis: int): 41 | (self.version, z) = dispatch(DistributedReduceScatter(axis), self.version, x) 42 | return z 43 | 44 | def scatter(self, x: Tensor, source: int, axis: int): 45 | (self.version, z) = dispatch(DistributedScatter(source, axis), self.version, x) 46 | return z 47 | 48 | def all_to_all(self, x: Tensor, gather_axis: int, scatter_axis: int): 49 | (self.version, z) = dispatch(DistributedAllToAll(gather_axis, scatter_axis), self.version, x) 50 | return z 51 | -------------------------------------------------------------------------------- /minit/distributed/group.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Callable 3 | 4 | 5 | @dataclass(frozen=True) 6 | class DistributedGroup: 7 | size: int 8 | rank: int 9 | 10 | 11 | _WORLD = None 12 | 13 | 14 | def get_world() -> DistributedGroup: 15 | return _WORLD 16 | 17 | 18 | def initialize_world(rank: int, size: int): 19 | global _WORLD 20 | assert _WORLD is None 21 | _WORLD = DistributedGroup(size, rank) 22 | -------------------------------------------------------------------------------- /minit/distributed/operator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from ..core.operator import Operator 4 | 5 | 6 | @dataclass 7 | class DistributedSend(Operator): 8 | target: int 9 | 10 | @dataclass 11 | class DistributedRecv(Operator): 12 | source: int 13 | 14 | @dataclass 15 | class DistributedBroadcast(Operator): 16 | source: int 17 | 18 | @dataclass 19 | class DistributedAllReduce(Operator): 20 | ... 21 | 22 | @dataclass 23 | class DistributedAllGather(Operator): 24 | axis: int 25 | 26 | @dataclass 27 | class DistributedReduceScatter(Operator): 28 | axis: int 29 | 30 | @dataclass 31 | class DistributedScatter(Operator): 32 | source: int 33 | axis: int 34 | 35 | @dataclass 36 | class DistributedAllToAll(Operator): 37 | gather_axis: int 38 | scatter_axis: int 39 | -------------------------------------------------------------------------------- /minit/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .arith import * 2 | from .generate import * 3 | from .index import * 4 | from .linalg import * 5 | from .reduce import * 6 | from .shape import * 7 | from .special import * 8 | from .einops import * 9 | from .control_flow import * 10 | -------------------------------------------------------------------------------- /minit/functional/arith.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | from typing import Optional, Union 3 | from ..core.tensor import Tensor 4 | from ..core.dispatch import dispatch 5 | from ..operator.arith import Add, And, Cast, Constant, Cosine, Equal, Exponential, FloorDivide, GreaterThan, Logarithm, Modulo, Not, Power, Select, SelectMax, Sine, Subtract, Multiply, Divide 6 | from .utils import _broadcast_constant 7 | 8 | 9 | def add(x: Tensor, y: Tensor): 10 | x, y = _broadcast_constant(x, y) 11 | (z,) = dispatch(Add(), x, y) 12 | return z 13 | 14 | 15 | def subtract(x: Tensor, y: Tensor): 16 | x, y = _broadcast_constant(x, y) 17 | (z,) = dispatch(Subtract(), x, y) 18 | return z 19 | 20 | 21 | def multiply(x: Tensor, y: Tensor): 22 | x, y = _broadcast_constant(x, y) 23 | (z,) = dispatch(Multiply(), x, y) 24 | return z 25 | 26 | 27 | def divide(x: Tensor, y: Tensor): 28 | x, y = _broadcast_constant(x, y) 29 | (z,) = dispatch(Divide(), x, y) 30 | return z 31 | 32 | 33 | def floor_divide(x: Tensor, y: Tensor): 34 | x, y = _broadcast_constant(x, y) 35 | (z,) = dispatch(FloorDivide(), x, y) 36 | return z 37 | 38 | 39 | def modulo(x: Tensor, y: Tensor): 40 | x, y = _broadcast_constant(x, y) 41 | (z,) = dispatch(Modulo(), x, y) 42 | return z 43 | 44 | 45 | def power(base: Tensor, exponent: Tensor): 46 | base, exponent = _broadcast_constant(base, exponent) 47 | (z,) = dispatch(Power(), base, exponent) 48 | return z 49 | 50 | 51 | def exponential(x: Tensor): 52 | (z,) = dispatch(Exponential(), x) 53 | return z 54 | 55 | 56 | def logarithm(x: Tensor): 57 | (z,) = dispatch(Logarithm(), x) 58 | return z 59 | 60 | 61 | def square(x: Tensor): 62 | from .generate import fill 63 | return power(x, fill(2, x.shape, x.dtype)) 64 | 65 | 66 | def square_root(x: Tensor): 67 | from .generate import fill 68 | return power(x, fill(1/2, x.shape, x.dtype)) 69 | 70 | 71 | def sine(x: Tensor): 72 | (z,) = dispatch(Sine(), x) 73 | return z 74 | 75 | 76 | def cosine(x: Tensor): 77 | (z,) = dispatch(Cosine(), x) 78 | return z 79 | 80 | 81 | def constant(x: Number, dtype: str): 82 | opr = Constant(value=x, dtype=dtype) 83 | (z,) = dispatch(opr) 84 | assert isinstance(z, Tensor) 85 | return z 86 | 87 | 88 | def cast(x: Tensor, dtype: str): 89 | if x.dtype == dtype: 90 | return x 91 | (z,) = dispatch(Cast(dtype), x) 92 | return z 93 | 94 | 95 | def greater_than(x: Tensor, y: Tensor): 96 | x, y = _broadcast_constant(x, y) 97 | (z,) = dispatch(GreaterThan(), x, y) 98 | return z 99 | 100 | 101 | def equal(x: Tensor, y: Tensor): 102 | x, y = _broadcast_constant(x, y) 103 | (z,) = dispatch(Equal(), x, y) 104 | return z 105 | 106 | 107 | def not_equal(x: Tensor, y: Tensor): 108 | return logical_not(equal(x, y)) 109 | 110 | 111 | def less_than(x: Tensor, y: Tensor): 112 | x, y = _broadcast_constant(x, y) 113 | return logical_and(logical_not(greater_than(x, y)), logical_not(equal(x, y))) 114 | 115 | 116 | def logical_and(x: Tensor, y: Tensor): 117 | x, y = _broadcast_constant(x, y) 118 | (z,) = dispatch(And(), x, y) 119 | return z 120 | 121 | 122 | def logical_not(x: Tensor): 123 | (x,) = _broadcast_constant(x) 124 | (z,) = dispatch(Not(), x) 125 | return z 126 | 127 | 128 | def logical_or(x: Tensor, y: Tensor): 129 | x, y = _broadcast_constant(x, y) 130 | return logical_and(logical_not(x), logical_not(y)) 131 | 132 | 133 | def greater_equal(x: Tensor, y: Tensor): 134 | x, y = _broadcast_constant(x, y) 135 | return logical_or(greater_than(x, y), equal(x, y)) 136 | 137 | 138 | def less_equal(x: Tensor, y: Tensor): 139 | x, y = _broadcast_constant(x, y) 140 | return logical_not(greater_than(x, y)) 141 | 142 | 143 | def select_max(x: Tensor, y: Tensor): 144 | x, y = _broadcast_constant(x, y) 145 | (z,) = dispatch(SelectMax(), x, y) 146 | return z 147 | 148 | 149 | def select(index: Tensor, *args: Tensor, dtype: Optional[str]=None): 150 | args = _broadcast_constant(*args, shape=index.shape, dtype=dtype) 151 | (z,) = dispatch(Select(), index, *args) 152 | return z 153 | 154 | 155 | def where(condition: Tensor, true_value: Tensor, false_value: Tensor, *, dtype: Optional[str]=None): 156 | assert condition.dtype == "bool" 157 | return select(condition, false_value, true_value, dtype=dtype) 158 | -------------------------------------------------------------------------------- /minit/functional/control_flow.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | from typing import Callable, Sequence, Tuple, Union 3 | 4 | from .arith import constant 5 | 6 | from ..core.scalar import ScalarTensor 7 | from ..core.meta import MetaTensor 8 | from ..core.tensor import Tensor 9 | from ..core.dispatch import dispatch, register_dispatch 10 | from ..operator.control_flow import ForLoop, WhileLoop, IfBlock, Block 11 | from .utils import _broadcast_constant 12 | from ..trace.function import trace_function 13 | from ..trace.executor import TraceGraphExecutor 14 | 15 | 16 | def _make_meta_input(arg: Tensor): 17 | return MetaTensor(tuple(MetaTensor((), dim.dtype) for dim in arg.shape), arg.dtype) 18 | 19 | 20 | def _make_meta_inputs(*args: Tensor): 21 | return [_make_meta_input(arg) for arg in args] 22 | 23 | 24 | def for_loop(count: Tensor, variables: Sequence[Tensor], body: Block) -> Tuple[Tensor, ...]: 25 | graph = trace_function(body, _make_meta_inputs(count, *variables)) 26 | results = dispatch(ForLoop(graph), count, *variables) 27 | return results 28 | 29 | 30 | def while_loop(condition: Tensor, variables: Sequence[Tensor], body: Block) -> Tuple[Tensor, ...]: 31 | graph = trace_function(body, _make_meta_inputs(condition, *variables)) 32 | results = dispatch(WhileLoop(graph), condition, *variables) 33 | return results 34 | 35 | 36 | def if_block(condition: Tensor, variables: Sequence[Tensor], true_body: Block, false_body: Block) -> Tuple[Tensor, ...]: 37 | true_graph = trace_function(true_body, _make_meta_inputs(*variables)) 38 | false_graph = trace_function(false_body, _make_meta_inputs(*variables)) 39 | results = dispatch(IfBlock(true_graph, false_graph), condition, *variables) 40 | return results 41 | 42 | 43 | @register_dispatch() 44 | def dispatch_for_loop(op: ForLoop, count: MetaTensor, *args: Tensor): 45 | outputs = TraceGraphExecutor(op.body)(count, *args) 46 | return outputs 47 | 48 | 49 | @register_dispatch() 50 | def dispatch_for_loop(op: ForLoop, count: ScalarTensor, *args: Tensor): 51 | body_executor = TraceGraphExecutor(op.body) 52 | for i in count.item(): 53 | outputs = tuple(body_executor(constant(i, count.dtype), *args)) 54 | args = [*args[-len(outputs):], *outputs] 55 | return outputs 56 | -------------------------------------------------------------------------------- /minit/functional/einops.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Dict, List, Optional, Tuple, Union 3 | import nvtx 4 | 5 | from ..core.meta import MetaTensor 6 | 7 | from .linalg import batch_matrix_multiply, matrix_multiply 8 | from ..graph import SubGraph 9 | from ..core.tensor import Tensor 10 | 11 | 12 | __all__ = [ 13 | "einsum", 14 | "rearrange", 15 | ] 16 | 17 | ShapeTerm = Union[List[str], str] 18 | ShapeSpec = List[ShapeTerm] 19 | ShapeSpecs = List[ShapeSpec] 20 | 21 | 22 | def parse_equation_term(term: str): 23 | dims = [] 24 | i = 0 25 | while i < len(term): 26 | if term[i] =="(": 27 | i += 1 28 | subdims = [] 29 | while term[i] != ")": 30 | assert term[i].isalpha() 31 | subdims.append(term[i]) 32 | i += 1 33 | dims.append(subdims) 34 | elif term[i] == ".": 35 | assert term[i:i+3] == "..." 36 | i += 2 37 | dims.append("...") 38 | else: 39 | assert term[i].isalpha() 40 | dims.append(term[i]) 41 | i += 1 42 | return dims 43 | 44 | 45 | def serialize_spec(spec: ShapeSpec): 46 | term = "" 47 | for dim in spec: 48 | if isinstance(dim, list): 49 | term += "(" 50 | for subdim in dim: 51 | term += subdim 52 | term += ")" 53 | else: 54 | term += dim 55 | return term 56 | 57 | 58 | def flatten_spec(spec: ShapeSpec): 59 | flatten_spec = [] 60 | for dim in spec: 61 | if isinstance(dim, list): 62 | dim: List[str] 63 | flatten_spec.extend(dim) 64 | else: 65 | dim: str 66 | flatten_spec.append(dim) 67 | return flatten_spec 68 | 69 | 70 | def parse_equation(equation: str) -> List[ShapeSpecs]: 71 | [args, output] = equation.split("->") 72 | args = args.split(",") 73 | input_specs = [parse_equation_term(arg) for arg in args] 74 | output_spec = parse_equation_term(output) 75 | return [input_specs, [output_spec]] 76 | 77 | 78 | def deduce_ellipsis(arg_specs: ShapeSpecs, args: Tuple[Tensor, ...]) -> Optional[List[str]]: 79 | ellipsis = None 80 | for arg_spec, arg in zip(arg_specs, args): 81 | for dim in arg_spec: 82 | if dim == "...": 83 | ellipsis_size = len(arg.shape) - (len(arg_spec)-1) 84 | assert ellipsis_size >= 0 85 | assert ellipsis is None 86 | ellipsis = [f".{i}" for i in range(ellipsis_size)] 87 | return ellipsis 88 | 89 | 90 | def expand_ellipsis(specs: ShapeSpecs, ellipsis: Optional[List[str]]): 91 | if ellipsis is None: 92 | return 93 | count = 0 94 | for spec in specs: 95 | for i in range(len(spec)): 96 | if spec[i] == "...": 97 | new_spec = spec[:i] + ellipsis + spec[i+1:] 98 | spec.clear() 99 | spec.extend(new_spec) 100 | count += 1 101 | return count 102 | 103 | 104 | @nvtx.annotate("einsum") 105 | def einsum(equation: str, *args: Tensor, variables: Optional[Dict[str, Tensor]]=None): 106 | if variables is None: 107 | variables = {} 108 | input_specs, output_specs = parse_equation(equation) 109 | ellipsis = deduce_ellipsis(input_specs, args) 110 | expand_ellipsis(input_specs, ellipsis) 111 | expand_ellipsis(output_specs, ellipsis) 112 | [output_spec] = output_specs 113 | flatten_input_specs = list(map(flatten_spec, input_specs)) 114 | flatten_output_spec = flatten_spec(output_spec) 115 | dim_counts = defaultdict(lambda: 0) 116 | for spec in [*flatten_input_specs[1:], flatten_output_spec]: 117 | for dim in spec: 118 | dim_counts[dim] += 1 119 | dims = flatten_input_specs[0] 120 | x = args[0] 121 | x_spec = input_specs[0] 122 | for i, (arg, input_spec, flatten_input_spec) in enumerate(zip(args[1:], input_specs[1:], flatten_input_specs[1:])): 123 | next_dims = flatten_input_spec 124 | bs = [] 125 | ms = [] 126 | ns = [] 127 | ks = [] 128 | for dim in next_dims: 129 | dim_counts[dim] -= 1 130 | count = dim_counts[dim] 131 | if dim in dims: 132 | if count > 0: 133 | bs.append(dim) 134 | else: 135 | ks.append(dim) 136 | else: 137 | ns.append(dim) 138 | for dim in dims: 139 | if dim not in next_dims: 140 | ms.append(dim) 141 | if len(bs) == 0: 142 | a_spec = [ms, ks] 143 | b_spec = [ns, ks] 144 | c_spec = [ms, ns] 145 | a = rearrange_impl(x_spec, a_spec, x, variables) 146 | b = rearrange_impl(input_spec, b_spec, arg, variables) 147 | c = matrix_multiply(a, b) 148 | x, x_spec = c, c_spec 149 | else: 150 | a_spec = [bs, ms, ks] 151 | b_spec = [bs, ns, ks] 152 | c_spec = [bs, ms, ns] 153 | a = rearrange_impl(x_spec, a_spec, x, variables) 154 | b = rearrange_impl(input_spec, b_spec, arg, variables) 155 | c = batch_matrix_multiply(a, b) 156 | x, x_spec = c, c_spec 157 | x = rearrange_impl(x_spec, output_spec, x, variables) 158 | return x 159 | 160 | 161 | def rearrange_impl(input_spec, output_spec, x: Tensor, variables: Dict[str, Tensor]): 162 | # TODO: handle broadcast 163 | # expand axes 164 | for i, (dim, dim_size) in reversed(list(enumerate(zip(input_spec, x.shape)))): 165 | if isinstance(dim, list): 166 | subdim_sizes = [] 167 | unknown_dim = None 168 | for subdim in dim: 169 | if subdim in variables: 170 | subdim_sizes.append(variables[subdim]) 171 | else: 172 | subdim_sizes.append(None) 173 | assert unknown_dim is None 174 | unknown_dim = subdim 175 | if unknown_dim is not None: 176 | for subdim_size in subdim_sizes: 177 | dim_size //= subdim_size 178 | variables[unknown_dim] = dim_size 179 | for i in range(len(subdim_sizes)): 180 | if subdim_sizes[i] is None: 181 | subdim_sizes[i] = dim_size 182 | x = x.expand(i, subdim_sizes) 183 | else: 184 | if dim not in variables: 185 | variables[dim] = dim_size 186 | else: 187 | # TODO: assertion? 188 | pass 189 | # shuffle axes 190 | flatten_input_spec = flatten_spec(input_spec) 191 | flatten_output_spec = flatten_spec(output_spec) 192 | assert not isinstance(x, MetaTensor) 193 | for i, dim in enumerate(flatten_output_spec): 194 | if dim not in flatten_input_spec: 195 | flatten_input_spec.insert(i, dim) 196 | x = x.add_axis(i, variables[dim]) 197 | assert not isinstance(x, MetaTensor) 198 | if flatten_input_spec[i] == dim: 199 | continue 200 | j = flatten_input_spec.index(dim) 201 | flatten_input_spec[i], flatten_input_spec[j] = flatten_input_spec[j], flatten_input_spec[i] 202 | x = x.transpose(i, j) 203 | assert not isinstance(x, MetaTensor) 204 | # fold axes 205 | for i, dim in enumerate(output_spec): 206 | if isinstance(dim, list): 207 | dim: List[str] 208 | x = x.fold(i, i+len(dim)) 209 | assert not isinstance(x, MetaTensor) 210 | return x 211 | 212 | 213 | @nvtx.annotate("rearrange") 214 | def rearrange(equation: str, x: Tensor, variables: Optional[Dict[str, Tensor]]=None) -> Tensor: 215 | if variables is None: 216 | variables = {} 217 | input_specs, output_specs = parse_equation(equation) 218 | ellipsis = deduce_ellipsis(input_specs, (x,)) 219 | expand_ellipsis(input_specs, ellipsis) 220 | expand_ellipsis(output_specs, ellipsis) 221 | [output_spec] = output_specs 222 | assert len(input_specs) == 1 223 | input_spec = input_specs[0] 224 | return rearrange_impl(input_spec, output_spec, x, variables) 225 | -------------------------------------------------------------------------------- /minit/functional/generate.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | from typing import Optional, Sequence, Tuple, Union 3 | from ..core.dispatch import dispatch 4 | from ..operator.generate import Fill, GenerateInterval, GenerateSequence 5 | from ..core.tensor import Tensor 6 | from .utils import _broadcast_constant, _convert_constant 7 | 8 | 9 | def generate_interval(start: Union[Tensor, Number], stop: Union[Tensor, Number], step: Union[Tensor, Number]=1, *, dtype: str): 10 | start, stop, step = _convert_constant(start, stop, step, dtype=dtype) 11 | (z,) = dispatch(GenerateInterval(), start, stop, step) 12 | return z 13 | 14 | 15 | def generate_sequence(start: Union[Tensor, Number], size: Union[Tensor, Number], step: Union[Tensor, Number]=1, *, dtype: str): 16 | start, step = _convert_constant(start, step, dtype=dtype) 17 | (size,) = _convert_constant(size) 18 | (z,) = dispatch(GenerateSequence(), start, size, step) 19 | return z 20 | 21 | 22 | def fill(value: Union[Number, Tensor], shape: Sequence[Tensor], dtype: Optional[str]=None): 23 | shape = _convert_constant(*shape) 24 | (value,) = _broadcast_constant(value, dtype=dtype, shape=()) 25 | if dtype is not None: 26 | assert dtype == value.dtype 27 | assert len(value.shape) == 0 28 | (z,) = dispatch(Fill(), value, *shape) 29 | return z 30 | -------------------------------------------------------------------------------- /minit/functional/index.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Tuple 3 | from ..core.tensor import Tensor 4 | from ..core.dispatch import dispatch 5 | from ..operator.index import Slice, SliceSet, Index, IndexSet, Split, Tie 6 | from .utils import _convert_constant 7 | 8 | 9 | def slice(x: Tensor, start: Tensor, stop: Tensor, axis: int): 10 | start, stop = _convert_constant(start, stop) 11 | (z,) = dispatch(Slice(axis=axis), x, start, stop) 12 | return z 13 | 14 | 15 | def slice_set(x: Tensor, start: Tensor, stop: Tensor, axis: int, value: Tensor): 16 | start, stop = _convert_constant(start, stop) 17 | (z,) = dispatch(SliceSet(axis=axis), x, start, stop, value) 18 | return z 19 | 20 | 21 | def index(x: Tensor, index: Tensor, axis: int): 22 | (z,) = dispatch(Index(axis=axis), x, index) 23 | return z 24 | 25 | 26 | def index_set(x: Tensor, index: Tensor, axis: int, value: Tensor): 27 | (z,) = dispatch(IndexSet(axis=axis), x, index, value) 28 | return z 29 | 30 | 31 | def split(x: Tensor, axis: int, sizes: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: 32 | sizes = _convert_constant(*sizes) 33 | from .arith import constant 34 | zs = [] 35 | offset = 0 36 | for size in sizes: 37 | offset += size.item() 38 | assert offset == x.shape[axis].item() 39 | offset = 0 40 | for size in sizes: 41 | start = offset 42 | stop = start + size.item() 43 | zs.append(slice(x, constant(start, dtype="int32"), constant(stop, dtype="int32"), axis)) 44 | offset = stop 45 | return tuple(zs) 46 | 47 | 48 | def tie(xs: Tuple[Tensor, ...], axis: int): 49 | (z,) = dispatch(Tie(axis=axis), *xs) 50 | return z 51 | -------------------------------------------------------------------------------- /minit/functional/linalg.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from ..core.tensor import Tensor 4 | from ..core.dispatch import dispatch 5 | from ..operator.linalg import MatrixMultiply, BatchMatrixMultiply, TriangleLower, TriangleUpper 6 | from .utils import _convert_constant 7 | 8 | 9 | def matrix_multiply(x: Tensor, y: Tensor): 10 | from .shape import fold, expand 11 | assert len(y.shape) == 2 12 | if len(x.shape) > 2: 13 | ms = x.shape[:len(x.shape)-1] 14 | x = fold(x, 0, len(ms)) 15 | else: 16 | ms = None 17 | (z,) = dispatch(MatrixMultiply(), x, y) 18 | if ms is not None: 19 | z = expand(z, 0, _convert_constant(*ms)) 20 | return z 21 | 22 | 23 | def batch_matrix_multiply(x: Tensor, y: Tensor): 24 | from .shape import fold, expand 25 | assert len(x.shape) > 2 26 | assert len(y.shape) > 2 27 | bs = x.shape[:-2] 28 | # assert bs == y.shape[:-2] 29 | x = fold(x, 0, len(bs)) 30 | y = fold(y, 0, len(bs)) 31 | (z,) = dispatch(BatchMatrixMultiply(), x, y) 32 | z = expand(z, 0, _convert_constant(*bs)) 33 | return z 34 | 35 | 36 | def triangle_upper(x: Tensor, diagonal: Tensor): 37 | (diagonal,) = _convert_constant(diagonal) 38 | (z,) = dispatch(TriangleUpper(), x, diagonal) 39 | return z 40 | 41 | 42 | def triangle_lower(x: Tensor, diagonal: Tensor): 43 | (diagonal,) = _convert_constant(diagonal) 44 | (z,) = dispatch(TriangleLower(), x, diagonal) 45 | return z 46 | -------------------------------------------------------------------------------- /minit/functional/reduce.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from ..core.scalar import ScalarTensor 3 | from ..core.meta import MetaTensor 4 | from ..core.dispatch import dispatch, register_dispatch 5 | from ..core.tensor import Tensor 6 | from ..operator.reduce import Max, Sum 7 | 8 | 9 | def sum(x: Tensor, axis: int) -> Tensor: 10 | (z,) = dispatch(Sum(axis=axis), x) 11 | return z 12 | 13 | 14 | def mean(x: Tensor, axis: int) -> Tensor: 15 | from .arith import divide 16 | from .generate import fill 17 | size = x.shape[axis] 18 | return divide(sum(x, axis), fill(size, x.shape[:axis] + (1,) + x.shape[axis+1:], x.dtype)) 19 | 20 | 21 | def max(x: Tensor, axis: int) -> Tensor: 22 | (z,) = dispatch(Max(axis=axis), x) 23 | return z 24 | 25 | 26 | @register_dispatch() 27 | def dispatch_max(op: Union[Max, Sum], x: MetaTensor): 28 | shape = x.shape[:op.axis] + (ScalarTensor(1, (), "int32"),) + x.shape[op.axis+1:] 29 | z = MetaTensor(shape, x.dtype) 30 | return (z,) 31 | -------------------------------------------------------------------------------- /minit/functional/shape.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | from ..core.scalar import ScalarTensor 4 | 5 | from ..core.tensor import Tensor 6 | from ..core.dispatch import dispatch, register_dispatch 7 | from ..operator.shape import AddAxis, Broadcast, Fold, Expand, Reinterpret, RemoveAxis, Transpose 8 | from .utils import _convert_constant 9 | 10 | 11 | def fold(x: Tensor, start: int, stop: int) -> Tensor: 12 | if stop - start == 1: 13 | return x 14 | (z,) = dispatch(Fold(start=start, stop=stop), x) 15 | return z 16 | 17 | 18 | def expand(x: Tensor, axis: int, sizes: Tuple[Tensor, ...]) -> Tensor: 19 | if len(sizes) == 1: 20 | # TODO: assertion 21 | return x 22 | sizes = _convert_constant(*sizes) 23 | (z,) = dispatch(Expand(axis=axis), x, *sizes) 24 | return z 25 | 26 | 27 | def add_axis(x: Tensor, axis: int, size: Optional[Tensor] = None) -> Tensor: 28 | if size is not None: 29 | (size,) = _convert_constant(size) 30 | (z,) = dispatch(AddAxis(axis=axis), x) 31 | if size is not None: 32 | z = broadcast(z, axis, size) 33 | return z 34 | 35 | 36 | def remove_axis(x: Tensor, axis: int) -> Tensor: 37 | (z,) = dispatch(RemoveAxis(axis=axis), x) 38 | return z 39 | 40 | 41 | def broadcast(x: Tensor, axis: int, size: Tensor) -> Tensor: 42 | (size,) = _convert_constant(size) 43 | assert axis < len(x.shape) 44 | (z,) = dispatch(Broadcast(axis=axis), x, size) 45 | return z 46 | 47 | 48 | def transpose(x: Tensor, axis_a: int, axis_b: int) -> Tensor: 49 | if axis_a == axis_b: 50 | return x 51 | if axis_a > axis_b: 52 | axis_a, axis_b = axis_b, axis_a 53 | (z,) = dispatch(Transpose(axis_a, axis_b), x) 54 | return z 55 | 56 | 57 | def repeat(x: Tensor, axis: int, size: Tensor) -> Tensor: 58 | (size,) = _convert_constant(size) 59 | return fold(broadcast(add_axis(x, axis), axis, size), axis, axis+2) 60 | 61 | 62 | def repeat_interleaved(x: Tensor, axis: int, size: Tensor) -> Tensor: 63 | (size,) = _convert_constant(size) 64 | return fold(broadcast(add_axis(x, axis+1), axis+1, size), axis, axis+2) 65 | 66 | 67 | def reinterpret(x: Tensor, target: str) -> Tensor: 68 | (z,) = dispatch(Reinterpret(target), x) 69 | return z 70 | 71 | 72 | @register_dispatch() 73 | def dispatch_add_axis(op: AddAxis, x: ScalarTensor): 74 | shape = x._shape[:op.axis] + (ScalarTensor(1, (), "int32"),) + x._shape[op.axis:] 75 | z = ScalarTensor(x._value, shape, x._dtype) 76 | return (z,) 77 | -------------------------------------------------------------------------------- /minit/functional/special.py: -------------------------------------------------------------------------------- 1 | from ..operator.special import RMSNorm, RoPE, Sigmoid, Softmax 2 | from ..core.dispatch import dispatch 3 | from ..core.tensor import Tensor 4 | 5 | 6 | def sigmoid(x: Tensor): 7 | (z,) = dispatch(Sigmoid(), x) 8 | return z 9 | 10 | 11 | def rms_norm(x: Tensor, weight: Tensor, axis: int, eps: float) -> Tensor: 12 | assert isinstance(eps, float) 13 | assert eps > 0 14 | assert eps < 0.5 15 | (z,) = dispatch(RMSNorm(axis=axis, eps=eps), x, weight) 16 | return z 17 | 18 | 19 | def rope(x: Tensor, freqs_cos: Tensor, freqs_sin: Tensor) -> Tensor: 20 | (z,) = dispatch(RoPE(), x, freqs_cos, freqs_sin) 21 | return z 22 | 23 | 24 | def softmax(x: Tensor, axis: int) -> Tensor: 25 | (z,) = dispatch(Softmax(axis=axis), x) 26 | return z 27 | -------------------------------------------------------------------------------- /minit/functional/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from ..operator.arith import Constant 4 | from ..core.dispatch import dispatch 5 | from ..core.tensor import Tensor 6 | 7 | 8 | def _broadcast_constant(*args, dtype=None, shape=None) -> Tuple[Tensor, ...]: 9 | for arg in args: 10 | if isinstance(arg, Tensor): 11 | if shape is None: 12 | shape = arg.shape 13 | if dtype is None: 14 | dtype = arg.dtype 15 | # assert tensor.shape == arg.shape 16 | assert arg.dtype == dtype, f"{arg.dtype} vs {dtype}" 17 | def broadcast_scalar(scalar): 18 | assert dtype is not None 19 | (constant,) = dispatch(Constant(scalar, dtype), *shape) 20 | return constant 21 | return tuple([arg if isinstance(arg, Tensor) else broadcast_scalar(arg) for arg in args]) 22 | 23 | 24 | def _convert_constant(*args, dtype="int32") -> Tuple[Tensor, ...]: 25 | from .arith import constant 26 | for arg in args: 27 | if isinstance(arg, Tensor): 28 | assert arg.dtype == dtype 29 | return tuple([arg if isinstance(arg, Tensor) else constant(arg, dtype) for arg in args]) 30 | -------------------------------------------------------------------------------- /minit/graph/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import TYPE_CHECKING, Generic, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union 3 | import weakref 4 | 5 | from ..core.scalar import ScalarTensor 6 | from ..core.meta import MetaTensor 7 | from ..core.operator import Operator 8 | from ..core.tensor import Tensor 9 | 10 | 11 | @dataclass(frozen=True) 12 | class GraphRef: 13 | if TYPE_CHECKING: 14 | graph: weakref.ref["Graph"] 15 | else: 16 | graph: weakref.ref 17 | 18 | def __call__(self) -> "Graph": 19 | graph = self.graph() 20 | assert graph is not None 21 | return graph 22 | 23 | @property 24 | def valid(self): 25 | return self.graph() is not None 26 | 27 | def __post_init__(self): 28 | assert isinstance(self.graph, weakref.ref) 29 | 30 | 31 | class NodeBase: 32 | __slots__ = [ 33 | "graph", 34 | "valid", 35 | ] 36 | 37 | graph: "GraphRef" 38 | valid: bool 39 | 40 | def __init__(self, graph: "GraphRef") -> None: 41 | self.graph = graph 42 | self.valid = True 43 | 44 | 45 | class ValueUse: 46 | target: "ValueNode" 47 | user: Optional["OperatorNode"] 48 | axis = None 49 | 50 | def __init__(self, target: "ValueNode", user: Optional["OperatorNode"]) -> None: 51 | assert isinstance(target, ValueNode) 52 | self.target = target 53 | self.user = user 54 | 55 | def clone(self, user: "OperatorNode"): 56 | return self.target.use_value(user) 57 | 58 | def shape(self, axis: int): 59 | return self.target.use_shape(None, axis) 60 | 61 | def __call__(self): 62 | return self.target 63 | 64 | def __del__(self): 65 | if self.target.valid: 66 | self.target.uses.remove(weakref.ref(self)) 67 | 68 | def __repr__(self): 69 | return f"{self.target}.value" 70 | 71 | 72 | class ShapeUse: 73 | target: "ValueNode" 74 | user: Optional["OperatorNode"] 75 | axis: int 76 | 77 | def __init__(self, target: "ValueNode", user: Optional["OperatorNode"], axis: int) -> None: 78 | assert isinstance(target, ValueNode) 79 | self.target = target 80 | self.user = user 81 | self.axis = axis 82 | 83 | def clone(self, user: "OperatorNode"): 84 | return self.target.use_shape(user, self.axis) 85 | 86 | def __call__(self): 87 | return self.target 88 | 89 | def __del__(self): 90 | if self.target.valid: 91 | self.target.uses.remove(weakref.ref(self)) 92 | 93 | def __repr__(self): 94 | return f"{self.target}.shape[{self.axis}]" 95 | 96 | 97 | Use = Union[ValueUse, ShapeUse] 98 | 99 | 100 | class ValueNode(NodeBase): 101 | if TYPE_CHECKING: 102 | uses: List[weakref.ref[Use]] 103 | else: 104 | uses: List[weakref.ref] 105 | 106 | def __init__(self, graph: GraphRef) -> None: 107 | super().__init__(graph) 108 | self.uses = [] 109 | 110 | def use_shape(self, user: Optional["OperatorNode"], axis: int): 111 | use = ShapeUse(self, user, axis) 112 | self.uses.append(weakref.ref(use)) 113 | return use 114 | 115 | def use_value(self, user: Optional["OperatorNode"]): 116 | if isinstance(self, InternalNode): 117 | assert user != self.producer 118 | use = ValueUse(self, user) 119 | self.uses.append(weakref.ref(use)) 120 | return use 121 | 122 | def replace(self, target: "ValueNode"): 123 | if self is target: 124 | return 125 | for use in self.uses: 126 | use().target = target 127 | target.uses += self.uses 128 | self.uses.clear() 129 | self.valid = False 130 | 131 | 132 | class OperatorNode(NodeBase): 133 | __slots__ = [ 134 | "graph", 135 | "valid", 136 | "operator", 137 | "args", 138 | "outputs", 139 | ] 140 | 141 | operator: Operator 142 | args: List[Use] 143 | outputs: Tuple["InternalNode", ...] 144 | 145 | def __init__(self, graph: "GraphRef", operator: Operator, args: Sequence[Use], output_metas: Tuple[MetaTensor, ...]) -> None: 146 | super().__init__(graph) 147 | self.operator = operator 148 | self.args = [arg.clone(self) for arg in args] 149 | self.outputs = tuple(InternalNode(graph, self, i, meta) for i, meta in enumerate(output_metas)) 150 | 151 | def destroy(self): 152 | del self.operator 153 | del self.args 154 | del self.outputs 155 | self.valid = False 156 | 157 | def __repr__(self) -> str: 158 | return repr({ 159 | "operator": self.operator, 160 | "args": self.args, 161 | "outputs": self.outputs, 162 | }) 163 | 164 | 165 | class InternalNode(ValueNode): 166 | __slots__ = [ 167 | "graph", 168 | "valid", 169 | "uses", 170 | "producer", 171 | "index", 172 | "meta", 173 | ] 174 | 175 | producer: "OperatorNode" 176 | index: int 177 | meta: MetaTensor 178 | 179 | def __init__(self, graph: "GraphRef", producer: "OperatorNode", index: int, meta: MetaTensor): 180 | super().__init__(graph) 181 | self.producer = producer 182 | self.index = index 183 | self.meta = meta 184 | 185 | @property 186 | def shape(self): 187 | return self.meta.shape 188 | 189 | @property 190 | def dtype(self): 191 | return self.meta.dtype 192 | 193 | def __repr__(self): 194 | return "InternalNode()" 195 | 196 | 197 | class ConstantNode(ValueNode): 198 | __slots__ = [ 199 | "graph", 200 | "valid", 201 | "uses", 202 | "value", 203 | ] 204 | 205 | value: Tensor 206 | 207 | def __init__(self, graph: "GraphRef", value: Tensor): 208 | super().__init__(graph) 209 | assert isinstance(value, ScalarTensor), f"{value} is not ScalarTensor" 210 | self.value = value 211 | 212 | @property 213 | def shape(self): 214 | return self.value.shape 215 | 216 | @property 217 | def dtype(self): 218 | return self.value.dtype 219 | 220 | 221 | class PlaceholderNode(ValueNode): 222 | __slots__ = [ 223 | "graph", 224 | "valid", 225 | "uses", 226 | "value", 227 | "meta", 228 | ] 229 | 230 | def __init__(self, graph: "GraphRef", meta: MetaTensor): 231 | super().__init__(graph) 232 | self.meta = meta 233 | 234 | @property 235 | def shape(self): 236 | return self.meta.shape 237 | 238 | @property 239 | def dtype(self): 240 | return self.meta.dtype 241 | 242 | def __repr__(self): 243 | return "PlaceholderNode()" 244 | 245 | 246 | TensorNode = Union[InternalNode, ConstantNode, PlaceholderNode] 247 | Node = Union[TensorNode, OperatorNode] 248 | 249 | 250 | _T = TypeVar("_T") 251 | 252 | class LinkedListVertex(Generic[_T]): 253 | prev: "LinkedListEdge[_T]" 254 | next: "LinkedListEdge[_T]" 255 | 256 | class LinkedListEdge(Generic[_T]): 257 | prev: "LinkedListVertex[_T]" 258 | next: "LinkedListVertex[_T]" 259 | value: Optional[_T] 260 | 261 | def __init__(self, prev: "LinkedListVertex[_T]", next: "LinkedListVertex[_T]", value: Optional[_T] = None) -> None: 262 | super().__init__() 263 | self.prev = prev 264 | self.next = next 265 | self.value = value 266 | prev.next = self 267 | next.prev = self 268 | 269 | class LinkedListIterator(Generic[_T]): 270 | current: LinkedListVertex[_T] 271 | tail: LinkedListVertex[_T] 272 | 273 | def __init__(self, head: LinkedListVertex[_T], tail: LinkedListVertex[_T]): 274 | self.current = head 275 | self.tail = tail 276 | 277 | def clone(self): 278 | return LinkedListIterator(self.current, self.tail) 279 | 280 | def __next__(self) -> _T: 281 | while self.current is not self.tail: 282 | value = self.current.next.value 283 | self.current = self.current.next.next 284 | if value is not None: 285 | return value 286 | raise StopIteration 287 | 288 | 289 | class LinkedList(Generic[_T], Iterable[_T]): 290 | head: LinkedListVertex[_T] 291 | tail: LinkedListVertex[_T] 292 | 293 | def __init__(self): 294 | self.head = LinkedListVertex() 295 | self.tail = LinkedListVertex() 296 | LinkedListEdge(self.head, self.tail, None) 297 | LinkedListEdge(self.tail, self.head, None) 298 | 299 | def view(self): 300 | return LinkedListView(self.head, self.tail) 301 | 302 | def tolist(self) -> List[_T]: 303 | result = [] 304 | node = self.head 305 | while node is not self.tail: 306 | if node.next.value is not None: 307 | result.append(node.next.value) 308 | node = node.next.next 309 | return result 310 | 311 | def __iter__(self) -> LinkedListIterator[_T]: 312 | return LinkedListIterator(self.head, self.tail) 313 | 314 | 315 | class LinkedListView(Generic[_T], Iterable[_T]): 316 | head: LinkedListVertex[_T] 317 | tail: LinkedListVertex[_T] 318 | 319 | def __init__(self, head: LinkedListVertex[_T], tail: LinkedListVertex[_T]) -> None: 320 | super().__init__() 321 | assert head is not tail 322 | self.head = head 323 | self.tail = tail 324 | 325 | def tolist(self) -> List[_T]: 326 | result = [] 327 | node = self.head 328 | while node is not self.tail: 329 | if node.next.value is not None: 330 | result.append(node.next.value) 331 | node = node.next.next 332 | return result 333 | 334 | def clear(self): 335 | node = self.head.next.next 336 | while node is not self.tail: 337 | next = node.next.next 338 | node.prev = None 339 | node.next = None 340 | node = next 341 | LinkedListEdge(self.head, self.tail, None) 342 | 343 | def fill(self, list: List[_T]): 344 | node = self.head.next.next 345 | while node is not self.tail: 346 | next = node.next.next 347 | node.prev = None 348 | node.next = None 349 | node = next 350 | if len(list) == 0: 351 | LinkedListEdge(self.head, self.tail, None) 352 | else: 353 | node = self.head 354 | for item in list: 355 | new_node = LinkedListVertex() 356 | LinkedListEdge(node, new_node, item) 357 | node = new_node 358 | 359 | def append(self, value: _T): 360 | node = LinkedListVertex() 361 | self.tail.prev.next = node 362 | node.prev = self.tail.prev 363 | LinkedListEdge(node, self.tail, value) 364 | 365 | def __iter__(self) -> LinkedListIterator[_T]: 366 | return LinkedListIterator(self.head, self.tail) 367 | 368 | 369 | class Graph: 370 | __slots__ = [ 371 | "inputs", 372 | "operators", 373 | "outputs", 374 | "__weakref__", 375 | ] 376 | 377 | inputs: List[PlaceholderNode] 378 | operators: "LinkedList[OperatorNode]" 379 | outputs: List[Use] 380 | 381 | def __init__(self) -> None: 382 | self.inputs = [] 383 | self.operators = LinkedList() 384 | self.outputs = [] 385 | 386 | def create_input(self, meta: MetaTensor): 387 | self.inputs.append(PlaceholderNode(self.ref, meta)) 388 | return self.inputs[-1] 389 | 390 | @property 391 | def ref(self): 392 | return GraphRef(weakref.ref(self)) 393 | 394 | def __del__(self): 395 | del self.inputs 396 | del self.outputs 397 | for operator in self.operators: 398 | operator.destroy() 399 | del self.operators 400 | 401 | 402 | class SubGraph: 403 | __slots__ = [ 404 | "graph", 405 | "inputs", 406 | "operators", 407 | "outputs", 408 | ] 409 | 410 | graph: Graph 411 | inputs: List[Use] 412 | operators: LinkedListView[OperatorNode] 413 | outputs: List[Use] 414 | 415 | def __init__(self, graph: Graph, inputs: Sequence[Use], operators: LinkedListView[OperatorNode], outputs: Sequence[Use]) -> None: 416 | for input in inputs: 417 | assert isinstance(input, (ShapeUse, ValueUse)) 418 | self.graph = graph 419 | self.inputs = list(inputs) 420 | self.operators = operators 421 | self.outputs = list(outputs) 422 | 423 | def __repr__(self): 424 | return repr({ 425 | "inputs": self.inputs, 426 | "operators": list(self.operators), 427 | "outputs": self.outputs, 428 | }) 429 | 430 | 431 | class GraphBuilder: 432 | graph: Graph 433 | inputs: List[Use] 434 | operators: LinkedListView[OperatorNode] 435 | 436 | def __init__(self, graph: Graph, inputs: List[Use], operators: LinkedListView[OperatorNode]) -> None: 437 | assert isinstance(graph, Graph) 438 | assert isinstance(operators, LinkedListView) 439 | self.graph = graph 440 | self.inputs = list(inputs) 441 | self.operators = operators 442 | 443 | def create_operator(self, operator: Operator, args: Sequence[Use], output_metas: Tuple[MetaTensor, ...]) -> Tuple[ValueNode, ...]: 444 | operator_node = OperatorNode(self.graph.ref, operator, args, output_metas) 445 | self.operators.append(operator_node) 446 | return tuple(output.use_value(None) for output in operator_node.outputs) 447 | 448 | def create_constant(self, value: Tensor): 449 | assert not isinstance(value, MetaTensor) 450 | constant = ConstantNode(self.graph.ref, value) 451 | return constant.use_value(None) 452 | 453 | def build(self, *outputs: Use) -> Tuple[Use, ...]: 454 | return outputs 455 | -------------------------------------------------------------------------------- /minit/graph/optimize.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Generic, List, Optional, Sequence, Set, Tuple, TypeVar 2 | 3 | from ..trace.function import trace_function_on_graph 4 | from . import Graph, GraphBuilder, LinkedListIterator, LinkedListView, OperatorNode, ShapeUse, SubGraph, TensorNode, ValueUse 5 | from ..core.tensor import Tensor 6 | 7 | 8 | class OptimizationPass: 9 | _pattern: Optional[SubGraph] = None 10 | 11 | def make_inputs(self) -> Sequence[Tensor]: 12 | raise NotImplementedError() 13 | 14 | def describe(self, *inputs: Tensor) -> Tuple[Tensor, ...]: 15 | raise NotImplementedError() 16 | 17 | def predicate(self, *inputs: Tensor) -> bool: 18 | raise NotImplementedError() 19 | 20 | def rewrite(self, *inputs: Tensor) -> Tuple[Tensor, ...]: 21 | raise NotImplementedError() 22 | 23 | @property 24 | def pattern(self) -> SubGraph: 25 | from ..trace.function import trace_function 26 | if self._pattern is not None: 27 | return self._pattern 28 | inputs = self.make_inputs() 29 | self._pattern = trace_function(self.describe, inputs) 30 | return self._pattern 31 | 32 | 33 | _T = TypeVar("_T") 34 | 35 | 36 | class ListIterator(Generic[_T]): 37 | items: List[_T] 38 | index: int = 0 39 | 40 | def __init__(self, items: List[_T]): 41 | self.items = items 42 | 43 | def __next__(self): 44 | if self.index >= len(self.items): 45 | raise StopIteration 46 | value = self.items[self.index] 47 | self.index += 1 48 | return value 49 | 50 | def __iter__(self): 51 | return self 52 | 53 | 54 | class Matcher: 55 | pattern: Graph 56 | 57 | def __init__(self, pattern: Graph) -> None: 58 | self.pattern = pattern 59 | 60 | def match(self, graph: Graph, operators_head: LinkedListIterator[OperatorNode]) -> Optional[SubGraph]: 61 | pattern = self.pattern 62 | nr_inputs = len(pattern.inputs) 63 | operators = operators_head.clone() 64 | variables = {} 65 | pattern_variables = {} 66 | for i, input in enumerate(pattern.inputs): 67 | pattern_variables[input] = i 68 | variable_id = nr_inputs 69 | users: Set[OperatorNode] = set() 70 | id2variable: Dict[int, TensorNode] = {} 71 | pattern_outputs = set(pattern_output.target for pattern_output in pattern.outputs) 72 | for i, pattern_operator in enumerate(pattern.operators): 73 | try: 74 | operator = next(operators) 75 | except StopIteration: 76 | return None 77 | if operator.operator != pattern_operator.operator: 78 | return None 79 | if len(operator.args) != len(pattern_operator.args): 80 | return None 81 | for arg, pattern_arg in zip(operator.args, pattern_operator.args): 82 | if arg.axis != pattern_arg.axis: 83 | return None 84 | # capture input 85 | arg_id = pattern_variables[pattern_arg.target] 86 | if arg.target not in variables: 87 | if arg_id >= nr_inputs: 88 | return None 89 | variables[arg.target] = arg_id 90 | id2variable[arg_id] = arg.target 91 | if variables[arg.target] != arg_id: 92 | return None 93 | if operator in users: 94 | users.remove(operator) 95 | for output, pattern_output in zip(operator.outputs, pattern_operator.outputs): 96 | variables[output] = variable_id 97 | id2variable[variable_id] = output 98 | pattern_variables[pattern_output] = variable_id 99 | if pattern_output not in pattern_outputs: 100 | for use in output.uses: 101 | user = use().user 102 | if user is not None: 103 | assert user != operator 104 | users.add(user) 105 | variable_id += 1 106 | if len(users) != 0: 107 | for user in users: 108 | assert user.valid 109 | return None 110 | inputs = [] 111 | outputs = [] 112 | for pattern_input in pattern.inputs: 113 | inputs.append(id2variable[pattern_variables[pattern_input]].use_value(None)) 114 | for pattern_output in pattern.outputs: 115 | if isinstance(pattern_output, ShapeUse): 116 | outputs.append(id2variable[pattern_variables[pattern_output.target]].use_shape(None, pattern_output.axis)) 117 | elif isinstance(pattern_output, ValueUse): 118 | outputs.append(id2variable[pattern_variables[pattern_output.target]].use_value(None)) 119 | else: 120 | assert False 121 | return SubGraph(graph, inputs, LinkedListView(operators_head.current, operators.current), outputs) 122 | 123 | 124 | class TraceGraphOptimizer: 125 | graph: SubGraph 126 | 127 | def __init__(self, graph: SubGraph) -> None: 128 | self.graph = graph 129 | 130 | def apply(self, optimization_pass: OptimizationPass) -> int: 131 | pattern = optimization_pass.pattern 132 | matcher = Matcher(pattern) 133 | operators = iter(self.graph.operators) 134 | count = 0 135 | while True: 136 | result = matcher.match(self.graph, operators) 137 | if result is not None: 138 | count += 1 139 | rewrite = optimization_pass.rewrite 140 | old_operators = result.operators.tolist() 141 | for operator in old_operators: 142 | operator.destroy() 143 | result.operators.clear() 144 | builder = GraphBuilder(self.graph.graph, result.inputs, result.operators) 145 | rewrite_outputs = trace_function_on_graph(rewrite, optimization_pass.make_inputs(), builder, result.inputs) 146 | for output, rewrite_output in zip(result.outputs, rewrite_outputs, strict=True): 147 | assert output.axis == rewrite_output.axis 148 | output.target.replace(rewrite_output.target) 149 | else: 150 | try: 151 | next(operators) 152 | except StopIteration: 153 | break 154 | return count 155 | -------------------------------------------------------------------------------- /minit/lazy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GetUpEarlier/minit/48d227e638c0316cf998295f7638f909fbc1b9f6/minit/lazy/__init__.py -------------------------------------------------------------------------------- /minit/lazy/dispatch.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from ..core.meta import MetaTensor 4 | from ..core.dispatch import dispatch, register_dispatch 5 | from .tensor import Expression, LazyTensor 6 | from ..core.scalar import ScalarTensor 7 | from ..core.operator import Operator 8 | from ..core.device_operator import DeviceOperator 9 | from ..core.object import get_origin_or_self 10 | 11 | 12 | def no_device_any_lazy(op, *tys): 13 | return get_origin_or_self(op) != DeviceOperator and any(ty in (LazyTensor, ScalarTensor) for ty in tys) 14 | 15 | 16 | def no_device_any_lazy_but_some(op, *tys): 17 | return get_origin_or_self(op) != DeviceOperator and any(ty is LazyTensor for ty in tys) and any(ty not in (ScalarTensor, LazyTensor) for ty in tys) 18 | 19 | 20 | @register_dispatch(predicate=no_device_any_lazy, priority=-2) 21 | def dispatch_lazy(op: Operator, *args: Union[ScalarTensor, LazyTensor]): 22 | if any(isinstance(arg, LazyTensor) for arg in args): 23 | meta_outputs = dispatch(op, *[MetaTensor(arg.shape, arg.dtype) if isinstance(arg, LazyTensor) else arg for arg in args]) 24 | else: 25 | meta_outputs = dispatch(DeviceOperator(op, "meta"), *[MetaTensor(arg.shape, arg.dtype) if isinstance(arg, LazyTensor) else arg for arg in args]) 26 | expression = Expression(op, args, meta_outputs) 27 | outputs = tuple(LazyTensor(expression, i) for i in range(len(meta_outputs))) 28 | return outputs 29 | 30 | 31 | @register_dispatch(predicate=no_device_any_lazy_but_some, priority=-2) 32 | def dispatch_lazy_decay(op: Operator, *args: Union[ScalarTensor, LazyTensor]): 33 | meta_outputs = dispatch(op, *[MetaTensor(arg.shape, arg.dtype) for arg in args]) 34 | expression = Expression(op, args, meta_outputs) 35 | outputs = tuple(LazyTensor(expression, i) for i in range(len(meta_outputs))) 36 | return outputs 37 | -------------------------------------------------------------------------------- /minit/lazy/tensor.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Tuple 3 | 4 | from ..core.meta import MetaTensor 5 | from ..core.device_operator import DeviceOperator 6 | from ..core.dispatch import dispatch 7 | from ..core.operator import Operator 8 | from ..core.tensor import Tensor 9 | 10 | 11 | @dataclass 12 | class Expression: 13 | op: Operator 14 | args: Tuple["Tensor", ...] 15 | outputs: Tuple["Tensor", ...] 16 | 17 | def evaluate(self, device: str) -> Tuple["Tensor", ...]: 18 | any_lazy = any(isinstance(arg, LazyTensor) for arg in self.args) 19 | if not any_lazy: 20 | outputs = dispatch(DeviceOperator(self.op, device), *self.args) 21 | else: 22 | args = tuple([arg.instantiate(device) if isinstance(arg, LazyTensor) else arg for arg in self.args]) 23 | outputs = dispatch(self.op, *args) 24 | for output in outputs: 25 | assert not isinstance(output, MetaTensor) 26 | return outputs 27 | 28 | 29 | class LazyTensor(Tensor): 30 | def __init__(self, expression: Expression, index: int) -> None: 31 | super().__init__() 32 | self._expression = expression 33 | self._index = index 34 | 35 | @property 36 | def dtype(self): 37 | return self._expression.outputs[self._index].dtype 38 | 39 | @property 40 | def shape(self): 41 | return self._expression.outputs[self._index].shape 42 | 43 | @property 44 | def device(self): 45 | return "lazy" 46 | 47 | def instantiate(self, device: str): 48 | result = self._expression.evaluate(device)[self._index] 49 | if device != "meta": 50 | assert not isinstance(result, MetaTensor) 51 | return result 52 | 53 | def __repr__(self): 54 | return f"LazyTensor{{expression={self._expression}}}" 55 | -------------------------------------------------------------------------------- /minit/module/__init__.py: -------------------------------------------------------------------------------- 1 | from .module import Module -------------------------------------------------------------------------------- /minit/module/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable, Dict, List 3 | import json 4 | 5 | from ..core.torch import TorchTensor 6 | from ..core.tensor import Tensor 7 | from ..core.shape import to_immediate_shape 8 | from .module import Module 9 | 10 | 11 | def load_from_torch(model: Module, path: str): 12 | import torch 13 | print(f"loading checkpoint from {path}") 14 | checkpoint: Dict[str, torch.Tensor] = torch.load(path) 15 | for name, array in checkpoint.items(): 16 | print(f"loading checkpoint {name}") 17 | dtype = model.get_buffer(name).dtype 18 | array = array.to(getattr(torch, dtype)) 19 | model.update_buffer(name, TorchTensor(array)) 20 | return model 21 | 22 | 23 | def load_from_safetensors(model: Module, paths: List[str], epilogue: Callable[[str, Tensor], Tensor] = lambda x: x): 24 | import torch 25 | import safetensors 26 | for path in paths: 27 | print(f"loading safetensors from {path}") 28 | with safetensors.safe_open(path, framework="pt", device="cpu") as f: 29 | for key in f.keys(): 30 | array = f.get_tensor(key) 31 | dtype = model.get_buffer(key).dtype 32 | print(f"loading safetensor {key} {array.shape} {array.dtype} -> {dtype}") 33 | shape = to_immediate_shape(model.get_buffer(key).shape) 34 | array = array.to(getattr(torch, dtype)) 35 | assert array.shape == shape, f"{array.shape} vs {shape}" 36 | model.update_buffer(key, epilogue(key, TorchTensor(array))) 37 | return model 38 | 39 | 40 | def load_from_safetensors_index(model: Module, path: str, epilogue: Callable[[str, Tensor], Tensor] = lambda k, v: v): 41 | with open(path) as f: 42 | index = json.load(f) 43 | parts = list(map(lambda part: os.path.join(os.path.dirname(path), part), dict.fromkeys(index["weight_map"].values()))) 44 | load_from_safetensors(model, parts, epilogue) 45 | -------------------------------------------------------------------------------- /minit/module/list.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, Iterator, List, TypeVar 2 | from .module import Module 3 | 4 | _Module = TypeVar("_Module", bound=Module) 5 | 6 | class ModuleList(Module, Generic[_Module]): 7 | _module_list: List[_Module] 8 | 9 | def __init__(self) -> None: 10 | super().__init__() 11 | self._module_list = [] 12 | 13 | def append(self, module: _Module): 14 | index = len(self._module_list) 15 | self._module_list.append(self.register_module(str(index), module)) 16 | 17 | def __getitem__(self, *args, **kwargs) -> _Module: 18 | return self._module_list.__getitem__(*args, **kwargs) 19 | 20 | def __iter__(self) -> Iterator[_Module]: 21 | return self._module_list.__iter__() 22 | 23 | def __len__(self) -> int: 24 | return self._module_list.__len__() 25 | -------------------------------------------------------------------------------- /minit/module/module.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Generic, Optional, Set, Tuple, TypeVar 2 | from typing_extensions import Self 3 | 4 | from ..core.shape import to_symbolic_shape 5 | from ..core.tensor import Tensor 6 | from ..core.meta import MetaTensor 7 | 8 | _Child = TypeVar("_Child") 9 | _Parent = TypeVar("_Parent") 10 | 11 | class Module(Generic[_Parent]): 12 | parent: _Parent 13 | children: Dict[str, Any] 14 | buffers: Dict[str, Tuple[()]] 15 | 16 | def __init__(self) -> None: 17 | super().__init__() 18 | self.children = {} 19 | self.buffers = {} 20 | 21 | def register_module(self, name: str, module: _Child) -> _Child: 22 | self.children[name] = () 23 | return module 24 | 25 | def register_buffer(self, name: str, shape: Tuple[int, ...], dtype: str, buffer: Optional[Tensor] = None): 26 | if buffer is None: 27 | buffer = MetaTensor(to_symbolic_shape(shape), dtype) 28 | self.buffers[name] = () 29 | return buffer 30 | 31 | def named_buffers(self): 32 | for name, buffer in self.buffers.items(): 33 | yield name, buffer 34 | for module_name, child in self.children.items(): 35 | for name, buffer in child.named_buffers(): 36 | yield module_name + "." + name, buffer 37 | 38 | def get_child(self, name: str) -> "Module": 39 | if "." in name: 40 | prefix, suffix = name.split(".", maxsplit=1) 41 | return self.get_child(prefix).get_child(suffix) 42 | else: 43 | assert name in self.children 44 | if name[0].isdigit(): 45 | return self[int(name)] 46 | else: 47 | return getattr(self, name) 48 | 49 | def get_buffer(self, name: str) -> Tensor: 50 | if "." in name: 51 | prefix, suffix = name.split(".", maxsplit=1) 52 | return self.get_child(prefix).get_buffer(suffix) 53 | else: 54 | assert name in self.buffers 55 | if name[0].isdigit(): 56 | return self[int(name)] 57 | else: 58 | return getattr(self, name) 59 | 60 | def update_buffer(self, name: str, value: Tensor) -> Tensor: 61 | if "." in name: 62 | prefix, suffix = name.split(".", maxsplit=1) 63 | self.get_child(prefix).update_buffer(suffix, value) 64 | else: 65 | assert name in self.buffers 66 | if name[0].isdigit(): 67 | self[int(name)] = value 68 | else: 69 | setattr(self, name, value) 70 | 71 | def __call__(self: Self, *args, **kwargs): 72 | return self.forward(*args, **kwargs) 73 | -------------------------------------------------------------------------------- /minit/nccl/dispatch.py: -------------------------------------------------------------------------------- 1 | import nvtx 2 | 3 | from .tensor import NCCLTensor 4 | from ..distributed.group import get_world 5 | from ..cuda.tensor import CUDATensor 6 | from ..core.tensor import Tensor 7 | from ..core.dispatch import register_dispatch 8 | from ..distributed.operator import DistributedBroadcast, DistributedAllReduce, DistributedAllGather 9 | from .kernel import _generate_nccl_primitives 10 | 11 | @register_dispatch() 12 | def dispatch_all_gather(op: DistributedAllGather, version: NCCLTensor, x: CUDATensor): 13 | if op.axis != 0: 14 | x = x.transpose(0, op.axis) 15 | shape = x._shape 16 | size = 1 17 | for dim in shape: 18 | size *= dim 19 | shape = ((shape[0] * get_world().size),) + shape[1:] 20 | z = CUDATensor.allocate(shape, x._dtype) 21 | all_gather, _, _ = _generate_nccl_primitives(x._dtype) 22 | with nvtx.annotate(f"all_gather_{size}_{x._dtype}"): 23 | all_gather(None, version.comm, x.data_ptr, z.data_ptr, size) 24 | if op.axis != 0: 25 | z = z.transpose(0, op.axis) 26 | return (version, z,) 27 | 28 | @register_dispatch() 29 | def dispatch_all_reduce(op: DistributedAllReduce, version: NCCLTensor, x: CUDATensor): 30 | shape = x._shape 31 | size = 1 32 | for dim in shape: 33 | size *= dim 34 | z = CUDATensor.allocate(shape, x._dtype) 35 | _, all_reduce, _ = _generate_nccl_primitives(x._dtype) 36 | with nvtx.annotate(f"all_reduce_{size}_{x._dtype}"): 37 | all_reduce(None, version.comm, x.data_ptr, z.data_ptr, size) 38 | return (version, z,) 39 | 40 | @register_dispatch() 41 | def dispatch_broadcast(op: DistributedBroadcast, version: NCCLTensor, x: CUDATensor): 42 | shape = x._shape 43 | size = 1 44 | for dim in shape: 45 | size *= dim 46 | z = CUDATensor.allocate(shape, x._dtype) 47 | _, _, broadcast = _generate_nccl_primitives(x._dtype) 48 | with nvtx.annotate(f"broadcast_{size}_{x._dtype}"): 49 | broadcast(None, version.comm, x.data_ptr, z.data_ptr, size, op.source) 50 | return (version, z,) 51 | -------------------------------------------------------------------------------- /minit/nccl/kernel.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import functools 3 | import os 4 | import nvtx 5 | import nvidia.nccl 6 | 7 | from ..compiler.template import substitude 8 | from ..core.cache import cached 9 | from ..cuda.toolkit import find_cuda_include_directory, find_cuda_libraries 10 | from ..compiler.cxx import CXXLibrary, CXXUnit, import_symbol 11 | from ..compiler.gcc import gcc 12 | 13 | 14 | NCCL_DTYPE_MAPPING = { 15 | "bfloat16": "ncclBfloat16", 16 | "float16": "ncclFloat16", 17 | "float32": "ncclFloat32", 18 | "float64": "ncclFloat64", 19 | "int8": "ncclInt8", 20 | "int32": "ncclInt32", 21 | "int64": "ncclInt64", 22 | "uint8": "ncclUint8", 23 | "uint32": "ncclUint32", 24 | "uint64": "ncclUint64", 25 | } 26 | 27 | 28 | @cached() 29 | def _generate_nccl_primitives(dtype: str): 30 | kernel_template =\ 31 | """ 32 | #include 33 | #include 34 | #include 35 | #include 36 | 37 | 38 | #define NCCL_ASSERT(expr) \\ 39 | do { \\ 40 | auto _err = (expr); \\ 41 | if (_err != ncclSuccess) { \\ 42 | throw std::runtime_error(ncclGetErrorString(_err)); \\ 43 | } \\ 44 | } while (0) 45 | 46 | 47 | extern "C" void nccl_all_gather(cudaStream_t stream, ncclComm_t comm, const void* sendbuff, void* recvbuff, size_t size) { 48 | NCCL_ASSERT(ncclAllGather(sendbuff, recvbuff, size, ${NCCL_DATA_TYPE}, comm, stream)); 49 | } 50 | 51 | extern "C" void nccl_all_reduce(cudaStream_t stream, ncclComm_t comm, const void* sendbuff, void* recvbuff, size_t size) { 52 | NCCL_ASSERT(ncclAllReduce(sendbuff, recvbuff, size, ${NCCL_DATA_TYPE}, ncclSum, comm, stream)); 53 | } 54 | 55 | extern "C" void nccl_broadcast(cudaStream_t stream, ncclComm_t comm, const void* sendbuff, void* recvbuff, size_t size, int root) { 56 | NCCL_ASSERT(ncclBroadcast(sendbuff, recvbuff, size, ${NCCL_DATA_TYPE}, root, comm, stream)); 57 | } 58 | """ 59 | kernel_source = substitude(kernel_template, { 60 | "NCCL_DATA_TYPE": NCCL_DTYPE_MAPPING[dtype], 61 | }) 62 | library = gcc.compile(CXXUnit( 63 | source=kernel_source, 64 | libraries=[ 65 | *find_cuda_libraries(), 66 | os.path.join(nvidia.nccl.__path__[0], "lib", "libnccl.so.2"), 67 | ], includes=[ 68 | find_cuda_include_directory(), 69 | os.path.join(nvidia.nccl.__path__[0], "include"), 70 | ] 71 | )) 72 | @import_symbol(library, "nccl_all_gather") 73 | def nccl_all_gather( 74 | stream: ctypes.c_void_p, 75 | comm: ctypes.c_void_p, 76 | sendbuff: ctypes.c_void_p, 77 | recvbuff: ctypes.c_void_p, 78 | size: ctypes.c_size_t, 79 | ): 80 | ... 81 | @import_symbol(library, "nccl_all_reduce") 82 | def nccl_all_reduce( 83 | stream: ctypes.c_void_p, 84 | comm: ctypes.c_void_p, 85 | sendbuff: ctypes.c_void_p, 86 | recvbuff: ctypes.c_void_p, 87 | size: ctypes.c_size_t, 88 | ): 89 | ... 90 | @import_symbol(library, "nccl_broadcast") 91 | def nccl_broadcast( 92 | stream: ctypes.c_void_p, 93 | comm: ctypes.c_void_p, 94 | sendbuff: ctypes.c_void_p, 95 | recvbuff: ctypes.c_void_p, 96 | size: ctypes.c_size_t, 97 | root: ctypes.c_int, 98 | ): 99 | ... 100 | return nccl_all_gather, nccl_all_reduce, nccl_broadcast 101 | -------------------------------------------------------------------------------- /minit/nccl/library.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import functools 3 | import os 4 | import nvtx 5 | import nvidia.nccl 6 | 7 | from ..compiler.template import substitude 8 | from ..core.cache import cached 9 | from ..cuda.toolkit import find_cuda_include_directory, find_cuda_libraries 10 | from ..compiler.cxx import CXXLibrary, CXXUnit, import_symbol 11 | from ..compiler.gcc import gcc 12 | 13 | 14 | @cached() 15 | def _generate_nccl_library(): 16 | kernel_template =\ 17 | """ 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | 24 | #define CUDA_ASSERT(expr) \\ 25 | do { \\ 26 | auto _err = (expr); \\ 27 | if (_err != cudaSuccess) { \\ 28 | throw std::runtime_error(cudaGetErrorString(_err)); \\ 29 | } \\ 30 | } while (0) 31 | 32 | 33 | #define NCCL_ASSERT(expr) \\ 34 | do { \\ 35 | auto _err = (expr); \\ 36 | if (_err != ncclSuccess) { \\ 37 | throw std::runtime_error(ncclGetErrorString(_err)); \\ 38 | } \\ 39 | } while (0) 40 | 41 | extern "C" size_t nccl_unique_id_size() { 42 | return sizeof(ncclUniqueId); 43 | } 44 | 45 | extern "C" void nccl_create_unique_id(char* bytes) { 46 | ncclUniqueId unique_id; 47 | NCCL_ASSERT(ncclGetUniqueId(&unique_id)); 48 | std::memcpy(bytes, &unique_id, sizeof(unique_id)); 49 | } 50 | 51 | extern "C" ncclComm_t nccl_init_rank(int nr_ranks, int rank, char* bytes) { 52 | ncclUniqueId unique_id; 53 | std::memcpy(&unique_id, bytes, sizeof(unique_id)); 54 | CUDA_ASSERT(cudaSetDevice(rank)); 55 | ncclComm_t comm; 56 | NCCL_ASSERT(ncclCommInitRank(&comm, nr_ranks, unique_id, rank)); 57 | return comm; 58 | } 59 | """ 60 | kernel_source = substitude(kernel_template, {}) 61 | library = gcc.compile(CXXUnit( 62 | source=kernel_source, 63 | libraries=[ 64 | *find_cuda_libraries(), 65 | os.path.join(nvidia.nccl.__path__[0], "lib", "libnccl.so.2"), 66 | ], includes=[ 67 | find_cuda_include_directory(), 68 | os.path.join(nvidia.nccl.__path__[0], "include"), 69 | ] 70 | )) 71 | return library 72 | 73 | 74 | _library = _generate_nccl_library() 75 | 76 | @import_symbol(_library, "nccl_unique_id_size") 77 | def nccl_unique_id_size() -> ctypes.c_size_t: 78 | ... 79 | 80 | @import_symbol(_library, "nccl_create_unique_id") 81 | def nccl_create_unique_id( 82 | bytes: ctypes.c_void_p, 83 | ): 84 | ... 85 | 86 | @import_symbol(_library, "nccl_init_rank") 87 | def nccl_init_rank( 88 | nr_ranks: ctypes.c_int, 89 | rank: ctypes.c_int, 90 | bytes: ctypes.c_void_p, 91 | ) -> ctypes.c_void_p: 92 | ... 93 | 94 | def launch_server() -> bytes: 95 | size = nccl_unique_id_size() 96 | id = bytearray(size) 97 | nccl_create_unique_id((ctypes.c_char * size).from_buffer(id)) 98 | return bytes(id) 99 | 100 | def connect_server(nr_ranks: int, rank: int, id: bytes) -> int: 101 | return nccl_init_rank(nr_ranks, rank, id) 102 | -------------------------------------------------------------------------------- /minit/nccl/server.py: -------------------------------------------------------------------------------- 1 | from .library import launch_server 2 | import base64 3 | 4 | def main(): 5 | id = launch_server() 6 | print(f"server launched at: {id}") 7 | base64_id = base64.b64encode(id) 8 | with open(".sync", "wb") as f: 9 | f.write(id) 10 | print(f"base64: {base64_id}") 11 | print("Press enter to exit") 12 | input() 13 | 14 | 15 | if __name__ == '__main__': 16 | main() -------------------------------------------------------------------------------- /minit/nccl/tensor.py: -------------------------------------------------------------------------------- 1 | from ..core.tensor import Tensor 2 | from .library import connect_server 3 | 4 | 5 | class NCCLTensor(Tensor): 6 | comm: int 7 | 8 | def __init__(self, comm: int) -> None: 9 | super().__init__() 10 | self.comm = comm 11 | 12 | @property 13 | def shape(self): 14 | return () 15 | 16 | @property 17 | def dtype(self): 18 | return "int64" 19 | 20 | @property 21 | def device(self): 22 | return "nccl" 23 | 24 | @staticmethod 25 | def connect(unique_id: bytes, rank: int, size: int) -> "NCCLTensor": 26 | comm = connect_server(size, rank, unique_id) 27 | return NCCLTensor(comm) 28 | -------------------------------------------------------------------------------- /minit/operator/__init__.py: -------------------------------------------------------------------------------- 1 | from . import arith 2 | from . import control_flow 3 | from . import generate 4 | from . import index 5 | from . import linalg 6 | from . import memory 7 | from . import random 8 | from . import reduce 9 | from . import shape 10 | from . import special 11 | -------------------------------------------------------------------------------- /minit/operator/arith.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import math 3 | from typing import Any, Union 4 | 5 | from ..core.tensor import Tensor 6 | from ..core.scalar import ScalarTensor 7 | from ..core.meta import MetaTensor 8 | from ..core.dispatch import register_dispatch 9 | from ..core.operator import Operator 10 | 11 | 12 | @dataclass 13 | class Add(Operator): 14 | ... 15 | 16 | @dataclass 17 | class Subtract(Operator): 18 | ... 19 | 20 | @dataclass 21 | class Multiply(Operator): 22 | ... 23 | 24 | @dataclass 25 | class Divide(Operator): 26 | ... 27 | 28 | @dataclass 29 | class FloorDivide(Operator): 30 | ... 31 | 32 | @dataclass 33 | class Modulo(Operator): 34 | ... 35 | 36 | @dataclass 37 | class Power(Operator): 38 | ... 39 | 40 | @dataclass 41 | class Exponential(Operator): 42 | ... 43 | 44 | @dataclass 45 | class Logarithm(Operator): 46 | ... 47 | 48 | @dataclass 49 | class Cosine(Operator): 50 | ... 51 | 52 | @dataclass 53 | class Sine(Operator): 54 | ... 55 | 56 | @dataclass 57 | class Constant(Operator): 58 | value: Any 59 | dtype: str 60 | 61 | @dataclass 62 | class Cast(Operator): 63 | dtype: str 64 | 65 | @dataclass 66 | class GreaterThan(Operator): 67 | ... 68 | 69 | @dataclass 70 | class Equal(Operator): 71 | ... 72 | 73 | @dataclass 74 | class And(Operator): 75 | ... 76 | 77 | @dataclass 78 | class Not(Operator): 79 | ... 80 | 81 | @dataclass 82 | class SelectMax(Operator): 83 | ... 84 | 85 | @dataclass 86 | class Select(Operator): 87 | ... 88 | 89 | @register_dispatch() 90 | def register_constant(op: Constant, *sizes: Tensor): 91 | assert not isinstance(op.value, Tensor), "cannot use tensor as constant" 92 | c = ScalarTensor(op.value, tuple(sizes), op.dtype) 93 | return (c,) 94 | 95 | def register_elemwise_operator(op_type, op_py, dtype=None): 96 | @register_dispatch() 97 | def _register_elemwise_scalar(op: op_type, *args: ScalarTensor): # type: ignore 98 | items = [arg.item() for arg in args] 99 | c_item = op_py(*items) 100 | output_dtype = args[0].dtype if dtype is None else dtype(op, *(arg.dtype for arg in args)) 101 | c = ScalarTensor(c_item, args[0].shape, output_dtype) 102 | return (c,) 103 | 104 | @register_dispatch(priority=-10, predicate=lambda *tys: any(ty is MetaTensor for ty in tys)) 105 | def _register_elemwise_meta(op: op_type, *args: Tensor): # type: ignore 106 | output_dtype = args[0].dtype if dtype is None else dtype(op, *(arg.dtype for arg in args)) 107 | c = MetaTensor(args[0].shape, output_dtype) 108 | return (c,) 109 | 110 | 111 | def same_dtypes(op: Operator, *args: str) -> str: 112 | for arg in args[1:]: 113 | assert arg == args[0] 114 | return args[0] 115 | 116 | 117 | def same_dtypes_return_bool(op: Operator, *args: str) -> str: 118 | for arg in args[1:]: 119 | assert arg == args[0] 120 | return "bool" 121 | 122 | 123 | def same_dtypes_except_first(op: Operator, *args: str) -> str: 124 | for arg in args[2:]: 125 | assert arg == args[1] 126 | return args[1] 127 | 128 | 129 | def dtype_from_operator(op: Operator, *args: Tensor) -> str: 130 | (arg,) = args 131 | return op.dtype 132 | 133 | 134 | def register_elemwise_operators(): 135 | for op_type, op_py, dtype in [ 136 | (Add, lambda x, y: x + y, same_dtypes), 137 | (Subtract, lambda x, y: x - y, same_dtypes), 138 | (Multiply, lambda x, y: x * y, same_dtypes), 139 | (Divide, lambda x, y: x / y, same_dtypes), 140 | (FloorDivide, lambda x, y: x // y, same_dtypes), 141 | (Modulo, lambda x, y: x % y, same_dtypes), 142 | (Power, lambda x, y: pow(x, y), same_dtypes), 143 | (SelectMax, lambda x, y: max(x, y), same_dtypes), 144 | (Select, lambda condition, *args: args[condition], same_dtypes_except_first), 145 | (Sine, lambda x: math.sin(x), same_dtypes), 146 | (Cosine, lambda x: math.cos(x), same_dtypes), 147 | (Exponential, lambda x: math.exp(x), same_dtypes), 148 | (Logarithm, lambda x: math.log(x), same_dtypes), 149 | (GreaterThan, lambda x, y: x > y, same_dtypes_return_bool), 150 | (Equal, lambda x, y: x == y, same_dtypes_return_bool), 151 | (And, lambda x, y: bool(x) and bool(y), same_dtypes_return_bool), 152 | (Not, lambda x: not x, same_dtypes_return_bool), 153 | (Cast, lambda x: x, dtype_from_operator), 154 | ]: 155 | register_elemwise_operator(op_type, op_py, dtype) 156 | 157 | 158 | register_elemwise_operators() 159 | -------------------------------------------------------------------------------- /minit/operator/control_flow.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Sequence 3 | 4 | from ..graph import SubGraph 5 | from ..core.operator import Operator 6 | from ..core.tensor import Tensor 7 | 8 | 9 | class Block: 10 | def __call__(self, *args: Tensor) -> Sequence[Tensor]: 11 | ... 12 | 13 | 14 | @dataclass 15 | class ForLoop(Operator): 16 | body: SubGraph 17 | 18 | 19 | @dataclass 20 | class WhileLoop(Operator): 21 | body: SubGraph 22 | 23 | 24 | @dataclass 25 | class IfBlock(Operator): 26 | true_body: SubGraph 27 | false_body: SubGraph 28 | -------------------------------------------------------------------------------- /minit/operator/generate.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from numbers import Number 3 | from typing import Literal 4 | 5 | from ..core.tensor import Tensor 6 | from ..core.device_operator import DeviceOperator 7 | from ..core.dispatch import register_dispatch 8 | from ..core.meta import MetaTensor 9 | from ..core.scalar import ScalarTensor 10 | from ..core.operator import Operator 11 | 12 | 13 | @dataclass 14 | class GenerateInterval(Operator): 15 | pass 16 | 17 | 18 | @dataclass 19 | class GenerateSequence(Operator): 20 | pass 21 | 22 | 23 | @dataclass 24 | class Fill(Operator): 25 | pass 26 | 27 | 28 | @register_dispatch(priority=0) 29 | def dispatch_generate_interval(op: DeviceOperator[GenerateInterval, Literal["meta"]], start: ScalarTensor, stop: ScalarTensor, step: ScalarTensor): 30 | size = stop.value() - start.value() / step.value() 31 | z = MetaTensor((ScalarTensor(size, (), "int32"),), start.dtype) 32 | return (z,) 33 | 34 | @register_dispatch() 35 | def register_fill(op: DeviceOperator[Fill, Literal["meta"]], value: Tensor, *sizes: Tensor): 36 | z = MetaTensor(sizes, value.dtype) 37 | return (z,) 38 | 39 | @register_dispatch() 40 | def register_fill(op: Fill, value: MetaTensor, *sizes: Tensor): 41 | z = MetaTensor(sizes, value.dtype) 42 | return (z,) 43 | 44 | @register_dispatch() 45 | def register_generate_sequence(op: DeviceOperator[GenerateSequence, Literal["meta"]], start: Tensor, size: Tensor, stop: Tensor): 46 | z = MetaTensor((size,), start.dtype) 47 | return (z,) 48 | 49 | @register_dispatch() 50 | def register_generate_sequence(op: GenerateSequence, start: MetaTensor, size: Tensor, stop: Tensor): 51 | z = MetaTensor((size,), start.dtype) 52 | return (z,) 53 | -------------------------------------------------------------------------------- /minit/operator/index.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from ..core.meta import MetaTensor 4 | from ..core.dispatch import register_dispatch 5 | from ..core.operator import Operator 6 | 7 | 8 | @dataclass 9 | class Slice(Operator): 10 | axis: int 11 | 12 | class SliceSet(Operator): 13 | ... 14 | 15 | @dataclass 16 | class Index(Operator): 17 | axis: int 18 | 19 | class IndexSet(Operator): 20 | ... 21 | 22 | @dataclass 23 | class Split(Operator): 24 | axis: int 25 | 26 | @dataclass 27 | class Tie(Operator): 28 | axis: int 29 | 30 | @register_dispatch() 31 | def register_tie(op: Tie, *args: MetaTensor): 32 | for arg in args[1:]: 33 | assert arg.dtype == args[0].dtype 34 | assert len(arg.shape) == len(args[0].shape) 35 | b = args[0].shape[op.axis] 36 | for arg in args[1:]: 37 | b = b + arg.shape[op.axis] 38 | z = MetaTensor(args[0]._shape[:op.axis] + (b,) + args[0].shape[op.axis+1:], args[0].dtype) 39 | return (z,) 40 | -------------------------------------------------------------------------------- /minit/operator/linalg.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from ..core.tensor import Tensor 4 | from ..core.meta import MetaTensor 5 | from ..core.dispatch import register_dispatch 6 | from ..core.operator import Operator 7 | 8 | 9 | @dataclass 10 | class MatrixMultiply(Operator): 11 | ... 12 | 13 | 14 | @dataclass 15 | class BatchMatrixMultiply(Operator): 16 | ... 17 | 18 | 19 | @dataclass 20 | class TriangleUpper(Operator): 21 | pass 22 | 23 | 24 | @dataclass 25 | class TriangleLower(Operator): 26 | diagonal: int 27 | 28 | 29 | @register_dispatch() 30 | def dispatch_matrix_multiply(op: MatrixMultiply, x: MetaTensor, y: MetaTensor): 31 | assert x.dtype == y.dtype 32 | m, k0 = x.shape 33 | n, k1 = y.shape 34 | z = MetaTensor((m, n), x.dtype) 35 | return (z,) 36 | 37 | 38 | @register_dispatch() 39 | def dispatch_triu(op: TriangleUpper, x: MetaTensor, y: Tensor): 40 | return (x,) 41 | -------------------------------------------------------------------------------- /minit/operator/memory.py: -------------------------------------------------------------------------------- 1 | from ..core.operator import Operator 2 | 3 | 4 | class Assign(Operator): 5 | ... 6 | 7 | class Copy(Operator): 8 | ... 9 | -------------------------------------------------------------------------------- /minit/operator/random.py: -------------------------------------------------------------------------------- 1 | from ..core.operator import Operator 2 | 3 | 4 | class Uniform(Operator): 5 | ... 6 | 7 | class Normal(Operator): 8 | ... 9 | -------------------------------------------------------------------------------- /minit/operator/reduce.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from ..core.operator import Operator 4 | 5 | 6 | @dataclass 7 | class Sum(Operator): 8 | axis: int 9 | 10 | 11 | @dataclass 12 | class Max(Operator): 13 | axis: int 14 | -------------------------------------------------------------------------------- /minit/operator/shape.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from ..core.meta import MetaTensor 4 | from ..core.tensor import Tensor 5 | from ..core.dispatch import register_dispatch 6 | from ..core.operator import Operator 7 | 8 | 9 | @dataclass 10 | class Fold(Operator): 11 | start: int 12 | stop: int 13 | 14 | @dataclass 15 | class Expand(Operator): 16 | axis: int 17 | 18 | @dataclass 19 | class AddAxis(Operator): 20 | axis: int 21 | 22 | @dataclass 23 | class RemoveAxis(Operator): 24 | axis: int 25 | 26 | @dataclass 27 | class Broadcast(Operator): 28 | axis: int 29 | 30 | @dataclass 31 | class Transpose(Operator): 32 | axis_a: int 33 | axis_b: int 34 | 35 | @dataclass 36 | class Reinterpret(Operator): 37 | target: str 38 | 39 | @register_dispatch(priority=-1) 40 | def dispatch_add_axis(op: AddAxis, x: MetaTensor): 41 | from ..functional.arith import constant 42 | z = MetaTensor(x.shape[:op.axis] + (constant(1, "int32"),) + x.shape[op.axis:], x.dtype) 43 | return (z,) 44 | 45 | @register_dispatch(priority=-1) 46 | def dispatch_add_axis(op: RemoveAxis, x: MetaTensor): 47 | z = MetaTensor(x.shape[:op.axis] + x.shape[op.axis+1:], x.dtype) 48 | return (z,) 49 | 50 | @register_dispatch(priority=-1) 51 | def dispatch_add_axis(op: Transpose, x: MetaTensor): 52 | shape = x.shape 53 | z = MetaTensor(tuple([*shape[:op.axis_a], shape[op.axis_b], *shape[op.axis_a+1:op.axis_b], shape[op.axis_a], *shape[op.axis_b+1:]]), x.dtype) 54 | return (z,) 55 | 56 | @register_dispatch(priority=-1) 57 | def dispatch_broadcast(op: Broadcast, x: MetaTensor, size: Tensor): 58 | z = MetaTensor(x.shape[:op.axis] + (size,) + x.shape[op.axis+1:], x.dtype) 59 | return (z,) 60 | 61 | @register_dispatch(priority=-1) 62 | def dispatch_expand(op: Expand, x: Tensor, *sizes: Tensor): 63 | z = MetaTensor(sizes, x.dtype) 64 | return (z,) 65 | 66 | @register_dispatch(priority=-1) 67 | def dispatch_fold(op: Fold, x: MetaTensor): 68 | from ..functional.arith import constant 69 | size = constant(1, "int32") 70 | shape = x.shape 71 | for dim in shape[op.start:op.stop]: 72 | size = size * dim 73 | z = MetaTensor(shape[:op.start] + (size,) + shape[op.stop:], x.dtype) 74 | return (z,) 75 | -------------------------------------------------------------------------------- /minit/operator/special.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from ..core.operator import Operator 4 | 5 | 6 | class Sigmoid(Operator): 7 | ... 8 | 9 | 10 | @dataclass 11 | class RMSNorm(Operator): 12 | axis: int 13 | eps: float 14 | 15 | 16 | @dataclass 17 | class RoPE(Operator): 18 | ... 19 | 20 | 21 | @dataclass 22 | class Softmax(Operator): 23 | axis: int 24 | -------------------------------------------------------------------------------- /minit/quantize/dispatch.py: -------------------------------------------------------------------------------- 1 | from .operator import Dequantize 2 | from ..core.tensor import Tensor 3 | from .tensor import QuantizedTensor 4 | from ..operator.linalg import MatrixMultiply 5 | from ..core.dispatch import dispatch, register_dispatch 6 | 7 | 8 | @register_dispatch() 9 | def dispatch_quantized_matrix_multiply(op: MatrixMultiply, a: Tensor, b: QuantizedTensor): 10 | b = b.dequantize() 11 | return dispatch(op, a, b) 12 | 13 | 14 | @register_dispatch() 15 | def dispatch_dequantize(op: Dequantize, data: Tensor, group: Tensor, zero: Tensor, scale: Tensor): 16 | scale = scale.index(group, axis=op.axis) 17 | zero = zero.index(group, axis=op.axis) 18 | dtype = scale.dtype 19 | output = (data.cast(dtype) - (zero.cast(dtype)+1)) * scale 20 | return (output,) 21 | -------------------------------------------------------------------------------- /minit/quantize/functional.py: -------------------------------------------------------------------------------- 1 | from ..core.dispatch import dispatch 2 | from .operator import Dequantize 3 | from ..core.tensor import Tensor 4 | 5 | 6 | def dequantize(x: Tensor, group: Tensor, zero: Tensor, scale: Tensor, axis: int) -> Tensor: 7 | (z,) = dispatch(Dequantize(axis), x, group, zero, scale) 8 | return z 9 | -------------------------------------------------------------------------------- /minit/quantize/operator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from ..core.operator import Operator 3 | 4 | 5 | @dataclass 6 | class Dequantize(Operator): 7 | axis: int 8 | -------------------------------------------------------------------------------- /minit/quantize/tensor.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, TypeVar 2 | from .functional import dequantize 3 | from ..core.tensor import Tensor 4 | 5 | 6 | _Data = TypeVar("_Data", bound=Tensor) 7 | 8 | 9 | class QuantizedTensor(Tensor, Generic[_Data]): 10 | def __init__(self, data: _Data, group: Tensor, zero: Tensor, scale: Tensor) -> None: 11 | super().__init__() 12 | self._data = data 13 | self._group = group 14 | self._zero = zero 15 | self._scale = scale 16 | 17 | @property 18 | def shape(self): 19 | return self._data.shape 20 | 21 | @property 22 | def dtype(self): 23 | return f"q{self._data.dtype}" 24 | 25 | @property 26 | def device(self): 27 | return self._data.device 28 | 29 | def dequantize(self) -> Tensor: 30 | return dequantize(self._data, self._group, self._zero, self._scale, 0) 31 | -------------------------------------------------------------------------------- /minit/remote/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GetUpEarlier/minit/48d227e638c0316cf998295f7638f909fbc1b9f6/minit/remote/__init__.py -------------------------------------------------------------------------------- /minit/remote/actor.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import argparse 3 | import importlib 4 | from typing import Any, List, Union 5 | from multiprocessing import Process 6 | 7 | from .channel import Channel 8 | from .registry import create_object, get_function, get_object 9 | from .object import ObjectRef 10 | 11 | class Actor: 12 | id: int 13 | 14 | def connect(self, address: str, port: int): 15 | self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 16 | self.socket.connect((address, port)) 17 | self.channel = Channel(self.socket) 18 | self.id = self.channel.recv().get() 19 | print(f"serving as actor {self.id}") 20 | return self 21 | 22 | def act_forever(self): 23 | while True: 24 | [fn, *args] = self.channel.recv().get() 25 | print(f"acting {fn} at {self.id}") 26 | fn: str 27 | args: List[Union[ObjectRef, Any]] 28 | for arg in args: 29 | if isinstance(arg, ObjectRef): 30 | assert arg.location == self.id, f"{arg.location} vs {self.id}" 31 | local_args = [ 32 | get_object(arg.id) if isinstance(arg, ObjectRef) else arg for arg in args 33 | ] 34 | local_fn, is_constructor = get_function(fn) 35 | local_result = local_fn(*local_args) 36 | if is_constructor: 37 | result = ObjectRef(self.id, create_object(local_result)) 38 | else: 39 | result = local_result 40 | self.channel.send(result) 41 | 42 | 43 | def act_forever(address: str, port: int, module: str): 44 | importlib.__import__(module) 45 | actor = Actor() 46 | actor.connect(address, port) 47 | actor.act_forever() 48 | 49 | 50 | def main(): 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("--host", default="localhost", type=str) 53 | parser.add_argument("--port", "-p", type=int) 54 | parser.add_argument("--module", "-m", type=str) 55 | parser.add_argument("--size", "-n", type=int) 56 | args = parser.parse_args() 57 | processes: List[Process] = [] 58 | try: 59 | for _i in range(args.size): 60 | process = Process(target=act_forever, args=(args.host, args.port, args.module)) 61 | processes.append(process) 62 | process.start() 63 | while True: 64 | for process in processes: 65 | process.join(timeout=0.5) 66 | if all(process.exitcode is not None for process in processes): 67 | break 68 | finally: 69 | for process in processes: 70 | process.kill() 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /minit/remote/channel.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from typing import Any 3 | from .utils import send_object, recv_object 4 | 5 | 6 | class Channel: 7 | class Future: 8 | _value = None 9 | 10 | def __init__(self, socket: "socket.socket") -> None: 11 | self._socket = socket 12 | 13 | def fetch(self): 14 | if self._value is None: 15 | self._value = recv_object(self._socket) 16 | assert self._value is not None 17 | 18 | def get(self): 19 | self.fetch() 20 | return self._value 21 | 22 | def __init__(self, socket: "socket.socket") -> None: 23 | self._socket = socket 24 | self._last_future = None 25 | 26 | def send(self, object: Any) -> None: 27 | if self._last_future is not None: 28 | self._last_future.fetch() 29 | send_object(self._socket, object) 30 | 31 | def recv(self) -> Future: 32 | if self._last_future is not None: 33 | self._last_future.fetch() 34 | self._last_future = Channel.Future(self._socket) 35 | return self._last_future 36 | -------------------------------------------------------------------------------- /minit/remote/controller.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from types import FunctionType 3 | from typing import Any, List, Union 4 | 5 | from .channel import Channel 6 | from .object import ObjectRef 7 | 8 | class Controller: 9 | peers: List[Channel] 10 | 11 | def __init__(self, address: str, port: int) -> None: 12 | self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 13 | self.address = address 14 | self.port = port 15 | self.socket.bind((address, port)) 16 | self.socket.listen(64) 17 | self.port = self.socket.getsockname()[1] 18 | self.peers = [] 19 | 20 | def wait_for_connect(self, nr_actors: int): 21 | peers = [] 22 | while len(peers) < nr_actors: 23 | connection, _address = self.socket.accept() 24 | id = len(peers) + len(self.peers) 25 | channel = Channel(connection) 26 | channel.send(id) 27 | peers.append(channel) 28 | self.peers += peers 29 | 30 | def invoke_function(self, actor: int, function: FunctionType, *args: Union[ObjectRef, Any]): 31 | """blocking invoke""" 32 | peer = self.peers[actor] 33 | print(f"invoking {function.__qualname__} at {actor}") 34 | peer.send((function.__qualname__, *args)) 35 | result = peer.recv() 36 | return result 37 | 38 | def __len__(self): 39 | return len(self.peers) 40 | -------------------------------------------------------------------------------- /minit/remote/function.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass(frozen=True) 5 | class FunctionRef: 6 | location: int 7 | name: str 8 | -------------------------------------------------------------------------------- /minit/remote/object.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Union 3 | 4 | from .value import Value 5 | 6 | 7 | @dataclass(frozen=True) 8 | class ObjectRef: 9 | location: int 10 | id: int 11 | -------------------------------------------------------------------------------- /minit/remote/registry.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | 3 | from types import FunctionType 4 | 5 | 6 | FUNCTIONS = {} 7 | CONSTRUCTORS = {} 8 | OBJECTS = {} 9 | 10 | 11 | def register_function(fn: FunctionType): 12 | FUNCTIONS[fn.__qualname__] = fn 13 | return fn 14 | 15 | 16 | def register_constructor(fn: FunctionType): 17 | FUNCTIONS[fn.__qualname__] = fn 18 | CONSTRUCTORS[fn.__qualname__] = fn 19 | return fn 20 | 21 | 22 | def get_function(name: str) -> Tuple[FunctionType, bool]: 23 | return (FUNCTIONS[name],name in CONSTRUCTORS) 24 | 25 | 26 | def create_object(obj: Any) -> int: 27 | id = len(OBJECTS) 28 | OBJECTS[id] = obj 29 | return id 30 | 31 | 32 | def get_object(id: int): 33 | return OBJECTS[id] 34 | 35 | 36 | def register_method(fn: FunctionType): 37 | return register_function(fn) 38 | 39 | 40 | register_method(list.append) 41 | register_method(list.__getitem__) 42 | register_method(list.__len__) 43 | -------------------------------------------------------------------------------- /minit/remote/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import socket 3 | from typing import Any 4 | 5 | 6 | def recv_bytes(channel: socket.socket, size: int) -> bytes: 7 | received_bytes = bytearray() 8 | while len(received_bytes) < size: 9 | received_bytes += channel.recv(size - len(received_bytes)) 10 | return bytes(received_bytes) 11 | 12 | 13 | def send_object(channel: socket.socket, object: Any): 14 | packet = pickle.dumps(object) 15 | channel.sendall(len(packet).to_bytes(8, "little")) 16 | channel.sendall(packet) 17 | 18 | 19 | def recv_object(channel: socket.socket) -> Any: 20 | size_bytes = recv_bytes(channel, 8) 21 | size = int.from_bytes(size_bytes, "little") 22 | packet = recv_bytes(channel, size) 23 | return pickle.loads(packet) 24 | -------------------------------------------------------------------------------- /minit/remote/value.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any 3 | 4 | 5 | @dataclass(frozen=True) 6 | class Value: 7 | value: Any 8 | -------------------------------------------------------------------------------- /minit/trace/dispatch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal, Union 2 | 3 | from ..core.meta import MetaTensor 4 | from ..lazy.tensor import Expression, LazyTensor 5 | from ..core.tensor import Tensor 6 | from .tensor import TraceTensor 7 | from ..graph import GraphBuilder, Use 8 | from ..core.operator import Operator 9 | from ..core.device_operator import DeviceOperator 10 | from ..core.dispatch import dispatch, register_dispatch 11 | from ..core.object import match_pattern 12 | 13 | 14 | def any_trace_tensor(*args): 15 | for arg in args: 16 | if match_pattern(TraceTensor, arg): 17 | return True 18 | return False 19 | 20 | 21 | def _dispatch_trace(op: Operator, *args: Union[TraceTensor, Tensor]): 22 | arg_values = [] 23 | arg_nodes: List[Use] = [] 24 | builder = None 25 | for arg in args: 26 | if isinstance(arg, TraceTensor): 27 | if builder is None: 28 | builder = arg._builder 29 | else: 30 | assert builder == arg._builder 31 | assert builder is not None 32 | for arg in args: 33 | if isinstance(arg, LazyTensor): 34 | arg = trace_evaluate(builder, arg._expression)[arg._index] 35 | assert not isinstance(arg, LazyTensor) 36 | if isinstance(arg, TraceTensor): 37 | arg_values.append(arg._value) 38 | arg_nodes.append(arg._node) 39 | else: 40 | arg_values.append(arg) 41 | arg_nodes.append(builder.create_constant(arg)) 42 | output_values = dispatch(op, *arg_values) 43 | output_metas = tuple(MetaTensor(output_value.shape, output_value.dtype) for output_value in output_values) 44 | output_uses = builder.create_operator(op, arg_nodes, output_metas) 45 | return tuple(TraceTensor(builder, output_use, output_value) for output_use, output_value in zip(output_uses, output_values)) 46 | 47 | 48 | def trace_evaluate(builder: GraphBuilder, expression: Expression): 49 | args = tuple([trace_evaluate(builder, arg._expression)[arg._index] if isinstance(arg, LazyTensor) else TraceTensor(builder, builder.create_constant(arg), arg) for arg in expression.args]) 50 | return dispatch(expression.op, *args) 51 | 52 | 53 | @register_dispatch(predicate=any_trace_tensor, priority=1) 54 | def dispatch_any(op: Operator, *args: Union[TraceTensor, Tensor]): 55 | return _dispatch_trace(op, *args) 56 | 57 | 58 | @register_dispatch(priority=1) 59 | def register_device_operator(op: DeviceOperator[Operator, Literal["trace"]], *args: Union[TraceTensor, Tensor]): 60 | return _dispatch_trace(op.operator, *args) 61 | -------------------------------------------------------------------------------- /minit/trace/executor.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from dataclasses import dataclass 3 | from typing import Callable, DefaultDict, Dict, List, Optional, Tuple 4 | import nvtx 5 | 6 | from ..core.operator import Operator 7 | 8 | from ..core.dispatch import dispatch 9 | 10 | from ..core.tensor import Tensor 11 | from ..graph import ConstantNode, OperatorNode, SubGraph, TensorNode, Use, ValueNode 12 | 13 | 14 | class ValueAndRefCount: 15 | __slots__ = [ 16 | "value", "ref_count", "init_ref_count", "require_shape", "shape", 17 | ] 18 | 19 | value: Optional[Tensor] 20 | ref_count: int 21 | init_ref_count: int 22 | require_shape: bool 23 | shape: Optional[Tuple[Tensor, ...]] 24 | 25 | def __init__(self, init_ref_count: int, require_shape: bool) -> None: 26 | self.value = None 27 | self.ref_count = 0 28 | self.init_ref_count = init_ref_count 29 | self.require_shape = require_shape 30 | self.shape = None 31 | 32 | def consume(self) -> Tensor: 33 | value = self.value 34 | assert value is not None 35 | self.ref_count -= 1 36 | assert self.ref_count >= 0 37 | if self.ref_count == 0: 38 | self.value = None 39 | return value 40 | 41 | def produce(self, value: Tensor): 42 | self.ref_count = self.init_ref_count 43 | if self.init_ref_count != 0: 44 | self.value = value 45 | if self.require_shape: 46 | self.shape = value.shape 47 | 48 | 49 | class TraceGraphExecutor: 50 | graph: SubGraph 51 | value_and_refs: Dict[ValueNode, ValueAndRefCount] 52 | sequence: List[Tuple[List[Callable[[], Tensor]], Operator, List[Callable[[Tensor], None]]]] 53 | 54 | def make_consumer(self, node_use: Use) -> Callable[[], Tensor]: 55 | if node_use.axis is not None: 56 | node = node_use.target 57 | axis = node_use.axis 58 | source = self.value_and_refs[node] 59 | return lambda: source.shape[axis] 60 | else: 61 | if isinstance(node_use.target, ConstantNode): 62 | constant = node_use.target.value 63 | return lambda: constant 64 | else: 65 | return self.value_and_refs[node_use.target].consume 66 | 67 | def make_producer(self, node: ValueNode) -> Callable[[Tensor], None]: 68 | value_and_ref = self.value_and_refs.get(node, None) 69 | if value_and_ref is not None: 70 | return value_and_ref.produce 71 | else: 72 | return lambda x: None 73 | 74 | def __init__(self, graph: SubGraph) -> None: 75 | self.graph = graph 76 | counts: DefaultDict[ValueNode, int] = defaultdict(lambda: 0) 77 | shapes: Dict[ValueNode, Tuple[ValueNode, ...]] = {} 78 | 79 | def record(node_use: Use): 80 | assert isinstance(node_use, Use) 81 | node = node_use() 82 | counts[node] += 1 83 | if node_use.axis is not None: 84 | shapes[node] = () 85 | 86 | for operator in graph.operators: 87 | for arg in operator.args: 88 | record(arg) 89 | for output in graph.outputs: 90 | record(output) 91 | self.value_and_refs: Dict[ValueNode, ValueAndRefCount] = { 92 | node: ValueAndRefCount(ref_count, (node in shapes)) for node, ref_count in counts.items() 93 | } 94 | self.sequence = [] 95 | for operator in graph.operators: 96 | self.sequence.append(( 97 | [self.make_consumer(arg) for arg in operator.args], 98 | operator.operator, 99 | [self.make_producer(node) for node in operator.outputs], 100 | )) 101 | 102 | def execute(self, *args: Tensor): 103 | assert len(args) == len(self.graph.inputs) 104 | 105 | for node_use, value in zip(self.graph.inputs, args): 106 | assert node_use.axis is None 107 | self.make_producer(node_use.target)(value) 108 | 109 | for operator_inputs, operator, operator_outputs in self.sequence: 110 | operator_input_values = [operator_input() for operator_input in operator_inputs] 111 | operator_output_values = dispatch(operator, *operator_input_values) 112 | for operator_output, operator_output_value in zip(operator_outputs, operator_output_values): 113 | operator_output(operator_output_value) 114 | 115 | outputs = tuple([self.make_consumer(output)() for output in self.graph.outputs]) 116 | return outputs 117 | 118 | def __call__(self, *args: Tensor): 119 | return self.execute(*args) 120 | -------------------------------------------------------------------------------- /minit/trace/function.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Protocol, Sequence, Tuple 2 | 3 | from ..core.meta import MetaTensor 4 | from ..graph import GraphBuilder, SubGraph, Graph, Use 5 | from .tensor import TraceTensor 6 | from ..core.tensor import Tensor 7 | 8 | 9 | class TraceableFunction(Protocol): 10 | def __call__(self, *args: Tensor) -> Optional[Tuple[Tensor, ...]]: 11 | ... 12 | 13 | 14 | def trace_function(func: TraceableFunction, args: Sequence[Tensor]) -> SubGraph: 15 | graph = Graph() 16 | input_nodes = tuple(graph.create_input(MetaTensor(arg.shape, arg.dtype)) for arg in args) 17 | builder = GraphBuilder(graph, input_nodes, graph.operators.view()) 18 | output_nodes = trace_function_on_graph(func, args, builder, [input_node.use_value(None) for input_node in input_nodes]) 19 | if output_nodes is None: 20 | output_nodes = () 21 | return SubGraph(graph, tuple(input_node.use_value(None) for input_node in input_nodes), graph.operators.view(), tuple(output_nodes)) 22 | 23 | 24 | def trace_function_on_graph(func: TraceableFunction, args: Tuple[Tensor, ...], builder: GraphBuilder, uses: Sequence[Use]) -> Tuple[Use, ...]: 25 | inputs = tuple(TraceTensor(builder, use, arg) for i, (arg, use) in enumerate(zip(args, uses))) 26 | outputs = func(*inputs) 27 | if outputs is None: 28 | outputs = () 29 | output_nodes = [] 30 | for output in outputs: 31 | if not isinstance(output, TraceTensor): 32 | output = TraceTensor(builder, builder.create_constant(output), output) 33 | output_nodes.append(output._node) 34 | return builder.build(*output_nodes) 35 | -------------------------------------------------------------------------------- /minit/trace/tensor.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from ..core.tensor import Tensor 3 | 4 | 5 | from ..graph import GraphBuilder, Use, ShapeUse, ValueUse 6 | 7 | class TraceTensor(Tensor): 8 | _value: Tensor 9 | _builder: GraphBuilder 10 | _node: Use 11 | 12 | def __init__(self, builder: GraphBuilder, node: Use, value: Tensor) -> None: 13 | super().__init__() 14 | self._builder = builder 15 | self._node = node 16 | self._value = value 17 | assert isinstance(node, (ShapeUse, ValueUse)) 18 | 19 | @property 20 | def shape(self) -> Tuple[Tensor, ...]: 21 | return tuple(TraceTensor(self._builder, self._node.shape(i), dim) for i, dim in enumerate(self._value.shape)) 22 | 23 | @property 24 | def dtype(self) -> str: 25 | return self._value.dtype 26 | 27 | @property 28 | def device(self): 29 | return "trace" 30 | --------------------------------------------------------------------------------