├── sst ├── __init__.py └── configs │ ├── __init__.py │ └── defaults.py ├── requirements.txt ├── imgs └── sstvos-architecture.png ├── gridattention ├── __init__.py ├── setup.py ├── src │ ├── gridattn.h │ ├── lib_cffi.cpp │ └── gridattn.cu └── functions.py ├── README.md ├── setup.py ├── tools └── test_gridattention.py └── .gitignore /sst/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | click 2 | yacs 3 | -------------------------------------------------------------------------------- /sst/configs/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import _C as cfg 2 | -------------------------------------------------------------------------------- /imgs/sstvos-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dukebw/SSTVOS/HEAD/imgs/sstvos-architecture.png -------------------------------------------------------------------------------- /gridattention/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import ( 2 | GridAttention, 3 | GridAttnMapFunc, 4 | GridAttnWeightFunc, 5 | GridAttentionMap, 6 | GridAttentionWeight, 7 | gridattn_weight, 8 | gridattn_map, 9 | init_gridattn, 10 | ) 11 | -------------------------------------------------------------------------------- /gridattention/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name="gridattention", 6 | ext_modules=[ 7 | CUDAExtension( 8 | "gridattention", 9 | ["src/lib_cffi.cpp", "src/gridattn.cu"], 10 | extra_compile_args=["-std=c++11"], 11 | ), 12 | ], 13 | cmdclass={"build_ext": BuildExtension}, 14 | ) 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SSTVOS 2 | 3 | Coming soon! 4 | 5 | ArXiv: https://arxiv.org/abs/2101.08833 6 | 7 | 8 | ![SSTVOS](imgs/sstvos-architecture.png "SSTVOS") 9 | 10 | 11 | ## Citation 12 | 13 | ``` 14 | @inproceedings{duke2021sstvos, 15 | title={SSTVOS: Sparse Spatiotemporal Transformers for Video Object Segmentation}, 16 | author={Brendan Duke and Abdalla Ahmed and Christian Wolf and Parham Aarabi and Graham W. Taylor}, 17 | booktitle={CVPR}, 18 | year={2021} 19 | } 20 | ``` 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """SSTVOS setup.py""" 2 | import setuptools 3 | 4 | with open("README.md", "r", encoding="utf-8") as fh: 5 | long_description = fh.read() 6 | 7 | setuptools.setup( 8 | name="sstvos", 9 | version="0.0.1", 10 | author="Brendan Duke", 11 | author_email="brendanw.duke@gmail.com", 12 | description="SSTVOS", 13 | long_description=long_description, 14 | long_description_content_type="text/markdown", 15 | url="https://github.com/dukebw/sstvos", 16 | packages=setuptools.find_packages(), 17 | classifiers=[ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | "Operating System :: OS Independent", 21 | ], 22 | python_requires=">=3.8", 23 | ) 24 | -------------------------------------------------------------------------------- /gridattention/src/gridattn.h: -------------------------------------------------------------------------------- 1 | #ifndef __GRIDATTN__ 2 | #define __GRIDATTN__ 3 | 4 | #include "cuda_runtime.h" 5 | 6 | /* 7 | * Exported functions 8 | */ 9 | extern "C" int 10 | _gridattn_map_forward_cuda(int N, 11 | int C, 12 | int T, 13 | int H, 14 | int W, 15 | const float *attn, 16 | const float *val, 17 | float *out, 18 | cudaStream_t stream); 19 | 20 | extern "C" int 21 | _gridattn_map_backward_cuda(int N, 22 | int C, 23 | int T, 24 | int H, 25 | int W, 26 | const float *dout, 27 | const float *attn, 28 | const float *val, 29 | float *dattn, 30 | float *dval, 31 | cudaStream_t stream); 32 | 33 | 34 | extern "C" int 35 | _gridattn_forward_cuda(int N, 36 | int C, 37 | int T, 38 | int H, 39 | int W, 40 | const float *query_data, 41 | const float *key_data, 42 | float *attnscores_data, 43 | cudaStream_t stream); 44 | 45 | extern "C" int 46 | _gridattn_backward_cuda(int N, 47 | int C, 48 | int T, 49 | int H, 50 | int W, 51 | const float *dattn_data, 52 | const float *query_data, 53 | const float *key_data, 54 | float *dquery_data, 55 | float *dkey_data, 56 | cudaStream_t stream); 57 | 58 | #endif /* __GRIDATTN__ */ 59 | -------------------------------------------------------------------------------- /tools/test_gridattention.py: -------------------------------------------------------------------------------- 1 | """Unit tests for grid attention""" 2 | import argparse 3 | 4 | import torch 5 | from torch.autograd.function import once_differentiable 6 | import torch.nn.functional as F 7 | 8 | from gridattention import gridattn_weight, gridattn_map, init_gridattn 9 | from sst.configs import cfg 10 | 11 | 12 | def _testgrad_gridattnfunc(gridattnfunc, variables): 13 | for i, var in enumerate(variables): 14 | if var.dtype == torch.float32: 15 | var = var.double().cuda() 16 | var.requires_grad = True 17 | variables[i] = var 18 | elif var.dtype == torch.int32: 19 | var = var.cuda() 20 | var.requires_grad = False 21 | variables[i] = var 22 | 23 | if torch.autograd.gradcheck(gridattnfunc, variables, eps=1e-4, atol=1e-2): 24 | print("Ok") 25 | else: 26 | print("Not ok") 27 | 28 | 29 | def _test_grad(): 30 | N = 1 31 | C = 4 32 | T = 5 33 | H = 6 34 | W = 6 35 | query = torch.randn((N, C, T, H, W)) 36 | key = torch.randn((N, C, T, H, W)) 37 | 38 | variables = [query, key] 39 | _testgrad_gridattnfunc(gridattn_weight, variables) 40 | 41 | attention = torch.randn((N, W + H + T - 2, T, H, W)) 42 | value = torch.randn((N, C, T, H, W)) 43 | 44 | variables = [attention, value] 45 | _testgrad_gridattnfunc(gridattn_map, variables) 46 | 47 | 48 | def test_gridattention(): 49 | parser = argparse.ArgumentParser(description="Test GridAttention") 50 | parser.add_argument( 51 | "--config-file", 52 | default="", 53 | metavar="FILE", 54 | help="path to config file", 55 | type=str, 56 | ) 57 | parser.add_argument( 58 | "opts", 59 | help="Modify config options using the command-line", 60 | default=None, 61 | nargs=argparse.REMAINDER, 62 | ) 63 | args = parser.parse_args() 64 | 65 | cfg.merge_from_file(args.config_file) 66 | cfg.merge_from_list(args.opts) 67 | cfg.freeze() 68 | 69 | init_gridattn(cfg) 70 | 71 | N = 2 72 | C = 8 73 | H = 10 74 | W = 10 75 | T = 6 76 | 77 | query = torch.zeros(N, C, T, H, W).cuda() + 1.1 78 | key = torch.zeros(N, C, T, H, W).cuda() + 2.0 79 | val = torch.zeros(N, C, T, H, W).cuda() + 3.0 80 | 81 | attnscores = gridattn_weight(query, key) 82 | attn = F.softmax(attnscores, dim=1) 83 | out = gridattn_map(attn, val) 84 | 85 | _test_grad() 86 | 87 | 88 | if __name__ == "__main__": 89 | test_gridattention() 90 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /gridattention/src/lib_cffi.cpp: -------------------------------------------------------------------------------- 1 | // All functions assume that input and output tensors are already initialized 2 | // and have the correct dimensions 3 | #include 4 | #include 5 | #include "gridattn.h" 6 | 7 | int 8 | gridattn_map_forward_cuda(const at::Tensor& attn, 9 | const at::Tensor& val, 10 | at::Tensor& out) 11 | { 12 | cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); 13 | int N = val.size(0); 14 | int C = val.size(1); 15 | int T = val.size(2); 16 | int H = val.size(3); 17 | int W = val.size(4); 18 | 19 | const float *attn_data = attn.data_ptr(); 20 | const float *val_data = val.data_ptr(); 21 | float *out_data = out.data_ptr(); 22 | 23 | return _gridattn_map_forward_cuda(N, C, T, H, W, attn_data, val_data, out_data, stream); 24 | } 25 | 26 | int 27 | gridattn_map_backward_cuda(const at::Tensor& dout, 28 | const at::Tensor& attn, 29 | const at::Tensor& val, 30 | at::Tensor& dattn, 31 | at::Tensor& dval) 32 | { 33 | cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); 34 | 35 | int N = dout.size(0); 36 | int C = dout.size(1); 37 | int T = dout.size(2); 38 | int H = dout.size(3); 39 | int W = dout.size(4); 40 | 41 | const float *dout_data = dout.data_ptr(); 42 | const float *attn_data = attn.data_ptr(); 43 | const float *val_data = val.data_ptr(); 44 | float *dattn_data = dattn.data_ptr(); 45 | float *dval_data = dval.data_ptr(); 46 | 47 | return _gridattn_map_backward_cuda(N, 48 | C, 49 | T, 50 | H, 51 | W, 52 | dout_data, 53 | attn_data, 54 | val_data, 55 | dattn_data, 56 | dval_data, 57 | stream); 58 | } 59 | 60 | int 61 | gridattn_forward_cuda(const at::Tensor& query, 62 | const at::Tensor& key, 63 | at::Tensor& attnscores) 64 | { 65 | cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); 66 | int N = query.size(0); 67 | int C = query.size(1); 68 | int T = query.size(2); 69 | int H = query.size(3); 70 | int W = query.size(4); 71 | 72 | return _gridattn_forward_cuda(N, 73 | C, 74 | T, 75 | H, 76 | W, 77 | query.data_ptr(), 78 | key.data_ptr(), 79 | attnscores.data_ptr(), 80 | stream); 81 | } 82 | 83 | int 84 | gridattn_backward_cuda(const at::Tensor& dattn, 85 | const at::Tensor& query, 86 | const at::Tensor& key, 87 | at::Tensor& dquery, 88 | at::Tensor& dkey) 89 | { 90 | cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); 91 | int N = query.size(0); 92 | int C = query.size(1); 93 | int T = query.size(2); 94 | int H = query.size(3); 95 | int W = query.size(4); 96 | 97 | return _gridattn_backward_cuda(N, 98 | C, 99 | T, 100 | H, 101 | W, 102 | dattn.data_ptr(), 103 | query.data_ptr(), 104 | key.data_ptr(), 105 | dquery.data_ptr(), 106 | dkey.data_ptr(), 107 | stream); 108 | } 109 | 110 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 111 | { 112 | m.def("gridattn_forward_cuda", 113 | &gridattn_forward_cuda, 114 | "Grid attention forward CUDA"); 115 | m.def("gridattn_backward_cuda", 116 | &gridattn_backward_cuda, 117 | "Grid attention backward CUDA"); 118 | m.def("gridattn_map_forward_cuda", 119 | &gridattn_map_forward_cuda, 120 | "Grid attention map forward CUDA"); 121 | m.def("gridattn_map_backward_cuda", 122 | &gridattn_map_backward_cuda, 123 | "Grid attention map backward CUDA"); 124 | } 125 | -------------------------------------------------------------------------------- /sst/configs/defaults.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from yacs.config import CfgNode as CN 4 | 5 | 6 | # ----------------------------------------------------------------------------- 7 | # Convention about Training / Test specific parameters 8 | # ----------------------------------------------------------------------------- 9 | # Whenever an argument can be either used for training or for testing, the 10 | # corresponding name will be post-fixed by a _TRAIN for a training parameter, 11 | # or _TEST for a test-specific parameter. 12 | # For example, the maximum image side during training will be 13 | # INPUT.MAX_SIZE_TRAIN, while for testing it will be 14 | # INPUT.MAX_SIZE_TEST 15 | 16 | # ----------------------------------------------------------------------------- 17 | # Config definition 18 | # ----------------------------------------------------------------------------- 19 | 20 | _C = CN() 21 | 22 | _C.EXP_NAME = "resnet101_cfbi" 23 | 24 | _C.DIR_ROOT = "./" 25 | _C.DIR_DATA = os.path.join(_C.DIR_ROOT, "datasets") 26 | _C.DIR_DAVIS = os.path.join(_C.DIR_DATA, "DAVIS") 27 | _C.DIR_YTB = os.path.join(_C.DIR_DATA, "YTB/train") 28 | _C.DIR_YTB_EVAL = os.path.join(_C.DIR_DATA, "YTB/valid") 29 | _C.DIR_RESULT = os.path.join(_C.DIR_ROOT, "result", _C.EXP_NAME) 30 | _C.DIR_CKPT = os.path.join(_C.DIR_RESULT, "ckpt") 31 | _C.DIR_LOG = os.path.join(_C.DIR_RESULT, "log") 32 | _C.DIR_IMG_LOG = os.path.join(_C.DIR_RESULT, "log", "img") 33 | _C.DIR_TB_LOG = os.path.join(_C.DIR_RESULT, "log", "tensorboard") 34 | _C.DIR_EVALUATION = os.path.join(_C.DIR_RESULT, "eval") 35 | 36 | _C.DATASETS = ["davis2017"] 37 | 38 | _C.DATA = CN() 39 | _C.DATA.WORKERS = 4 40 | _C.DATA.RANDOMCROP = (465, 465) 41 | _C.DATA.RANDOMFLIP = 0.5 42 | _C.DATA.MAX_CROP_STEPS = 5 43 | _C.DATA.MIN_SCALE_FACTOR = 1.0 44 | _C.DATA.MAX_SCALE_FACTOR = 1.3 45 | _C.DATA.SHORT_EDGE_LEN = 480 46 | _C.DATA.RANDOM_REVERSE_SEQ = True 47 | _C.DATA.DAVIS_REPEAT = 30 48 | _C.DATA.CURR_SEQ_LEN = 3 49 | _C.DATA.RANDOM_GAP_DAVIS = 3 50 | _C.DATA.RANDOM_GAP_YTB = 3 51 | 52 | _C.PRETRAIN = True 53 | _C.PRETRAIN_FULL = False 54 | _C.PRETRAIN_MODEL = "./pretrain_models/resnet101-deeplabv3p.pth.tar" 55 | 56 | _C.MODEL = CN() 57 | _C.MODEL.BACKBONE = "resnet" 58 | _C.MODEL.MODULE = "networks.cfbi.cfbi" 59 | _C.MODEL.OUTPUT_STRIDE = 16 60 | _C.MODEL.ASPP_OUTDIM = 256 61 | _C.MODEL.SHORTCUT_DIM = 48 62 | _C.MODEL.SEMANTIC_EMBEDDING_DIM = 100 63 | _C.MODEL.HEAD_EMBEDDING_DIM = 256 64 | _C.MODEL.PRE_HEAD_EMBEDDING_DIM = 64 65 | _C.MODEL.GN_GROUPS = 32 66 | _C.MODEL.GN_EMB_GROUPS = 25 67 | _C.MODEL.MULTI_LOCAL_DISTANCE = [2, 4, 6, 8, 10, 12] 68 | _C.MODEL.LOCAL_DOWNSAMPLE = True 69 | _C.MODEL.REFINE_CHANNELS = 64 # n * 32 70 | _C.MODEL.LOW_LEVEL_INPLANES = 256 if _C.MODEL.BACKBONE == "resnet" else 24 71 | _C.MODEL.RELATED_CHANNELS = 64 72 | _C.MODEL.EPSILON = 1e-5 73 | _C.MODEL.MATCHING_BACKGROUND = True 74 | _C.MODEL.GCT_BETA_WD = True 75 | _C.MODEL.FLOAT16_MATCHING = False 76 | _C.MODEL.FREEZE_BN = True 77 | _C.MODEL.FREEZE_BACKBONE = False 78 | _C.MODEL.IS_TRANSFORMER = False 79 | _C.MODEL.USE_PREV_EMBEDDING_AND_LABEL_AS_FEATURES = True 80 | _C.MODEL.USE_LOCAL_MATCHING = True 81 | 82 | _C.MODEL.TRANSFORMER = CN() 83 | _C.MODEL.TRANSFORMER.ATTENTION_FEATURES = "No" 84 | _C.MODEL.TRANSFORMER.ATTENTION_NORM_TYPE = "FilterResponseNorm" 85 | _C.MODEL.TRANSFORMER.ATTENTION_NORMALIZATION = "Distance" 86 | _C.MODEL.TRANSFORMER.ATTENTION_STRIDE_SPATIAL = [-1] 87 | _C.MODEL.TRANSFORMER.ATTENTION_STRIDE_TEMPORAL = "Dense" 88 | _C.MODEL.TRANSFORMER.ATTENTION_TYPE = "Full" 89 | _C.MODEL.TRANSFORMER.ATTENTION_SPARSITY = "Local" 90 | _C.MODEL.TRANSFORMER.HISTORY_SIZE = 1 91 | _C.MODEL.TRANSFORMER.NUM_ATTN_HEADS = 8 92 | _C.MODEL.TRANSFORMER.NUM_ATTN_LAYERS = 1 93 | _C.MODEL.TRANSFORMER.POSITIONAL_ENCODING = "NonPositional" 94 | _C.MODEL.TRANSFORMER.POSITIONAL_ENCODING_DIM = 64 95 | _C.MODEL.TRANSFORMER.POSITIONAL_ENCODING_MAX_LEN_T = 300 96 | _C.MODEL.TRANSFORMER.POSITIONWISE_FEEDFORWARD_NONLINEARITY = "FRN" 97 | _C.MODEL.TRANSFORMER.USE_RELATIVE_POSITIONAL_EMBEDDINGS = False 98 | _C.MODEL.TRANSFORMER.USE_WARPED_FGBG_FEATURES = True 99 | _C.MODEL.TRANSFORMER.ARE_FEATURES_DOWNSAMPLED = False 100 | 101 | _C.TRAIN = CN() 102 | _C.TRAIN.TOTAL_STEPS = 25000 103 | _C.TRAIN.START_STEP = 0 104 | _C.TRAIN.LR = 0.01 105 | _C.TRAIN.MOMENTUM = 0.9 106 | _C.TRAIN.COSINE_DECAY = False 107 | _C.TRAIN.WARM_UP_STEPS = 1000 108 | _C.TRAIN.WEIGHT_DECAY = 15e-5 109 | _C.TRAIN.POWER = 0.9 110 | _C.TRAIN.GPUS = 4 111 | _C.TRAIN.BATCH_SIZE = 8 112 | _C.TRAIN.START_SEQ_TRAINING_STEPS = _C.TRAIN.TOTAL_STEPS // 2 113 | _C.TRAIN.TBLOG = False 114 | _C.TRAIN.TBLOG_STEP = 60 115 | _C.TRAIN.LOG_STEP = 100 116 | _C.TRAIN.IMG_LOG = False 117 | _C.TRAIN.TOP_K_PERCENT_PIXELS = 0.15 118 | _C.TRAIN.HARD_MINING_STEP = _C.TRAIN.TOTAL_STEPS // 2 119 | _C.TRAIN.CLIP_GRAD_NORM = 5.0 120 | _C.TRAIN.SAVE_STEP = 1000 121 | _C.TRAIN.MAX_KEEP_CKPT = 8 122 | _C.TRAIN.RESUME = False 123 | _C.TRAIN.RESUME_CKPT = None 124 | _C.TRAIN.RESUME_STEP = 0 125 | _C.TRAIN.AUTO_RESUME = True 126 | _C.TRAIN.GLOBAL_ATROUS_RATE = 1 127 | _C.TRAIN.LOCAL_ATROUS_RATE = 1 128 | _C.TRAIN.LOCAL_PARALLEL = True 129 | _C.TRAIN.GLOBAL_CHUNKS = 20 130 | _C.TRAIN.DATASET_FULL_RESOLUTION = True 131 | _C.TRAIN.DO_EVAL_DURING_TRAINING = True 132 | 133 | _C.TEST = CN() 134 | _C.TEST.GPU_ID = 0 135 | _C.TEST.DATASET = "davis2017" 136 | _C.TEST.DATASET_FULL_RESOLUTION = False 137 | _C.TEST.DATASET_SPLIT = ["val"] 138 | _C.TEST.CKPT_PATH = None 139 | _C.TEST.CKPT_STEP = None # if "None", evaluate the latest checkpoint. 140 | _C.TEST.FLIP = False 141 | _C.TEST.MULTISCALE = [1] 142 | _C.TEST.MIN_SIZE = None 143 | _C.TEST.MAX_SIZE = 800 * 1.3 if _C.TEST.MULTISCALE == [1.0] else 800 144 | _C.TEST.WORKERS = 4 145 | _C.TEST.GLOBAL_CHUNKS = 4 146 | _C.TEST.GLOBAL_ATROUS_RATE = 1 147 | _C.TEST.LOCAL_ATROUS_RATE = 1 148 | _C.TEST.LOCAL_PARALLEL = True 149 | 150 | # dist 151 | _C.DIST = CN() 152 | _C.DIST.ENABLE = True 153 | _C.DIST.BACKEND = "gloo" 154 | _C.DIST.URL = "file:///home/ubuntu/work/CFBI/sharefile" 155 | _C.DIST.START_GPU = 0 156 | -------------------------------------------------------------------------------- /gridattention/functions.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import torch 5 | 6 | import torch.autograd as autograd 7 | import torch.cuda.comm as comm 8 | import torch.nn.functional as F 9 | from torch.autograd.function import once_differentiable 10 | from torch.utils import cpp_extension 11 | 12 | _gridattn = None 13 | 14 | 15 | def init_gridattn(cfg): 16 | global _gridattn 17 | 18 | curr_dir = os.path.dirname(os.path.abspath(__file__)) 19 | _src_path = os.path.join(curr_dir, "src") 20 | _build_path = os.path.join(curr_dir, "build") 21 | os.makedirs(_build_path, exist_ok=True) 22 | 23 | _gridattn = cpp_extension.load( 24 | name="gridattn", 25 | extra_cflags=["-O3"], 26 | build_directory=_build_path, 27 | verbose=True, 28 | sources=[os.path.join(_src_path, f) for f in ["lib_cffi.cpp", "gridattn.cu"]], 29 | extra_cuda_cflags=["--expt-extended-lambda", "-O3", "--use_fast_math"], 30 | ) 31 | 32 | 33 | def _check_contiguous(*args): 34 | if not all([mod is None or mod.is_contiguous() for mod in args]): 35 | raise ValueError("Non-contiguous input") 36 | 37 | 38 | class GridAttnWeightFunc(torch.autograd.Function): 39 | @staticmethod 40 | def forward(ctx, query, key, strides=None): 41 | query = query.float() 42 | key = key.float() 43 | 44 | # Save context 45 | n, c, t, h, w = query.size() 46 | size = (n, h + w + t - 2, t, h, w) 47 | attnscores = torch.zeros( 48 | size, 49 | dtype=query.dtype, 50 | layout=query.layout, 51 | device=query.device, 52 | ) 53 | 54 | _gridattn.gridattn_forward_cuda(query, key, attnscores) 55 | 56 | # Output 57 | ctx.save_for_backward(query, key) 58 | 59 | return attnscores 60 | 61 | @staticmethod 62 | @once_differentiable 63 | def backward(ctx, dattn): 64 | query, key = ctx.saved_tensors 65 | 66 | dquery = torch.zeros_like(query) 67 | dkey = torch.zeros_like(key) 68 | 69 | _gridattn.gridattn_backward_cuda(dattn.contiguous(), query, key, dquery, dkey) 70 | 71 | _check_contiguous(dquery, dkey) 72 | 73 | return dquery, dkey, None 74 | 75 | 76 | class GridAttnMapFunc(torch.autograd.Function): 77 | @staticmethod 78 | def forward(ctx, attention, value, strides=None): 79 | assert attention.shape[1] == (sum(value.shape[2:]) - 2) 80 | attention = attention.float() 81 | value = value.float() 82 | 83 | # Save context 84 | out = torch.zeros_like(value) 85 | _gridattn.gridattn_map_forward_cuda(attention, value, out) 86 | 87 | # Output 88 | ctx.save_for_backward(attention, value) 89 | 90 | return out 91 | 92 | @staticmethod 93 | @once_differentiable 94 | def backward(ctx, dout): 95 | attention, value = ctx.saved_tensors 96 | 97 | dattn = torch.zeros_like(attention) 98 | dval = torch.zeros_like(value) 99 | 100 | _gridattn.gridattn_map_backward_cuda( 101 | dout.contiguous(), attention, value, dattn, dval 102 | ) 103 | 104 | _check_contiguous(dattn, dval) 105 | 106 | return dattn, dval, None 107 | 108 | 109 | gridattn_weight = GridAttnWeightFunc.apply 110 | gridattn_map = GridAttnMapFunc.apply 111 | 112 | 113 | class PositionalEncoding(torch.nn.Module): 114 | def __init__(self, d_model, dropout=0.1, max_len=5000): 115 | torch.nn.Module.__init__(self) 116 | self.dropout = torch.nn.Dropout(p=dropout) 117 | 118 | pe = torch.zeros(max_len, d_model) 119 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 120 | div_term = torch.exp( 121 | torch.arange(0, d_model, 2).float() * (-math.log(100.0) / d_model) 122 | ) 123 | pe[:, 0::2] = torch.sin(position * div_term) 124 | pe[:, 1::2] = torch.cos(position * div_term) 125 | self.register_buffer("pe", pe) 126 | 127 | 128 | class GridAttentionWeight(torch.nn.Module): 129 | """NOTE(brendan): W and H should be the max. size in the x and y 130 | dimensions, respectively, of the tensor input to attention. 131 | """ 132 | 133 | def __init__(self, in_dim, cfg): 134 | torch.nn.Module.__init__(self) 135 | 136 | if cfg.MODEL.RESNETS.STAGE_WITH_DCN[-1]: 137 | stride = 8 138 | else: 139 | stride = 16 140 | H = cfg.INPUT.MAX_SIZE_TRAIN // stride 141 | W = cfg.INPUT.MAX_SIZE_TRAIN // stride 142 | self.max_frame_count = cfg.INPUT.MAX_FRAME_COUNT 143 | self.num_heads = cfg.MODEL.GRID_ATTENTION.DIM_REDUCE_FACTOR 144 | self.embinit = cfg.MODEL.GRID_ATTENTION.EMBEDDING_INITIALIZATION 145 | self.is_position_sincos = ( 146 | cfg.MODEL.GRID_ATTENTION.POSITIONAL_ENCODING == "Sinusoidal" 147 | ) 148 | 149 | projdim = in_dim // self.num_heads 150 | self.query_conv = torch.nn.Conv3d( 151 | in_channels=in_dim, out_channels=in_dim, kernel_size=1 152 | ) 153 | self.key_conv = torch.nn.Conv3d( 154 | in_channels=in_dim, out_channels=in_dim, kernel_size=1 155 | ) 156 | 157 | self.x_embeddings = torch.nn.Parameter(torch.Tensor((2 * W) - 1, projdim)) 158 | self.y_embeddings = torch.nn.Parameter(torch.Tensor((2 * H) - 1, projdim)) 159 | if cfg.MODEL.GRID_ATTENTION.POSITIONAL_ENCODING == "LearnedRelative": 160 | self.t_embeddings = torch.nn.Parameter( 161 | torch.Tensor((2 * self.max_frame_count) - 1, projdim) 162 | ) 163 | elif cfg.MODEL.GRID_ATTENTION.POSITIONAL_ENCODING == "Sinusoidal": 164 | self.t_embeddings = PositionalEncoding( 165 | projdim, max_len=(2 * self.max_frame_count) - 1 166 | ) 167 | else: 168 | assert False 169 | 170 | self.reset_parameters() 171 | 172 | def reset_parameters(self): 173 | if self.embinit == "Normal": 174 | torch.nn.init.normal_(self.x_embeddings) 175 | torch.nn.init.normal_(self.y_embeddings) 176 | if not self.is_position_sincos: 177 | torch.nn.init.normal_(self.t_embeddings) 178 | elif self.embinit == "Zeros": 179 | torch.nn.init.zeros_(self.x_embeddings) 180 | torch.nn.init.zeros_(self.y_embeddings) 181 | if not self.is_position_sincos: 182 | torch.nn.init.zeros_(self.t_embeddings) 183 | else: 184 | assert False 185 | 186 | def forward(self, query, key, frm_indices): 187 | projquery = self.query_conv(query) 188 | n, c, t, h, w = projquery.shape 189 | assert (frm_indices.shape[0] == n) and ( 190 | frm_indices.shape[1] == t 191 | ), f"{frm_indices.shape}, {n}, {t}" 192 | 193 | projquery = projquery.view(n * self.num_heads, c // self.num_heads, t, h, w) 194 | 195 | projkey = self.key_conv(key) 196 | projkey = projkey.view(n * self.num_heads, c // self.num_heads, t, h, w) 197 | 198 | if self.is_position_sincos: 199 | # TODO(brendan): try dropout 200 | # t_embeddings = self.t_embeddings.dropout(self.t_embeddings.pe) 201 | t_embeddings = self.t_embeddings.pe 202 | else: 203 | t_embeddings = self.t_embeddings 204 | energy = gridattn_weight( 205 | projquery, 206 | projkey, 207 | self.x_embeddings, 208 | self.y_embeddings, 209 | t_embeddings, 210 | frm_indices.repeat(self.num_heads, 1), 211 | ) 212 | 213 | attn = F.softmax(energy, dim=1) 214 | 215 | return attn.view(n, self.num_heads, t + h + w - 2, t, h, w) 216 | 217 | 218 | class GridAttentionMap(torch.nn.Module): 219 | def __init__(self, in_dim, dim_reduce_factor, cfg): 220 | torch.nn.Module.__init__(self) 221 | 222 | self.num_heads = dim_reduce_factor 223 | 224 | self.value_conv = torch.nn.Conv3d( 225 | in_channels=in_dim, out_channels=in_dim, kernel_size=1 226 | ) 227 | 228 | def forward(self, attention, value): 229 | projvalue = self.value_conv(value) 230 | n, c, t, h, w = projvalue.shape 231 | projvalue = projvalue.view(n * self.num_heads, c // self.num_heads, t, h, w) 232 | 233 | attended_val = gridattn_map( 234 | attention.view(n * self.num_heads, t + h + w - 2, t, h, w), projvalue 235 | ) 236 | 237 | return attended_val.view(n, self.num_heads, c // self.num_heads, t, h, w) 238 | 239 | 240 | class GridAttention(torch.nn.Module): 241 | """NOTE(brendan): See GridAttentionWeight for W, H definition.""" 242 | 243 | def __init__(self, in_dim, cfg): 244 | torch.nn.Module.__init__(self) 245 | 246 | self.num_heads = cfg.MODEL.GRID_ATTENTION.DIM_REDUCE_FACTOR 247 | 248 | self.attnweight = GridAttentionWeight(in_dim, cfg) 249 | self.attnmap = GridAttentionMap(in_dim, self.num_heads, cfg) 250 | 251 | self.outconv = torch.nn.Conv3d( 252 | in_channels=in_dim, out_channels=in_dim, kernel_size=1 253 | ) 254 | 255 | def forward(self, query, key, value, frm_indices): 256 | n, c, t, h, w = value.shape 257 | attention = self.attnweight(query, key, frm_indices) 258 | 259 | attended_val = self.attnmap(attention, value) 260 | attended_val = attended_val.view(n, c, t, h, w) 261 | 262 | return self.outconv(attended_val) 263 | -------------------------------------------------------------------------------- /gridattention/src/gridattn.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "gridattn.h" 6 | 7 | __global__ void 8 | gridattn_map_forward_kernel(const float *attn, 9 | const float *val, 10 | float *out, 11 | int N, 12 | int C, 13 | int T, 14 | int H, 15 | int W) 16 | { 17 | int x = blockIdx.x*blockDim.x + threadIdx.x; 18 | int thread_yidx = blockIdx.y*blockDim.y + threadIdx.y; 19 | int t = thread_yidx % T; 20 | int y = thread_yidx / T; 21 | int c = blockIdx.z; 22 | int HW = H * W; 23 | int HWT = HW * T; 24 | int zdim = H + W + T - 2; 25 | int cHWT = c * HWT; 26 | int CHWT = C * HWT; 27 | 28 | if ((x >= W) || (y >= H) || (t >= T) || (c >= C)) 29 | return; 30 | 31 | int xyt_offset = t*HW + y*W + x; 32 | for (int Nidx = 0; 33 | Nidx < N; 34 | ++Nidx) { 35 | float accum = 0.0f; 36 | 37 | int channel_offset = Nidx*CHWT + cHWT; 38 | int attn_Nxyt_offset = Nidx*zdim*HWT + xyt_offset; 39 | for (int z = 0; 40 | z < W; 41 | ++z) { 42 | float attn_zxyt = attn[attn_Nxyt_offset + z*HWT]; 43 | float val_iyt = val[channel_offset + t*HW + y*W + z]; 44 | 45 | accum += attn_zxyt * val_iyt; 46 | } 47 | 48 | for (int z = W; 49 | z < (W + H - 1); 50 | ++z) { 51 | int j = z - W; 52 | j = (j < y) ? j : (j + 1); 53 | 54 | float attn_zxyt = attn[attn_Nxyt_offset + z*HWT]; 55 | float val_xjt = val[channel_offset + t*HW + j*W + x]; 56 | 57 | accum += attn_zxyt * val_xjt; 58 | } 59 | 60 | for (int z = (W + H - 1); 61 | z < (W + H + T - 2); 62 | ++z) { 63 | int k = z - (W + H - 1); 64 | k = (k < t) ? k : (k + 1); 65 | 66 | float attn_zxyt = attn[attn_Nxyt_offset + z*HWT]; 67 | float val_xyk = val[channel_offset + k*HW + y*W + x]; 68 | 69 | accum += attn_zxyt * val_xyk; 70 | } 71 | 72 | out[channel_offset + xyt_offset] = accum; 73 | } 74 | } 75 | 76 | __global__ void 77 | gridattn_map_backward_kernel_dattn(const float *dout, 78 | const float *val, 79 | float *dattn, 80 | int N, 81 | int C, 82 | int T, 83 | int H, 84 | int W) 85 | { 86 | int x = blockIdx.x*blockDim.x + threadIdx.x; 87 | int thread_yidx = blockIdx.y*blockDim.y + threadIdx.y; 88 | int t = thread_yidx % T; 89 | int y = thread_yidx / T; 90 | int z = blockIdx.z; 91 | int HW = H * W; 92 | int HWT = HW * T; 93 | int zdim = H + W + T - 2; 94 | int zHWT = z * HWT; 95 | int zdimHWT = zdim * HWT; 96 | /* TODO(brendan): divide by sqrt(C), in backward as well. */ 97 | 98 | if ((x >= W) || (y >= H) || (t >= T) || (z >= (H + W + T - 2))) 99 | return; 100 | 101 | int i; 102 | int j; 103 | int k; 104 | if (z < W) { 105 | i = z; 106 | j = y; 107 | k = t; 108 | } else if (z < (H + W - 1)) { 109 | i = x; 110 | j = z - W; 111 | j = j < y ? j : j + 1; 112 | k = t; 113 | } else { 114 | i = x; 115 | j = y; 116 | k = z - (H + W - 1); 117 | k = k < t ? k : k + 1; 118 | } 119 | int ijk_offset = k*HW + j*W + i; 120 | int xyt_offset = t*HW + y*W + x; 121 | int zxyt_offset = zHWT + xyt_offset; 122 | 123 | for (int Nidx = 0; 124 | Nidx < N; 125 | ++Nidx) { 126 | float accum = 0.0f; 127 | 128 | for (int Cidx = 0; 129 | Cidx < C; 130 | ++Cidx) { 131 | int channel_offset = (Nidx*C + Cidx)*HWT; 132 | float dout_xyt = dout[channel_offset + xyt_offset]; 133 | float val_ijk = val[channel_offset + ijk_offset]; 134 | 135 | accum += dout_xyt * val_ijk; 136 | } 137 | 138 | dattn[Nidx*zdimHWT + zxyt_offset] = accum; 139 | } 140 | } 141 | 142 | __global__ void 143 | gridattn_map_backward_kernel_dval(const float *attn, 144 | const float *dout, 145 | float *dval, 146 | int N, 147 | int C, 148 | int T, 149 | int H, 150 | int W) 151 | { 152 | int i = blockIdx.x*blockDim.x + threadIdx.x; 153 | int thread_yidx = blockIdx.y*blockDim.y + threadIdx.y; 154 | int k = thread_yidx % T; 155 | int j = thread_yidx / T; 156 | int c = blockIdx.z; 157 | int HW = H * W; 158 | int HWT = HW * T; 159 | int zdim = H + W + T - 2; 160 | int cHWT = c * HWT; 161 | int CHWT = C * HWT; 162 | 163 | if ((i >= W) || (j >= H) || (k >= T) || (c >= C)) 164 | return; 165 | 166 | int ijk_offset = k*HW + j*W + i; 167 | for (int Nidx = 0; 168 | Nidx < N; 169 | ++Nidx) { 170 | float accum = 0.0f; 171 | 172 | int channel_offset = Nidx*CHWT + cHWT; 173 | int attn_N_offset = Nidx * zdim * HWT; 174 | for (int x = 0; 175 | x < W; 176 | ++x) { 177 | int xjk_offset = k*HW + j*W + x; 178 | float attn_zxjk = attn[attn_N_offset + i*HWT + xjk_offset]; 179 | float dout_xjk = dout[channel_offset + xjk_offset]; 180 | 181 | accum += attn_zxjk * dout_xjk; 182 | } 183 | 184 | for (int y = 0; 185 | y < H; 186 | ++y) { 187 | if (y == j) 188 | continue; 189 | 190 | int z = (y > j) ? j : (j - 1); 191 | z += W; 192 | 193 | int iyk_offset = k*HW + y*W + i; 194 | float attn_ziyk = attn[attn_N_offset + z*HWT + iyk_offset]; 195 | float dout_iyk = dout[channel_offset + iyk_offset]; 196 | 197 | accum += attn_ziyk * dout_iyk; 198 | } 199 | 200 | for (int t = 0; 201 | t < T; 202 | ++t) { 203 | if (t == k) 204 | continue; 205 | 206 | int z = (t > k) ? k : (k - 1); 207 | z += W + H - 1; 208 | 209 | int ijt_offset = t*HW + j*W + i; 210 | float attn_zxyt = attn[attn_N_offset + z*HWT + ijt_offset]; 211 | float dout_ijt = dout[channel_offset + ijt_offset]; 212 | 213 | accum += attn_zxyt * dout_ijt; 214 | } 215 | 216 | dval[channel_offset + ijk_offset] = accum; 217 | } 218 | } 219 | 220 | __global__ void 221 | gridattn_forward_kernel(const float *query, 222 | const float *key, 223 | float *attnscores, 224 | int N, 225 | int C, 226 | int T, 227 | int H, 228 | int W) 229 | { 230 | int x = blockIdx.x*blockDim.x + threadIdx.x; 231 | int thread_yidx = blockIdx.y*blockDim.y + threadIdx.y; 232 | int t = thread_yidx % T; 233 | int y = thread_yidx / T; 234 | int z = blockIdx.z; 235 | int HW = H * W; 236 | int HWT = HW * T; 237 | int zdim = H + W + T - 2; 238 | int zHWT = z * HWT; 239 | int zdimHWT = zdim * HWT; 240 | 241 | if ((x >= W) || (y >= H) || (t >= T) || (z >= (H + W + T - 2))) 242 | return; 243 | 244 | int i; 245 | int j; 246 | int k; 247 | if (z < W) { 248 | i = z; 249 | j = y; 250 | k = t; 251 | } else if (z < (H + W - 1)) { 252 | i = x; 253 | j = z - W; 254 | j = j < y ? j : j + 1; 255 | k = t; 256 | } else { 257 | i = x; 258 | j = y; 259 | k = z - (H + W - 1); 260 | k = k < t ? k : k + 1; 261 | } 262 | int ijk_offset = k*HW + j*W + i; 263 | int xyt_offset = t*HW + y*W + x; 264 | int zxyt_offset = zHWT + xyt_offset; 265 | 266 | for (int Nidx = 0; 267 | Nidx < N; 268 | ++Nidx) { 269 | float accum = 0.0f; 270 | for (int Cidx = 0; 271 | Cidx < C; 272 | ++Cidx) { 273 | int channel_offset = (Nidx*C + Cidx)*HWT; 274 | float query_xyt = query[channel_offset + xyt_offset]; 275 | float key_ijk = key[channel_offset + ijk_offset]; 276 | 277 | accum += query_xyt * key_ijk; 278 | } 279 | 280 | attnscores[Nidx*zdimHWT + zxyt_offset] = accum; 281 | } 282 | } 283 | 284 | __global__ void 285 | gridattn_backward_kernel_dquery(const float *dattn, 286 | const float *key, 287 | float *dquery, 288 | int N, 289 | int C, 290 | int T, 291 | int H, 292 | int W) 293 | { 294 | int x = blockIdx.x*blockDim.x + threadIdx.x; 295 | int thread_yidx = blockIdx.y*blockDim.y + threadIdx.y; 296 | int t = thread_yidx % T; 297 | int y = thread_yidx / T; 298 | int c = blockIdx.z; 299 | int HW = H * W; 300 | int HWT = HW * T; 301 | int zdim = H + W + T - 2; 302 | int cHWT = c * HWT; 303 | int CHWT = C * HWT; 304 | 305 | if ((x >= W) || (y >= H) || (t >= T) || (c >= C)) 306 | return; 307 | 308 | int xyt_offset = t*HW + y*W + x; 309 | for (int Nidx = 0; 310 | Nidx < N; 311 | ++Nidx) { 312 | float accum = 0.0f; 313 | 314 | int channel_offset = Nidx*CHWT + cHWT; 315 | int attn_Nxyt_offset = Nidx*zdim*HWT + xyt_offset; 316 | for (int z = 0; 317 | z < W; 318 | ++z) { 319 | float dattn_zxyt = dattn[attn_Nxyt_offset + z*HWT]; 320 | float key_iyt = key[channel_offset + t*HW + y*W + z]; 321 | 322 | accum += dattn_zxyt * key_iyt; 323 | } 324 | 325 | for (int z = W; 326 | z < (W + H - 1); 327 | ++z) { 328 | int j = z - W; 329 | j = (j < y) ? j : (j + 1); 330 | 331 | float dattn_zxyt = dattn[attn_Nxyt_offset + z*HWT]; 332 | float key_xjt = key[channel_offset + t*HW + j*W + x]; 333 | 334 | accum += dattn_zxyt * key_xjt; 335 | } 336 | 337 | for (int z = (W + H - 1); 338 | z < (W + H + T - 2); 339 | ++z) { 340 | int k = z - (W + H - 1); 341 | k = (k < t) ? k : (k + 1); 342 | 343 | float dattn_zxyt = dattn[attn_Nxyt_offset + z*HWT]; 344 | float key_xyk = key[channel_offset + k*HW + y*W + x]; 345 | 346 | accum += dattn_zxyt * key_xyk; 347 | } 348 | 349 | dquery[channel_offset + xyt_offset] = accum; 350 | } 351 | } 352 | 353 | __global__ void 354 | gridattn_backward_kernel_dkey(const float *dattn, 355 | const float *query, 356 | float *dkey, 357 | int N, 358 | int C, 359 | int T, 360 | int H, 361 | int W) 362 | { 363 | int i = blockIdx.x*blockDim.x + threadIdx.x; 364 | int thread_yidx = blockIdx.y*blockDim.y + threadIdx.y; 365 | int k = thread_yidx % T; 366 | int j = thread_yidx / T; 367 | int c = blockIdx.z; 368 | int HW = H * W; 369 | int HWT = HW * T; 370 | int zdim = H + W + T - 2; 371 | int cHWT = c * HWT; 372 | int CHWT = C * HWT; 373 | 374 | if ((i >= W) || (j >= H) || (k >= T) || (c >= C)) 375 | return; 376 | 377 | int ijk_offset = k*HW + j*W + i; 378 | for (int Nidx = 0; 379 | Nidx < N; 380 | ++Nidx) { 381 | float dkey_accum = 0.0f; 382 | 383 | int channel_offset = Nidx*CHWT + cHWT; 384 | int attn_N_offset = Nidx * zdim * HWT; 385 | for (int x = 0; 386 | x < W; 387 | ++x) { 388 | int xjk_offset = k*HW + j*W + x; 389 | float dattn_zxjk = dattn[attn_N_offset + i*HWT + xjk_offset]; 390 | float query_xjk = query[channel_offset + xjk_offset]; 391 | 392 | dkey_accum += dattn_zxjk * query_xjk; 393 | } 394 | 395 | for (int y = 0; 396 | y < H; 397 | ++y) { 398 | if (y == j) 399 | continue; 400 | 401 | int z = (y > j) ? j : (j - 1); 402 | z += W; 403 | 404 | int iyk_offset = k*HW + y*W + i; 405 | float dattn_ziyk = dattn[attn_N_offset + z*HWT + iyk_offset]; 406 | float query_iyk = query[channel_offset + iyk_offset]; 407 | dkey_accum += dattn_ziyk * query_iyk; 408 | } 409 | 410 | for (int t = 0; 411 | t < T; 412 | ++t) { 413 | if (t == k) 414 | continue; 415 | #ifdef MASKED_GRIDATTENTION 416 | if (k > t) 417 | continue; 418 | #endif /* MASKED_GRIDATTENTION */ 419 | 420 | int z = (t > k) ? k : (k - 1); 421 | z += W + H - 1; 422 | 423 | int ijt_offset = t*HW + j*W + i; 424 | float dattn_zxyt = dattn[attn_N_offset + z*HWT + ijt_offset]; 425 | float query_ijt = query[channel_offset + ijt_offset]; 426 | dkey_accum += dattn_zxyt * query_ijt; 427 | } 428 | 429 | dkey[channel_offset + ijk_offset] = dkey_accum; 430 | } 431 | } 432 | 433 | /* 434 | * Implementations 435 | */ 436 | extern "C" int 437 | _gridattn_map_forward_cuda(int N, 438 | int C, 439 | int T, 440 | int H, 441 | int W, 442 | const float *attn, 443 | const float *val, 444 | float *out, 445 | cudaStream_t stream) 446 | { 447 | // Run kernel 448 | dim3 threads_per_block{32, 32}; 449 | uint32_t d1 = (W + threads_per_block.x - 1)/threads_per_block.x; 450 | uint32_t d2 = (T*H + threads_per_block.y - 1)/threads_per_block.y; 451 | uint32_t d3 = C; 452 | dim3 blocks{d1, d2, d3}; 453 | gridattn_map_forward_kernel<<>>(attn, 454 | val, 455 | out, 456 | N, 457 | C, 458 | T, 459 | H, 460 | W); 461 | 462 | // Check for errors 463 | cudaError_t err = cudaGetLastError(); 464 | if (err != cudaSuccess) 465 | return 0; 466 | 467 | return 1; 468 | } 469 | 470 | extern "C" int 471 | _gridattn_map_backward_cuda(int N, 472 | int C, 473 | int T, 474 | int H, 475 | int W, 476 | const float *dout, 477 | const float *attn, 478 | const float *val, 479 | float *dattn, 480 | float *dval, 481 | cudaStream_t stream) 482 | { 483 | // Run kernel 484 | dim3 threads_per_block{32, 32}; 485 | uint32_t d1 = (W + threads_per_block.x - 1) / threads_per_block.x; 486 | uint32_t d2 = (T*H + threads_per_block.y - 1) / threads_per_block.y; 487 | uint32_t d3 = H + W + T; 488 | dim3 blocks{d1, d2, d3}; 489 | gridattn_map_backward_kernel_dattn<<>>(dout, 490 | val, 491 | dattn, 492 | N, 493 | C, 494 | T, 495 | H, 496 | W); 497 | 498 | d3 = C; 499 | blocks = dim3{d1, d2, d3}; 500 | gridattn_map_backward_kernel_dval<<>>(attn, 501 | dout, 502 | dval, 503 | N, 504 | C, 505 | T, 506 | H, 507 | W); 508 | 509 | // Check for errors 510 | cudaError_t err = cudaGetLastError(); 511 | if (err != cudaSuccess) 512 | return 0; 513 | 514 | return 1; 515 | } 516 | 517 | extern "C" int 518 | _gridattn_forward_cuda(int N, 519 | int C, 520 | int T, 521 | int H, 522 | int W, 523 | const float *query_data, 524 | const float *key_data, 525 | float *attnscores_data, 526 | cudaStream_t stream) 527 | { 528 | dim3 threads_per_block{32, 32}; 529 | uint32_t d1 = (W + threads_per_block.x - 1) / threads_per_block.x; 530 | uint32_t d2 = (T*H + threads_per_block.y - 1) / threads_per_block.y; 531 | uint32_t d3 = H + W + T; 532 | dim3 blocks{d1, d2, d3}; 533 | gridattn_forward_kernel<<>>(query_data, 534 | key_data, 535 | attnscores_data, 536 | N, 537 | C, 538 | T, 539 | H, 540 | W); 541 | 542 | cudaError_t err = cudaGetLastError(); 543 | if (err != cudaSuccess) 544 | return 0; 545 | 546 | return 1; 547 | } 548 | 549 | extern "C" int 550 | _gridattn_backward_cuda(int N, 551 | int C, 552 | int T, 553 | int H, 554 | int W, 555 | const float *dattn_data, 556 | const float *query_data, 557 | const float *key_data, 558 | float *dquery_data, 559 | float *dkey_data, 560 | cudaStream_t stream) 561 | { 562 | dim3 threads_per_block{32, 32}; 563 | uint32_t d1 = (W + threads_per_block.x - 1)/threads_per_block.x; 564 | uint32_t d2 = (T*H + threads_per_block.y - 1)/threads_per_block.y; 565 | uint32_t d3 = C; 566 | dim3 blocks{d1, d2, d3}; 567 | gridattn_backward_kernel_dquery<<>>(dattn_data, 568 | key_data, 569 | dquery_data, 570 | N, 571 | C, 572 | T, 573 | H, 574 | W); 575 | 576 | gridattn_backward_kernel_dkey<<>>(dattn_data, 577 | query_data, 578 | dkey_data, 579 | N, 580 | C, 581 | T, 582 | H, 583 | W); 584 | 585 | cudaError_t err = cudaGetLastError(); 586 | if (err != cudaSuccess) 587 | return 0; 588 | 589 | return 1; 590 | } 591 | --------------------------------------------------------------------------------