├── .gitmodules ├── README.md ├── assets ├── car_final.gif ├── car_relight.gif ├── demo.gif ├── domino_composite.gif ├── domino_sim.gif ├── method.png ├── pig_ball_composite.gif ├── pig_ball_relight.gif ├── pool_composite.gif └── pool_sim.gif ├── custom_data ├── balls_shelf │ └── original.png ├── boxes │ └── original.png ├── kitchen │ └── original.png ├── table │ └── original.png └── wall_toy │ └── original.png ├── data ├── balls │ ├── depth.npy │ ├── inpaint.png │ ├── intermediate │ │ ├── albedo_vis.png │ │ ├── depth_vis.png │ │ ├── mask.png │ │ ├── normal_vis.png │ │ ├── obj_movable.json │ │ ├── shading_vis.png │ │ └── vis_mask.jpg │ ├── mask.png │ ├── normal.npy │ ├── original.png │ ├── shading.npy │ ├── sim.yaml │ └── vis.png ├── car │ ├── inpaint.png │ ├── mask.png │ ├── normal.npy │ ├── original.png │ ├── shading.npy │ └── sim.yaml ├── domino │ ├── depth.npy │ ├── inpaint.png │ ├── intermediate │ │ ├── albedo_vis.png │ │ ├── depth_vis.png │ │ ├── normal_vis.png │ │ ├── obj_movable.json │ │ ├── shading_vis.png │ │ └── vis_mask.jpg │ ├── mask.png │ ├── normal.npy │ ├── original.png │ ├── shading.npy │ ├── sim.yaml │ └── vis.png ├── pig_ball │ ├── depth.npy │ ├── inpaint.png │ ├── intermediate │ │ ├── albedo_vis.png │ │ ├── bg_mask_vis.png │ │ ├── depth_vis.png │ │ ├── edge_vis.png │ │ ├── fg_mask_vis.png │ │ ├── normal_vis.png │ │ ├── obj_movable.json │ │ ├── shading_vis.png │ │ └── vis_mask.jpg │ ├── mask.png │ ├── normal.npy │ ├── original.png │ ├── shading.npy │ ├── sim.yaml │ └── vis.png └── pool │ ├── depth.npy │ ├── inpaint.png │ ├── intermediate │ ├── albedo_vis.png │ ├── depth_vis.png │ ├── normal_vis.png │ ├── obj_movable.json │ ├── shading_vis.png │ └── vis_mask.jpg │ ├── mask.png │ ├── normal.npy │ ├── original.png │ ├── shading.npy │ ├── sim.yaml │ └── vis.png ├── diffusion └── video_diffusion.py ├── perception ├── README.md ├── gpt │ ├── __init__.py │ ├── gpt_configs │ │ ├── my_apikey │ │ ├── physics │ │ │ └── user.txt │ │ └── ram │ │ │ └── user.txt │ └── gpt_utils.py ├── gpt_physic.py ├── gpt_ram.py ├── run_albedo_shading.py ├── run_depth_normal.py ├── run_fg_bg.py ├── run_gsam.py └── run_inpaint.py ├── relight ├── relight.py └── relight_utils.py ├── requirements.txt ├── scripts └── run_demo.sh └── simulation ├── animate.py ├── animate_utils.py └── sim_utils.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "diffusion/SEINE"] 2 | path = diffusion/SEINE 3 | url = https://github.com/Vchitect/SEINE.git 4 | [submodule "perception/Grounded-Segment-Anything"] 5 | path = perception/Grounded-Segment-Anything 6 | url = https://github.com/IDEA-Research/Grounded-Segment-Anything.git 7 | [submodule "perception/GeoWizard"] 8 | path = perception/GeoWizard 9 | url = https://github.com/fuxiao0719/GeoWizard 10 | [submodule "perception/Inpaint-Anything"] 11 | path = perception/Inpaint-Anything 12 | url = https://github.com/geekyutao/Inpaint-Anything.git 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | 4 |

PhysGen: Rigid-Body Physics-Grounded
Image-to-Video Generation

5 | 6 |

7 | ECCV, 2024 8 |
9 | Shaowei Liu 10 | · 11 | Zhongzheng Ren 12 | · 13 | Saurabh Gupta* 14 | · 15 | Shenlong Wang* 16 | · 17 |

18 | 19 |

20 | Demo GIF 21 |

22 | 23 |

24 | 25 | Paper PDF 26 | Arxiv 27 | 28 | Project Page 29 | Google Colab 30 | 31 | Youtube Video 32 |

33 | 34 |

35 |
36 | 37 | This repository contains the pytorch implementation for the paper [PhysGen: Rigid-Body Physics-Grounded Image-to-Video Generation](https://stevenlsw.github.io/physgen/), ECCV 2024. In this paper, we present a novel training-free image-to-video generation pipeline integrates physical simulation and generative video diffusion prior. 38 | 39 | ## Overview 40 | ![overview](assets/method.png) 41 | 42 | ## 📄 Table of Contents 43 | 44 | - [Installation](#installation) 45 | - [Colab Notebook](#colab-notebook) 46 | - [Quick Demo](#quick-demo) 47 | - [Perception](#perception) 48 | - [Simulation](#simulation) 49 | - [Rendering](#rendering) 50 | - [All-in-One command](#all-in-one-command) 51 | - [Evaluation](#evaluation) 52 | - [Custom Image Video Generation](#custom-image-video-generation) 53 | - [Citation](#citation) 54 | 55 | 56 | ## Installation 57 | - Clone this repository: 58 | ```Shell 59 | git clone --recurse-submodules https://github.com/stevenlsw/physgen.git 60 | cd physgen 61 | ``` 62 | - Install requirements by the following commands: 63 | ```Shell 64 | conda create -n physgen python=3.9 65 | conda activate physgen 66 | pip install -r requirements.txt 67 | ``` 68 | 69 | ## Colab Notebook 70 | Run our [Colab notebook](https://colab.research.google.com/drive/1imGIms3Y4RRtddA6IuxZ9bkP7N2gVVC_) for quick start! 71 | 72 | 73 | 74 | ## Quick Demo 75 | 76 | - Run image space dynamics simulation in just **3** seconds **without GPU and any displace device and additional setup** required! 77 | ```Shell 78 | export PYTHONPATH=$(pwd) 79 | name="pool" 80 | python simulation/animate.py --data_root data --save_root outputs --config data/${name}/sim.yaml 81 | ``` 82 | - The output video should be saved in `outputs/${name}/composite.mp4`. Try set `name` to be `domino`, `balls`, `pig_ball` and `car` for other scenes exploration. The example outputs are shown below: 83 | 84 | | **Input Image** | **Simulation** | **Output Video** | 85 | |:---------------:|:--------------:|:----------------:| 86 | | Pool Original Image | Pool Simulation GIF | Pool Composite GIF | 87 | | Domino Original Image | Domino Simulation GIF | Domino Composite GIF | 88 | 89 | 90 | ## Perception 91 | - Please see [perception/README.md](perception/README.md) for details. 92 | 93 | | **Input** | **Segmentation** | **Normal** | **Albedo** | **Shading** | **Inpainting** | 94 | |:---------:|:----------------:|:----------:|:----------:|:-----------:|:--------------:| 95 | | input | segmentation | normal | albedo | shading | inpainting | 96 | 97 | 98 | 99 | ## Simulation 100 | - Simulation requires the following input for each image: 101 | ```Shell 102 | image folder/ 103 | ├── original.png 104 | ├── mask.png # segmentation mask 105 | ├── inpaint.png # background inpainting 106 | ├── sim.yaml # simulation configuration file 107 | ``` 108 | 109 | - `sim.yaml` specify the physical properties of each object and initial conditions (force and speed on each object). Please see `data/pig_ball/sim.yaml` for an example. Set `display` to `true` to visualize the simulation process with display device, set `save_snapshot` to `true` to save the simulation snapshots. 110 | - Run the simulation by the following command: 111 | ```Shell 112 | cd simulation 113 | python animate.py --data_root ../data --save_root ../outputs --config ../data/${name}/sim.yaml 114 | ``` 115 | - The outputs are saved in `outputs/${name}` as follows: 116 | ```Shell 117 | output folder/ 118 | ├── history.pkl # simulation history 119 | ├── composite.mp4 # composite video 120 | |── composite.pt # composite video tensor 121 | ├── mask_video.pt # foreground masked video tensor 122 | ├── trans_list.pt # objects transformation list tensor 123 | ``` 124 | 125 | ## Rendering 126 | 127 | ### Relighting 128 | - Relighting requires the following input: 129 | ```Shell 130 | image folder/ # 131 | ├── normal.npy # normal map 132 | ├── shading.npy # shading map by intrinsic decomposition 133 | previous output folder/ 134 | ├── composite.pt # composite video 135 | ├── mask_video.pt # foreground masked video tensor 136 | ├── trans_list.pt # objects transformation list tensor 137 | 138 | ``` 139 | - The `perception_input` is the image folder contains the perception result. The `previous_output` is the output folder from the previous simulation step. 140 | - Run the relighting by the following command: 141 | ```Shell 142 | cd relight 143 | python relight.py --perception_input ../data/${name} --previous_output ../outputs/${name} 144 | ``` 145 | - The output `relight.mp4` and `relight.pt` is the relighted video and tensor. 146 | - Compare between composite video and relighted video: 147 | | **Input Image** | **Composite Video** | **Relight Video** | 148 | |:---------------:|:-------------------:|:-----------------:| 149 | | Original Input Image | Pig Ball Composite GIF | Pig Ball Relight GIF | 150 | 151 | 152 | 153 | ### Video Diffusion Rendering 154 | - Download the [SEINE](https://github.com/Vchitect/SEINE/) model follow [instruction](https://github.com/Vchitect/SEINE/tree/main?tab=readme-ov-file#download-our-model-and-t2i-base-model) 155 | 156 | ```Shell 157 | # install git-lfs beforehand 158 | mkdir -p diffusion/SEINE/pretrained 159 | git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 diffusion/SEINE/pretrained/stable-diffusion-v1-4 160 | wget -P diffusion/SEINE/pretrained https://huggingface.co/Vchitect/SEINE/resolve/main/seine.pt 161 | ``` 162 | 163 | - The video diffusion rendering requires the following input: 164 | ```Shell 165 | image folder/ # 166 | ├── original.png # input image 167 | ├── sim.yaml # simulation configuration file (optional) 168 | previous output folder/ 169 | ├── relight.pt # composite video 170 | ├── mask_video.pt # foreground masked video tensor 171 | ``` 172 | - Run the video diffusion rendering by the following command: 173 | ```Shell 174 | cd diffusion 175 | python video_diffusion.py --perception_input ../data/${name} --previous_output ../outputs/${name} 176 | ``` 177 | `denoise_strength` and `prompt` could be adjusted in the above script. `denoise_strength` controls the amount of noise added, 0 means no denoising, 1 means denoise from scratch with lots of variance to the input image. `prompt` is the input prompt for video diffusion model, we use default foreground object names from perception model as prompt. 178 | 179 | 180 | - The output `final_video.mp4` is the rendered video. 181 | 182 | - Compare between relight video and diffuson rendered video: 183 | | **Input Image** | **Relight Video** | **Final Video** | 184 | |:--------------------------------------:|:--------------------------------------------:|:--------------------------------------------:| 185 | | Original Input Image | Car Composite GIF | Car Relight GIF | 186 | 187 | 188 | 189 | ## All-in-One command 190 | We integrate the simulation, relighting and video diffusion rendering in one script. Please follow the [Video Diffusion Rendering](#video-diffusion-rendering) to download the SEINE model first. 191 | ```Shell 192 | bash scripts/run_demo.sh ${name} 193 | ``` 194 | 195 | ## Evaluation 196 | We compare ours against open-sourced img-to-video models [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter), [I2VGen-XL](https://github.com/ali-vilab/VGen), [SEINE](https://github.com/Vchitect/SEINE) and collected reference videos [GT]() in Sec. 4.3. 197 | 198 | - Install [pytorch-fid](https://github.com/mseitzer/pytorch-fid): 199 | 200 | ``` 201 | pip install pytorch-fid 202 | ``` 203 | 204 | - Download the evaluation data from [here](https://uofi.box.com/s/zl8au6w3jopke9sxb7v9sdyboglcozhl) for all comparisons and unzip to `evaluation` directory. Choose `${method name}` from `DynamiCrafter`, `I2VGen-XL`, `SEINE`, `ours`. 205 | 206 | 207 | - Evaluate image FID: 208 | ```Shell 209 | python -m pytorch_fid evaluation/${method name}/all evaluation/GT/all 210 | ``` 211 | 212 | - Evaluate motion FID: 213 | ```Shell 214 | python -m pytorch_fid evaluation/${method name}/all_flow evaluation/GT/all_flow 215 | ``` 216 | 217 | - For motion FID, we use [RAFT](https://github.com/princeton-vl/RAFT) to compute optical flow between neighbor frames. The video processing scripts can be found [here](https://drive.google.com/drive/folders/10KDXRGEdcYSJuxLp8v6u1N5EB8Ghs6Xk?usp=sharing). 218 | 219 | 220 | 221 | ## Custom Image Video Generation 222 | 223 | - Our method should generally work for side-view and top-down view images. For custom images, please follow the [perception](#perception), [simulation](#simulation), [rendering](#rendering) pipeline to generate the video. 224 | 225 | - Critical steps (assume proper environment installed) 226 | 227 | - Input: 228 | ```Shell 229 | image folder/ 230 | ├── original.png 231 | ``` 232 | - Perception: 233 | ```Shell 234 | cd perception/ 235 | python gpt_ram.py --img_path ${image folder} 236 | python run_gsam.py --input ${image folder} 237 | python run_depth_normal.py --input ${image folder} --vis 238 | python run_fg_bg.py --input ${image folder} --vis_edge 239 | python run_inpaint.py --input ${image folder} --dilate_kernel_size 20 240 | python run_albedo_shading.py --input ${image folder} --vis 241 | ``` 242 | 243 | - After perception step, you should get 244 | ```Shell 245 | image folder/ 246 | ├── original.png 247 | ├── mask.png # foreground segmentation mask 248 | ├── inpaint.png # background inpainting 249 | ├── normal.npy # normal map 250 | ├── shading.npy # shading map by intrinsic decomposition 251 | ├── edges.json # edges 252 | ├── physics.yaml # physics properties of foreground objects 253 | ``` 254 | 255 | - Compose `${image folder}/sim.yaml` for simulation by specifying the object init conditions (you could check foreground objects ids in `${image folder}/intermediate/fg_mask_vis.png`), please see example in `data/pig_ball/sim.yaml`, copy the content in `physics.yaml` to `sim.yaml` and edges information from `edges.json`. 256 | 257 | - Run simulation: 258 | ```Shell 259 | cd simulation/ 260 | python animate.py --data_root ${image_folder} --save_root ${image_folder} --config ${image_folder}/sim.yaml 261 | ``` 262 | 263 | - Run rendering: 264 | ```Shell 265 | cd relight/ 266 | python relight.py --perception_input ${image_folder} --previous_output ${image_folder} 267 | cd ../diffusion/ 268 | python video_diffusion.py --perception_input ${image_folder} --previous_output ${image_folder} --denoise_strength ${denoise_strength} 269 | ``` 270 | 271 | - We put some custom images under `custom_data` folder. You could play with each image by running the above steps and see different physical simulations. 272 | 273 | | **Balls Shelf** | **Boxes** | **Kitchen** | **Table** | **Toy** 274 | | :---------------: | :-------: | :---------: | :-------: | :----------: | 275 | | Balls Shelf | Boxes | Kitchen | Table | Wall Toy | 276 | 277 | ## Citation 278 | 279 | If you find our work useful in your research, please cite: 280 | 281 | ```BiBTeX 282 | @inproceedings{liu2024physgen, 283 | title={PhysGen: Rigid-Body Physics-Grounded Image-to-Video Generation}, 284 | author={Liu, Shaowei and Ren, Zhongzheng and Gupta, Saurabh and Wang, Shenlong}, 285 | booktitle={European Conference on Computer Vision ECCV}, 286 | year={2024} 287 | } 288 | ``` 289 | 290 | 291 | ## Acknowledgement 292 | * [Grounded-Segment-Anything 293 | ](https://github.com/IDEA-Research/Grounded-Segment-Anything) for segmentation in [perception](#perception) 294 | * [GeoWizard 295 | ](https://github.com/fuxiao0719/GeoWizard) for depth and normal estimation in [perception](#perception) 296 | * [Intrinsic](https://github.com/compphoto/Intrinsic/) for intrinsic image decomposition in [perception](#perception) 297 | * [Inpaint-Anything](https://github.com/geekyutao/Inpaint-Anything) for image inpainting in [perception](#perception) 298 | * [Pymunk](https://github.com/viblo/pymunk) for physics simulation in [simulation](#simulation) 299 | * [SEINE](https://github.com/Vchitect/SEINE/) for video diffusion in [rendering](#rendering) 300 | -------------------------------------------------------------------------------- /assets/car_final.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/car_final.gif -------------------------------------------------------------------------------- /assets/car_relight.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/car_relight.gif -------------------------------------------------------------------------------- /assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/demo.gif -------------------------------------------------------------------------------- /assets/domino_composite.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/domino_composite.gif -------------------------------------------------------------------------------- /assets/domino_sim.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/domino_sim.gif -------------------------------------------------------------------------------- /assets/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/method.png -------------------------------------------------------------------------------- /assets/pig_ball_composite.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/pig_ball_composite.gif -------------------------------------------------------------------------------- /assets/pig_ball_relight.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/pig_ball_relight.gif -------------------------------------------------------------------------------- /assets/pool_composite.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/pool_composite.gif -------------------------------------------------------------------------------- /assets/pool_sim.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/pool_sim.gif -------------------------------------------------------------------------------- /custom_data/balls_shelf/original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/custom_data/balls_shelf/original.png -------------------------------------------------------------------------------- /custom_data/boxes/original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/custom_data/boxes/original.png -------------------------------------------------------------------------------- /custom_data/kitchen/original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/custom_data/kitchen/original.png -------------------------------------------------------------------------------- /custom_data/table/original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/custom_data/table/original.png -------------------------------------------------------------------------------- /custom_data/wall_toy/original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/custom_data/wall_toy/original.png -------------------------------------------------------------------------------- /data/balls/depth.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/depth.npy -------------------------------------------------------------------------------- /data/balls/inpaint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/inpaint.png -------------------------------------------------------------------------------- /data/balls/intermediate/albedo_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/intermediate/albedo_vis.png -------------------------------------------------------------------------------- /data/balls/intermediate/depth_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/intermediate/depth_vis.png -------------------------------------------------------------------------------- /data/balls/intermediate/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/intermediate/mask.png -------------------------------------------------------------------------------- /data/balls/intermediate/normal_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/intermediate/normal_vis.png -------------------------------------------------------------------------------- /data/balls/intermediate/obj_movable.json: -------------------------------------------------------------------------------- 1 | {"background": false, "puck": true} -------------------------------------------------------------------------------- /data/balls/intermediate/shading_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/intermediate/shading_vis.png -------------------------------------------------------------------------------- /data/balls/intermediate/vis_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/intermediate/vis_mask.jpg -------------------------------------------------------------------------------- /data/balls/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/mask.png -------------------------------------------------------------------------------- /data/balls/normal.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/normal.npy -------------------------------------------------------------------------------- /data/balls/original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/original.png -------------------------------------------------------------------------------- /data/balls/shading.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/shading.npy -------------------------------------------------------------------------------- /data/balls/sim.yaml: -------------------------------------------------------------------------------- 1 | # general setting 2 | cat: balls 3 | # simulation setting 4 | animation_frames: 16 5 | size: [512, 512] 6 | num_steps: 200 7 | # enable for simulation visualization 8 | display: false 9 | save_snapshot: false 10 | gravity: 0 11 | init_velocity: 12 | 6: [40, -150] 13 | init_acc: 14 | # perception setting 15 | edge_list: [[[1, 1], [1, 511]], [[1, 511], [511, 511]], [[511, 511], [511, 1]], [ 16 | [511, 1], [1, 1]]] 17 | obj_info: 18 | 1: 19 | label: puck 20 | mass: 1900 21 | density: 22 | friction: 0.1 23 | elasticity: 0.05 24 | primitive: circle 25 | 2: 26 | label: puck 27 | mass: 1900 28 | density: 29 | friction: 0.1 30 | elasticity: 0.05 31 | primitive: circle 32 | 3: 33 | label: puck 34 | mass: 1900 35 | density: 36 | friction: 0.1 37 | elasticity: 0.05 38 | primitive: circle 39 | 4: 40 | label: puck 41 | mass: 1900 42 | density: 43 | friction: 0.1 44 | elasticity: 0.05 45 | primitive: circle 46 | 5: 47 | label: puck 48 | mass: 1900 49 | density: 50 | friction: 0.1 51 | elasticity: 0.05 52 | primitive: circle 53 | 6: 54 | label: puck 55 | mass: 1900 56 | density: 57 | friction: 0.1 58 | elasticity: 0.05 59 | primitive: circle 60 | # diffusion setting (optional) 61 | denoise_strength: 0.5 62 | -------------------------------------------------------------------------------- /data/balls/vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/vis.png -------------------------------------------------------------------------------- /data/car/inpaint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/car/inpaint.png -------------------------------------------------------------------------------- /data/car/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/car/mask.png -------------------------------------------------------------------------------- /data/car/normal.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/car/normal.npy -------------------------------------------------------------------------------- /data/car/original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/car/original.png -------------------------------------------------------------------------------- /data/car/shading.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/car/shading.npy -------------------------------------------------------------------------------- /data/car/sim.yaml: -------------------------------------------------------------------------------- 1 | # general setting 2 | cat: car 3 | # simulation setting 4 | animation_frames: 16 5 | size: [512, 512] 6 | num_steps: 100 7 | # enable for simulation visualization 8 | display: false 9 | save_snapshot: false 10 | gravity: 980 11 | ground_friction: 0.9 12 | ground_elasticity: 1.0 13 | init_velocity: 14 | 1: [-400, 0] 15 | init_acc: 16 | # perception setting 17 | edge_list: [[[0, 303], [512, 303]]] 18 | obj_info: 19 | 1: 20 | label: toy car 21 | mass: 22 | density: 0.003 23 | friction: 0.2 24 | elasticity: 0.05 25 | primitive: polygon 26 | # diffusion setting (optional) 27 | denoise_strength: 0.65 28 | -------------------------------------------------------------------------------- /data/domino/depth.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/depth.npy -------------------------------------------------------------------------------- /data/domino/inpaint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/inpaint.png -------------------------------------------------------------------------------- /data/domino/intermediate/albedo_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/intermediate/albedo_vis.png -------------------------------------------------------------------------------- /data/domino/intermediate/depth_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/intermediate/depth_vis.png -------------------------------------------------------------------------------- /data/domino/intermediate/normal_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/intermediate/normal_vis.png -------------------------------------------------------------------------------- /data/domino/intermediate/obj_movable.json: -------------------------------------------------------------------------------- 1 | {"candle": true, "table": false, "background": false} -------------------------------------------------------------------------------- /data/domino/intermediate/shading_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/intermediate/shading_vis.png -------------------------------------------------------------------------------- /data/domino/intermediate/vis_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/intermediate/vis_mask.jpg -------------------------------------------------------------------------------- /data/domino/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/mask.png -------------------------------------------------------------------------------- /data/domino/normal.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/normal.npy -------------------------------------------------------------------------------- /data/domino/original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/original.png -------------------------------------------------------------------------------- /data/domino/shading.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/shading.npy -------------------------------------------------------------------------------- /data/domino/sim.yaml: -------------------------------------------------------------------------------- 1 | # general setting 2 | cat: domino 3 | # simulation setting 4 | animation_frames: 16 5 | size: [512, 512] 6 | num_steps: 160 7 | # enable for simulation visualization 8 | display: false 9 | save_snapshot: false 10 | gravity: 980 11 | init_velocity: 12 | 6: [-80, 0] 13 | init_acc: 14 | # perception setting 15 | edge_list: [[[0, 457], [512, 457]]] 16 | obj_info: 17 | 1: 18 | label: white cuboid 19 | mass: 20 | density: 2.5 21 | friction: 0.5 22 | elasticity: 0.01 23 | primitive: polygon 24 | 2: 25 | label: white cuboid 26 | mass: 27 | density: 2.5 28 | friction: 0.5 29 | elasticity: 0.01 30 | primitive: polygon 31 | 3: 32 | label: white cuboid 33 | mass: 34 | density: 2.5 35 | friction: 0.5 36 | elasticity: 0.01 37 | primitive: polygon 38 | 4: 39 | label: white cuboid 40 | mass: 41 | density: 2.5 42 | friction: 0.5 43 | elasticity: 0.01 44 | primitive: polygon 45 | 5: 46 | label: white cuboid 47 | mass: 48 | density: 2.5 49 | friction: 0.5 50 | elasticity: 0.01 51 | primitive: polygon 52 | 6: 53 | label: white cuboid 54 | mass: 55 | density: 2.5 56 | friction: 0.5 57 | elasticity: 0.01 58 | primitive: polygon 59 | # diffusion setting (optional) 60 | denoise_strength: 0.45 61 | -------------------------------------------------------------------------------- /data/domino/vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/vis.png -------------------------------------------------------------------------------- /data/pig_ball/depth.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/depth.npy -------------------------------------------------------------------------------- /data/pig_ball/inpaint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/inpaint.png -------------------------------------------------------------------------------- /data/pig_ball/intermediate/albedo_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/intermediate/albedo_vis.png -------------------------------------------------------------------------------- /data/pig_ball/intermediate/bg_mask_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/intermediate/bg_mask_vis.png -------------------------------------------------------------------------------- /data/pig_ball/intermediate/depth_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/intermediate/depth_vis.png -------------------------------------------------------------------------------- /data/pig_ball/intermediate/edge_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/intermediate/edge_vis.png -------------------------------------------------------------------------------- /data/pig_ball/intermediate/fg_mask_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/intermediate/fg_mask_vis.png -------------------------------------------------------------------------------- /data/pig_ball/intermediate/normal_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/intermediate/normal_vis.png -------------------------------------------------------------------------------- /data/pig_ball/intermediate/obj_movable.json: -------------------------------------------------------------------------------- 1 | {"ball": true, "shelf": false, "piggybank": true, "book": true, "wall": false, "floor": false, "baseboard": false} -------------------------------------------------------------------------------- /data/pig_ball/intermediate/shading_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/intermediate/shading_vis.png -------------------------------------------------------------------------------- /data/pig_ball/intermediate/vis_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/intermediate/vis_mask.jpg -------------------------------------------------------------------------------- /data/pig_ball/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/mask.png -------------------------------------------------------------------------------- /data/pig_ball/normal.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/normal.npy -------------------------------------------------------------------------------- /data/pig_ball/original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/original.png -------------------------------------------------------------------------------- /data/pig_ball/shading.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/shading.npy -------------------------------------------------------------------------------- /data/pig_ball/sim.yaml: -------------------------------------------------------------------------------- 1 | # general setting 2 | cat: pig_ball 3 | # simulation setting 4 | animation_frames: 16 5 | size: [512, 512] 6 | num_steps: 220 7 | # enable for simulation visualization 8 | display: false 9 | save_snapshot: false 10 | gravity: 980 11 | init_velocity: {} 12 | init_acc: 13 | 3: [-100, 0] 14 | # perception setting 15 | edge_list: [[[409, 154], [512, 154]], [[182, 186], [512, 186]], [[0, 512], [512, 512]]] 16 | obj_info: 17 | 1: 18 | label: ball 19 | mass: 100 20 | density: 21 | friction: 0.7 22 | elasticity: 0.8 23 | primitive: circle 24 | 3: 25 | label: toy 26 | mass: 27 | density: 1.0 28 | friction: 0.5 29 | elasticity: 0.3 30 | primitive: polygon 31 | # diffusion setting (optional) 32 | denoise_strength: 0.3 33 | -------------------------------------------------------------------------------- /data/pig_ball/vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/vis.png -------------------------------------------------------------------------------- /data/pool/depth.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/depth.npy -------------------------------------------------------------------------------- /data/pool/inpaint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/inpaint.png -------------------------------------------------------------------------------- /data/pool/intermediate/albedo_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/intermediate/albedo_vis.png -------------------------------------------------------------------------------- /data/pool/intermediate/depth_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/intermediate/depth_vis.png -------------------------------------------------------------------------------- /data/pool/intermediate/normal_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/intermediate/normal_vis.png -------------------------------------------------------------------------------- /data/pool/intermediate/obj_movable.json: -------------------------------------------------------------------------------- 1 | {"billiardball": true, "table": false} -------------------------------------------------------------------------------- /data/pool/intermediate/shading_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/intermediate/shading_vis.png -------------------------------------------------------------------------------- /data/pool/intermediate/vis_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/intermediate/vis_mask.jpg -------------------------------------------------------------------------------- /data/pool/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/mask.png -------------------------------------------------------------------------------- /data/pool/normal.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/normal.npy -------------------------------------------------------------------------------- /data/pool/original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/original.png -------------------------------------------------------------------------------- /data/pool/shading.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/shading.npy -------------------------------------------------------------------------------- /data/pool/sim.yaml: -------------------------------------------------------------------------------- 1 | # general setting 2 | cat: pool 3 | # simulation setting 4 | animation_frames: 16 5 | size: [512, 512] 6 | num_steps: 100 7 | # enable for simulation visualization 8 | display: false 9 | save_snapshot: false 10 | gravity: 0 11 | init_velocity: 12 | 1: [-100, 0] 13 | init_acc: 14 | 1: [-5882.35, 0.0] 15 | 16 | # perception setting 17 | edge_list: [[[1, 1], [1, 511]], [[1, 511], [511, 511]], [[511, 511], [511, 1]], [ 18 | [511, 1], [1, 1]]] 19 | obj_info: 20 | 1: 21 | label: billiard 22 | mass: 170 23 | density: 24 | friction: 0.3 25 | elasticity: 0.7 26 | primitive: circle 27 | 10: 28 | label: billiard 29 | mass: 170 30 | density: 31 | friction: 0.3 32 | elasticity: 0.7 33 | primitive: circle 34 | 11: 35 | label: billiard 36 | mass: 170 37 | density: 38 | friction: 0.3 39 | elasticity: 0.7 40 | primitive: circle 41 | 12: 42 | label: billiard 43 | mass: 170 44 | density: 45 | friction: 0.3 46 | elasticity: 0.7 47 | primitive: circle 48 | 13: 49 | label: billiard 50 | mass: 170 51 | density: 52 | friction: 0.3 53 | elasticity: 0.7 54 | primitive: circle 55 | 14: 56 | label: billiard 57 | mass: 170 58 | density: 59 | friction: 0.3 60 | elasticity: 0.7 61 | primitive: circle 62 | 15: 63 | label: billiard 64 | mass: 170 65 | density: 66 | friction: 0.3 67 | elasticity: 0.7 68 | primitive: circle 69 | 16: 70 | label: billiard 71 | mass: 170 72 | density: 73 | friction: 0.3 74 | elasticity: 0.7 75 | primitive: circle 76 | 2: 77 | label: billiard 78 | mass: 170 79 | density: 80 | friction: 0.3 81 | elasticity: 0.7 82 | primitive: circle 83 | 3: 84 | label: billiard 85 | mass: 170 86 | density: 87 | friction: 0.3 88 | elasticity: 0.7 89 | primitive: circle 90 | 4: 91 | label: billiard 92 | mass: 170 93 | density: 94 | friction: 0.3 95 | elasticity: 0.7 96 | primitive: circle 97 | 5: 98 | label: billiard 99 | mass: 170 100 | density: 101 | friction: 0.3 102 | elasticity: 0.7 103 | primitive: circle 104 | 6: 105 | label: billiard 106 | mass: 170 107 | density: 108 | friction: 0.3 109 | elasticity: 0.7 110 | primitive: circle 111 | 7: 112 | label: billiard 113 | mass: 170 114 | density: 115 | friction: 0.3 116 | elasticity: 0.7 117 | primitive: circle 118 | 8: 119 | label: billiard 120 | mass: 170 121 | density: 122 | friction: 0.3 123 | elasticity: 0.7 124 | primitive: circle 125 | 9: 126 | label: billiard 127 | mass: 170 128 | density: 129 | friction: 0.3 130 | elasticity: 0.7 131 | primitive: circle 132 | # diffusion setting (optional) 133 | denoise_strength: 0.2 134 | -------------------------------------------------------------------------------- /data/pool/vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/vis.png -------------------------------------------------------------------------------- /diffusion/video_diffusion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from omegaconf import OmegaConf 5 | import argparse 6 | from PIL import Image 7 | import torch 8 | import torch.nn.functional as F 9 | import torchvision 10 | from torchvision import transforms 11 | torch.backends.cuda.matmul.allow_tf32 = True 12 | torch.backends.cudnn.allow_tf32 = True 13 | from einops import rearrange 14 | from tqdm.auto import tqdm 15 | from diffusers.models import AutoencoderKL 16 | from diffusers.utils.import_utils import is_xformers_available 17 | 18 | 19 | sys.path.insert(0, "SEINE") 20 | from datasets import video_transforms 21 | from diffusion import create_diffusion 22 | from models.clip import TextEmbedder 23 | from models import get_models 24 | from utils import mask_generation_before 25 | 26 | 27 | def prepare_input(relight_video_path, mask_video_path, 28 | image_h, image_w, latent_h, latent_w, device, use_fp16): 29 | relight_video = torch.load(relight_video_path).to(device) # (f, 3, H, W) 30 | if relight_video.max() > 1.: 31 | relight_video = relight_video / 255.0 32 | if relight_video.shape[1] !=3: # (f, H, W, 3) -> (f, 3, H, W) 33 | relight_video = relight_video.permute(0, 3, 1, 2) # (f, 3, H, W) 34 | if relight_video.shape[-2:] != (image_h, image_w): 35 | relight_video = F.interpolate(relight_video, size=(image_h, image_w)) 36 | relight_video = 2 * relight_video - 1 # [-1, 1] 37 | if use_fp16: 38 | relight_video = relight_video.to(dtype=torch.float16) 39 | 40 | mask_video = torch.load(mask_video_path).to(device).float() # (f, 1, H, W) 41 | if mask_video.ndim == 3: # (f, H, W) -> (f, 4, H, W) 42 | mask_video = mask_video.unsqueeze(1).repeat(1, 4, 1, 1) 43 | elif mask_video.ndim == 4 and mask_video.shape[1] == 1: # (f, 1, H, W) -> (f, 4, H, W) 44 | mask_video = mask_video.repeat(1, 4, 1, 1) 45 | elif mask_video.ndim == 4 and mask_video.shape[0] == 1: # (1, f, H, W) -> (f, 4, H, W) 46 | mask_video = mask_video.repeat(4, 1, 1, 1).permute(1, 0, 2, 3) # (f, 4, H, W) 47 | if mask_video.shape[-2:] != (latent_h, latent_w): 48 | mask_video = F.interpolate(mask_video, size=(latent_h, latent_w)) # (f, 4, h, w) 49 | mask_video = mask_video > 0.5 50 | mask_video = rearrange(mask_video, 'f c h w -> c f h w').contiguous() 51 | mask_video = mask_video.unsqueeze(0) # (1, 4, f, h, w) 52 | return relight_video, mask_video 53 | 54 | 55 | def get_input(image_path, image_h, image_w, mask_type="first1", num_frames=16): 56 | transform_video = transforms.Compose([ 57 | video_transforms.ToTensorVideo(), 58 | video_transforms.ResizeVideo((image_h, image_w)), 59 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]) 60 | print(f'loading video from {image_path}') 61 | _, full_file_name = os.path.split(image_path) 62 | file_name, extension = os.path.splitext(full_file_name) 63 | if extension == '.jpg' or extension == '.jpeg' or extension == '.png': 64 | print("loading the input image") 65 | video_frames = [] 66 | num = int(mask_type.split('first')[-1]) 67 | first_frame = torch.as_tensor(np.array(Image.open(image_path), dtype=np.uint8, copy=True)).unsqueeze(0) 68 | for i in range(num): 69 | video_frames.append(first_frame) 70 | num_zeros = num_frames-num 71 | for i in range(num_zeros): 72 | zeros = torch.zeros_like(first_frame) 73 | video_frames.append(zeros) 74 | n = 0 75 | video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w 76 | video_frames = transform_video(video_frames) 77 | return video_frames, n 78 | 79 | 80 | 81 | if __name__ == "__main__": 82 | parser = argparse.ArgumentParser() 83 | # Model checkpoint 84 | parser.add_argument("--ckpt", default="SEINE/pretrained/seine.pt") 85 | parser.add_argument("--pretrained_model_path", default="SEINE/pretrained/stable-diffusion-v1-4") 86 | # Model config 87 | parser.add_argument('--model', type=str, default='UNet', help='Model architecture to use.') 88 | parser.add_argument('--num_frames', type=int, default=16, help='Number of frames to process.') 89 | parser.add_argument('--image_size', type=int, nargs=2, default=[512, 512], help='Resolution of the input images.') 90 | 91 | # Model speedup config 92 | parser.add_argument('--use_fp16', type=bool, default=True, help='Use FP16 for faster inference. Set to False if debugging with video loss.') 93 | parser.add_argument('--enable_xformers_memory_efficient_attention', type=bool, default=True, help='Enable xformers memory efficient attention.') 94 | 95 | # Sample config 96 | parser.add_argument('--seed', type=int, default=None, help='Random seed for reproducibility.') 97 | parser.add_argument('--run_time', type=int, default=13, help='Run time of the model.') 98 | parser.add_argument('--cfg_scale', type=float, default=8.0, help='Configuration scale factor.') 99 | parser.add_argument('--sample_method', type=str, default='ddim', choices=['ddim', 'ddpm'], help='Sampling method to use.') 100 | parser.add_argument('--do_classifier_free_guidance', type=bool, default=True, help='Enable classifier-free guidance.') 101 | parser.add_argument('--mask_type', type=str, default="first1", help='Type of mask to use.') 102 | parser.add_argument('--use_mask', type=bool, default=True, help='Whether to use a mask.') 103 | parser.add_argument('--num_sampling_steps', type=int, default=50, help='Number of sampling steps to perform.') 104 | parser.add_argument('--prompt', type=str, default=None, help='input prompt') 105 | parser.add_argument('--negative_prompt', type=str, default="", help='Negative prompt to use.') 106 | 107 | # Video diffusion config 108 | parser.add_argument('--denoise_strength', type=float, default=0.4, help='Denoise strength parameter.') 109 | parser.add_argument('--stop_idx', type=int, default=5, help='Stop index for the diffusion process.') 110 | parser.add_argument('--perception_input', type=str, default="../data/pool", help='input dir') 111 | parser.add_argument('--previous_output', type=str, default="../outputs/pool", help='previous output dir') 112 | 113 | # Parse the arguments 114 | args = parser.parse_args() 115 | 116 | output = args.previous_output 117 | os.makedirs(output, exist_ok=True) 118 | 119 | if args.seed: 120 | torch.manual_seed(args.seed) 121 | torch.set_grad_enabled(False) 122 | device = "cuda" if torch.cuda.is_available() else "cpu" 123 | if args.ckpt is None: 124 | raise ValueError("Please specify a checkpoint path using --ckpt ") 125 | 126 | latent_h = args.image_size[0] // 8 127 | latent_w = args.image_size[1] // 8 128 | image_h = args.image_size[0] 129 | image_w = args.image_size[1] 130 | latent_h = latent_h 131 | latent_w = latent_w 132 | print('loading model') 133 | model = get_models(args).to(device) 134 | 135 | if args.enable_xformers_memory_efficient_attention: 136 | if is_xformers_available(): 137 | model.enable_xformers_memory_efficient_attention() 138 | else: 139 | raise ValueError("xformers is not available. Make sure it is installed correctly") 140 | 141 | ckpt_path = args.ckpt 142 | state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema'] 143 | model.load_state_dict(state_dict) 144 | print('loading succeed') 145 | model.eval() 146 | 147 | pretrained_model_path = args.pretrained_model_path 148 | diffusion = create_diffusion(str(args.num_sampling_steps)) 149 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device) 150 | text_encoder = TextEmbedder(pretrained_model_path).to(device) 151 | 152 | sim_config_path = os.path.join(args.perception_input, "sim.yaml") 153 | config = OmegaConf.load(sim_config_path) 154 | objects = config.obj_info 155 | object_names = [] 156 | for seg_id in objects: 157 | name = objects[seg_id]['label'] 158 | object_names.append(name) 159 | object_names = list(set(object_names)) 160 | if args.prompt is None: 161 | prompt = ", ".join(object_names) 162 | else: 163 | prompt = args.prompt 164 | print(f"input prompt: {prompt}") 165 | denoise_strength = getattr(config, 'denoise_strength', args.denoise_strength) 166 | 167 | if args.use_fp16: 168 | print('Warnning: using half percision for inferencing!') 169 | vae.to(dtype=torch.float16) 170 | model.to(dtype=torch.float16) 171 | text_encoder.to(dtype=torch.float16) 172 | 173 | 174 | relight_video_path = os.path.join(args.previous_output, "relight.pt") 175 | mask_video_path = os.path.join(args.previous_output, "mask_video.pt") 176 | relight_video, mask_video = prepare_input(relight_video_path, mask_video_path, image_h, image_w, latent_h, latent_w, device, args.use_fp16) 177 | 178 | with torch.no_grad(): 179 | ref_latent = vae.encode(relight_video).latent_dist.sample().mul_(0.18215) # (f, 4, h, w) 180 | ref_latent = ref_latent.permute(1, 0, 2, 3).contiguous().unsqueeze(0) # (1, 4, f, h, w) 181 | 182 | image_path = os.path.join(args.perception_input, "original.png") 183 | video, reserve_frames = get_input(image_path, image_h, image_w, args.mask_type, args.num_frames) 184 | video_input = video.unsqueeze(0).to(device) # b,f,c,h,w 185 | b,f,c,h,w=video_input.shape 186 | mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w [1, 16, 3, 512, 512]) 187 | masked_video = video_input * (mask == 0) 188 | 189 | if args.use_fp16: 190 | masked_video = masked_video.to(dtype=torch.float16) 191 | mask = mask.to(dtype=torch.float16) 192 | 193 | masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous() 194 | masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215) 195 | masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous() 196 | mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1) 197 | 198 | if args.do_classifier_free_guidance: 199 | masked_video = torch.cat([masked_video] * 2) 200 | mask = torch.cat([mask] * 2) 201 | prompt_all = [prompt] + [args.negative_prompt] 202 | else: 203 | masked_video = masked_video 204 | mask = mask 205 | prompt_all = [prompt] 206 | 207 | text_prompt = text_encoder(text_prompts=prompt_all, train=False) 208 | model_kwargs = dict(encoder_hidden_states=text_prompt, 209 | class_labels=None, 210 | cfg_scale=args.cfg_scale, 211 | use_fp16=args.use_fp16) # tav unet 212 | 213 | indices = list(range(diffusion.num_timesteps))[::-1] 214 | noise_level = int(denoise_strength * diffusion.num_timesteps) 215 | indices = indices[-noise_level:] 216 | 217 | latent = diffusion.q_sample(ref_latent, torch.tensor([indices[0]], device=device)) 218 | latent = torch.cat([latent] * 2) 219 | stop_idx = args.stop_idx 220 | 221 | for idx, i in tqdm(enumerate(indices)): 222 | t = torch.tensor([indices[idx]] * masked_video.shape[0], device=device) 223 | with torch.no_grad(): 224 | out = diffusion.ddim_sample( 225 | model.forward_with_cfg, 226 | latent, 227 | t, 228 | clip_denoised=False, 229 | denoised_fn=None, 230 | cond_fn=None, 231 | model_kwargs=model_kwargs, 232 | eta=0.0, 233 | mask=mask, 234 | x_start=masked_video, 235 | use_concat=args.use_mask, 236 | ) 237 | 238 | # update latent 239 | latent = out["sample"] 240 | if idx < len(indices)-stop_idx: 241 | x = diffusion.q_sample(ref_latent, torch.tensor([indices[idx+1]], device=device)) 242 | pred_xstart = out["pred_xstart"] 243 | 244 | weight = min(idx / len(indices), 1.0) 245 | latent = (1 - mask_video.float()) * latent + mask_video.float() * ((1 - weight) * x + weight * latent) 246 | 247 | if args.use_fp16: 248 | latent = latent.to(dtype=torch.float16) 249 | latent = latent[0].permute(1, 0, 2, 3).contiguous() # (f, 4, h, w) 250 | video_output = vae.decode(latent / 0.18215).sample 251 | video_ = ((video_output * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1) 252 | save_video_path = os.path.join(output, f'final_video.mp4') 253 | torchvision.io.write_video(save_video_path, video_, fps=7) 254 | print("all done!") -------------------------------------------------------------------------------- /perception/README.md: -------------------------------------------------------------------------------- 1 | # Perception 2 | 3 | ## Segmentation 4 | 5 | ### GPT-based Recognize Anything 6 | - We use GPT-4V to recognize objects in the image as input prompt for [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). Alternatively, you could use any other object recognition model (e.g. [RAM](https://github.com/xinyu1205/recognize-anything)) to get the objects in the given image. 7 | 8 | - We put pre-computed GPT-4V result under each `data/${name}/obj_movable.json`. You could skip below and run [segmentation](#grounded-segment-anything) if you don't want to re-run GPT-4V. 9 | 10 | - Copy the OpenAI API key into `gpt/gpt_configs/my_apikey`. 11 | 12 | - Install requirements 13 | ```bash 14 | pip install inflect openai==0.28 15 | ``` 16 | - Run GPT-4V based RAM 17 | ```Shell 18 | python gpt_ram.py --img_path ../data/${name} 19 | ``` 20 | - The default `save_path` is saved under same folder as input `../data/${name}/intermediate/obj_movable.json`. The output is a list in json format. 21 | 22 | ```shell 23 | [ 24 | {"obj_1": True # True if the object is movable or False if not}, 25 | {"obj_2": True}, 26 | ... 27 | ] 28 | ``` 29 | 30 | ### Grounded-Segment-Anything 31 | We use [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/753dd6675ea7935af401a983e88d159629ad4d5b) to segment the input image given the prompts input. We use the earlier checkout version for the paper. You could adapt the code to the latest version of [Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything) or [Grounded-SAM-2](https://github.com/IDEA-Research/Grounded-SAM-2). 32 | 33 | - Follow the [Grounded-SAM setup](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/753dd6675ea7935af401a983e88d159629ad4d5b?tab=readme-ov-file#install-without-docker) 34 | ```Shell 35 | cd Grounded-Segment-Anything/ 36 | git checkout 753dd6675ea7935af401a983e88d159629ad4d5b 37 | 38 | # Follow Grounded-SAM readme to install requirements 39 | 40 | # Download pretrained weights to current folder 41 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 42 | wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth 43 | 44 | ``` 45 | - Segmentation requires the input image `input` and a prompt `prompts_path` for each object in the image. The default prompt path is `../data/${name}/intermediate/obj_movable.json`. 46 | 47 | ```Shell 48 | python run_gsam.py --input ../data/${name} 49 | ``` 50 | - The default `output` is saved under the same folder as input `../data/${name}` and visualizations under `../data/${name}/intermediate` as follows: 51 | ```Shell 52 | image folder/ 53 | ├── intermediate/ 54 | ├── mask.png # FG and BG segmentation mask 55 | ├── mask.json # segmentation id and and object name, movability 56 | ├── vis_mask.jpg # segmentation visualization 57 | ``` 58 | | **Pool** | **Domino** | **Pig Ball** | **Balls** 59 | |:---------:|:----------------:|:----------:| :----------:| 60 | | pool | domino | pig_ball| balls | 61 | | pool | domino | pig_ball| balls | 62 | 63 | 64 | ## Depth and Normal Estimation 65 | - We use [GeoWizard](https://github.com/fuxiao0719/GeoWizard) to estimate depth and normal of input image. Follow [GeoWizard setup](https://github.com/fuxiao0719/GeoWizard/blob/main/README.md#%EF%B8%8F-setup) to install requirements. Recommend to create a new conda environment. 66 | 67 | - Run GeoWizard on input image 68 | ```Shell 69 | python run_depth_normal.py --input ../data/${name} --output ../outputs/${name} --vis 70 | ``` 71 | - Depth and normal are saved in `outputs/${name}`. Visualizations are saved in `outputs/${name}/intermediate`. 72 | ```Shell 73 | image folder/ 74 | ├──depth.npy 75 | ├── normal.npy 76 | ├── intermediate/ 77 | ├── depth_vis.png 78 | ├── normal_vis.png 79 | ``` 80 | 81 | | **Input** | **Normal** | **Depth** 82 | |:---------:|:----------------:|:----------:| 83 | | input | normal | normal | 84 | 85 | 86 | ## Foreground / Background & Edge Detection 87 | - We separate the foreground and background using the segmentation mask. Foreground objects with complete masks are used for physics reasoning and simulation, while truncated objects are treated as static. We use edges from static objects and the background as physical boundaries for simulation. 88 | - Module requires the following input for each image: 89 | ```Shell 90 | image folder/ 91 | ├── depth.npy 92 | ├──normal.npy 93 | ├── original.png # optional:for visualization only 94 | ├── intermediate/ 95 | ├── mask.png # complete image segmentation mask 96 | ├── mask.json # segmentation id and and object name, movability 97 | ``` 98 | - Run foreground/background separation and edge detection 99 | ```Shell 100 | python run_fg_bg.py --input ../data/${name} --vis_edge 101 | ``` 102 | - The default output is saved under the same folder as input `../data/${name}`, contains the final foreground objects mask `mask.png` and edge list `edges.json` saved in `outputs/${name}`. 103 | ```Shell 104 | image folder/ 105 | ├── mask.png # final mask 106 | ├── edges.json 107 | ├── intermediate/ 108 | ├── edge_vis.png # red line for edges 109 | ├── fg_mask_vis.png # text is the segmentation id 110 | ├── bg_mask_vis.png 111 | ├── bg_mask.png 112 | ``` 113 | 114 | | **Input** | **Foreground** | **Background** | **Edges** 115 | |:---------:|:----------------:|:----------:| :----------:| 116 | | input | mask | edges | edges | 117 | - We could simulate all foreground objects by specifying their velocity and acceleration use the segmentation id in simulation. 118 | 119 | 120 | 121 | ## Inpainting 122 | We use [Inpaint-Anything](https://github.com/geekyutao/Inpaint-Anything) to inpaint the background of input image. You could adapt the code to any other latest inpainting model. 123 | 124 | - Follow [Inpaint-Anything setup](https://github.com/geekyutao/Inpaint-Anything?tab=readme-ov-file#installation) to install requirements and download the [pretrained model](https://drive.google.com/drive/folders/1wpY-upCo4GIW4wVPnlMh_ym779lLIG2A). Recommend to create a new conda environment. 125 | 126 | ```Shell 127 | python -m pip install torch torchvision torchaudio 128 | python -m pip install -e segment_anything 129 | python -m pip install -r lama/requirements.txt 130 | # Download pretrained model under Inpaint-Anything/pretrained_models/ 131 | ``` 132 | 133 | - Inpainting requires the input image `../data/${name}/original.png` and a foreground mask of `../data/${name}/mask.png` under the same folder. 134 | ```Shell 135 | python run_inpaint.py --input ../data/${name} --output ../outputs/${name} --dilate_kernel_size 20 136 | ``` 137 | `dilate_kernel_size` could be adjusted in the above script. For heavy shadow image, increase `dilate_kernel_size` to get better inpainting results. 138 | 139 | - The output `inpaint.png` is saved in `outputs/${name}`. 140 | 141 | | **Input** | **Inpainting** 142 | |:---------:|:----------------:| 143 | | input | inpainting | 144 | 145 | 146 | ## Physics Reasoning 147 | 148 | - Install requirements 149 | ```bash 150 | pip install openai==0.28 ruamel.yaml 151 | ``` 152 | - Copy the OpenAI API key into `gpt/gpt_configs/my_apikey`. 153 | 154 | - Physics reasoning requires the following input for each image: 155 | ```Shell 156 | image folder/ 157 | ├── original.png 158 | ├── mask.png # movable segmentation mask 159 | ``` 160 | - Run GPT-4V physical property reasoning by the following command: 161 | ```Shell 162 | python gpt_physic.py --input ../data/${name} --output ../outputs/${name} 163 | ``` 164 | 165 | - The output `physics.yaml` contains the physical properties and primitive shape of each object segment in the image. Note GPT-4V outputs may vary for different runs and differ from the original setting in `data/${name}/sim.yaml`. Users could adjust accordingly to each run output. 166 | 167 | 168 | ## Albedo and Shading Estimation 169 | - We use [Intrinsic](https://github.com/compphoto/Intrinsic/tree/d9741e99b2997e679c4055e7e1f773498b791288) to infer albedo and shading of input image. Follow [Intrinsic setup](https://github.com/compphoto/Intrinsic/tree/d9741e99b2997e679c4055e7e1f773498b791288?tab=readme-ov-file#setup) to install requirements. Recommend to create a new conda environment. 170 | ``` 171 | git clone https://github.com/compphoto/Intrinsic 172 | cd Intrinsic/ 173 | git checkout d9741e99b2997e679c4055e7e1f773498b791288 174 | pip install . 175 | ``` 176 | 177 | - Run Intrinsic decomposition on input image 178 | ```Shell 179 | python run_albedo_shading.py --input ../data/${name} --output ../outputs/${name} --vis 180 | ``` 181 | 182 | - `shading.npy` are saved in `outputs/${name}`. Visualization of albedo and shading are saved in `outputs/${name}/intermediate`. 183 | 184 | | **Input** | **Albedo** | **Shading** 185 | |:---------:|:----------------:|:----------:| 186 | | input | albedo | shading | 187 | 188 | - [Intrinsic](https://github.com/compphoto/Intrinsic) has released updated trained model with better results. Feel free to use the updated model or any other model for better performance. 189 | -------------------------------------------------------------------------------- /perception/gpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/perception/gpt/__init__.py -------------------------------------------------------------------------------- /perception/gpt/gpt_configs/my_apikey: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/perception/gpt/gpt_configs/my_apikey -------------------------------------------------------------------------------- /perception/gpt/gpt_configs/physics/user.txt: -------------------------------------------------------------------------------- 1 | You will be given an image and a binary mask specifying an object on the image, analyze and provide your final answer of the object physical property. The query object will be enclosed in white mask. The physical property includes the mass, the friction and elasticity. The mass is in grams. The friction uses the Coulomb friction model, a value of 0.0 is frictionless. The elasticity value of 0.0 gives no bounce, while a value of 1.0 will give a perfect bounce. 2 | 3 | Format Requirement: 4 | You must provide your answer in the following JSON format, as it will be parsed by a code script later. Your answer must look like: 5 | { 6 | "mass": number, 7 | "friction": number, 8 | "elasticity": number 9 | } 10 | The answer should be one exact number for each property, do not include any other text in your answer, as it will be parsed by a code script later. -------------------------------------------------------------------------------- /perception/gpt/gpt_configs/ram/user.txt: -------------------------------------------------------------------------------- 1 | Describe all unique object categories in the given image, ensuring all pixels are included and assigned to one of the categories, do not miss any movable or static object appeared in the image, each category name is a single word and in singular noun format, do not include '-' in the name. Different categories should not be repeated or overlapped with each other in the image. For each category, judge if the instances in the image is movable, the answer is True or False. If there are multiple instances of the same category in the image, the judgement is True only if the object category satisfies the following requirements: 1. The object category is things (objects with a well-defined shape, e.g. car, person) and not stuff (amorphous background regions, e.g. grass, sky, largest segmentation component). 2. All instances in the image of this category are movable with complete shape and fully-visible. 2 | 3 | Format Requirement: 4 | You must provide your answer in the following JSON format, as it will be parsed by a code script later. Your answer must look like: 5 | { 6 | "category-1": False, 7 | "category-2": True 8 | 9 | } 10 | Do not include any other text in your answer. Do not include unnecessary words besides the category name and True/False values. -------------------------------------------------------------------------------- /perception/gpt/gpt_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from re import DOTALL, finditer 4 | import base64 5 | from io import BytesIO 6 | from PIL import Image 7 | from imantics import Mask 8 | 9 | 10 | def fit_polygon_from_mask(mask): 11 | # return polygon vertices 12 | polygons = Mask(mask).polygons() 13 | if len(polygons.points) > 1: # find the largest polygon 14 | areas = [] 15 | for points in polygons.points: 16 | points = points.reshape((-1, 1, 2)).astype(np.int32) 17 | img = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) 18 | img = cv2.fillPoly(img, [points], color=[0, 255, 0]) 19 | mask = img[:, :, 1] > 0 20 | area = np.count_nonzero(mask) 21 | areas.append(area) 22 | areas = np.array(areas) 23 | largest_idx = np.argmax(areas) 24 | points = polygons.points[largest_idx] 25 | else: 26 | points = polygons.points[0] 27 | return points 28 | 29 | 30 | # same function in simulation/sim_utils.py 31 | def fit_circle_from_mask(mask_image): 32 | contours, _ = cv2.findContours(mask_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 33 | 34 | if len(contours) == 0: 35 | print("No contours found in the mask image.") 36 | return None 37 | max_contour = max(contours, key=cv2.contourArea) 38 | (x, y), radius = cv2.minEnclosingCircle(max_contour) 39 | 40 | center = (x, y) 41 | return center, radius 42 | 43 | 44 | def is_mask_truncated(mask): 45 | if np.any(mask[0, :] == 1) or np.any(mask[-1, :] == 1): # Top or bottom rows 46 | return True 47 | if np.any(mask[:, 0] == 1) or np.any(mask[:, -1] == 1): # Left or right columns 48 | return True 49 | return False 50 | 51 | 52 | def compute_iou(mask1, mask2): 53 | intersection = np.logical_and(mask1, mask2).sum() 54 | union = np.logical_or(mask1, mask2).sum() 55 | iou = intersection / union if union > 0 else 0.0 56 | 57 | return iou 58 | 59 | 60 | def find_json_response(full_response): 61 | extracted_responses = list( 62 | finditer(r"({[^}]*$|{.*})", full_response, flags=DOTALL) 63 | ) 64 | 65 | if not extracted_responses: 66 | print( 67 | f"Unable to find any responses of the matching type dictionary: `{full_response}`" 68 | ) 69 | return None 70 | 71 | if len(extracted_responses) > 1: 72 | print("Unexpected response > 1, continuing anyway...", extracted_responses) 73 | 74 | extracted_response = extracted_responses[0] 75 | extracted_str = extracted_response.group(0) 76 | return extracted_str 77 | 78 | 79 | def encode_image(image_path): 80 | with open(image_path, "rb") as image_file: 81 | return base64.b64encode(image_file.read()).decode("utf-8") -------------------------------------------------------------------------------- /perception/gpt_physic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ast 3 | from time import sleep 4 | import cv2 5 | import numpy as np 6 | import pymunk 7 | import openai 8 | from ruamel.yaml import YAML 9 | from ruamel.yaml.comments import CommentedMap 10 | 11 | from gpt.gpt_utils import find_json_response, encode_image, fit_polygon_from_mask, fit_circle_from_mask, compute_iou, is_mask_truncated 12 | 13 | 14 | class GPTV_physics: 15 | def __init__(self, query_prompt="gpt/gpt_configs/physics/user.txt", retry_limit=0): 16 | with open(query_prompt, "r") as file: 17 | self.query = file.read().strip() 18 | self.retry_limit = retry_limit 19 | 20 | def call(self, image_path, mask, max_tokens=300, tmp_dir="./"): 21 | mask_path = os.path.join(tmp_dir, "tmp_msk.png") 22 | mask = (mask * 255).astype(np.uint8) 23 | cv2.imwrite(mask_path, mask) 24 | query_image = encode_image(image_path) 25 | mask_image = encode_image(mask_path) 26 | 27 | try_count = 0 28 | while True: 29 | response = openai.ChatCompletion.create( 30 | model="gpt-4-vision-preview", 31 | messages=[{ 32 | "role": "system", 33 | "content": self.query 34 | }, 35 | { 36 | "role": "user", 37 | "content": [ 38 | { 39 | "type": "image_url", 40 | "image_url": { 41 | "url": f"data:image/{image_path[-3:]};base64,{query_image}" 42 | } 43 | }, 44 | { 45 | "type": "image_url", 46 | "image_url": { 47 | "url": f"data:image/{mask_path[-3:]};base64,{mask_image}" 48 | } 49 | }, 50 | ] 51 | } 52 | ], 53 | seed=100, 54 | max_tokens=max_tokens, 55 | ) 56 | response = response["choices"][0]["message"]["content"] 57 | try: 58 | result = find_json_response(response) 59 | result = ast.literal_eval(result.replace(' ', '').replace('\n', '')) 60 | break 61 | except: 62 | print(f"Unknown response: {response}") 63 | try_count += 1 64 | if try_count > self.retry_limit: 65 | raise ValueError(f"Over Limit: Unknown response: {response}") 66 | else: 67 | print("Retrying after 1s.") 68 | sleep(1) 69 | os.remove(mask_path) 70 | return result 71 | 72 | 73 | if __name__ == "__main__": 74 | import argparse 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument("--input", type=str, default="../data/pig_ball") 77 | parser.add_argument("--output", type=str, default="../outputs/pig_ball") 78 | parser.add_argument("--apikey_path", type=str, default="gpt/gpt_configs/my_apikey") 79 | args = parser.parse_args() 80 | 81 | with open(args.apikey_path, "r") as file: 82 | apikey = file.read().strip() 83 | 84 | openai.api_key = apikey 85 | gpt = GPTV_physics(query_prompt="gpt/gpt_configs/physics/user.txt") 86 | 87 | image_path = os.path.join(args.input, "original.png") 88 | mask_path = os.path.join(args.input, "mask.png") 89 | save_dir = args.output 90 | os.makedirs(save_dir, exist_ok=True) 91 | save_path = os.path.join(save_dir, "physics.yaml") 92 | 93 | seg_mask = cv2.imread(mask_path, 0) 94 | seg_ids = np.unique(seg_mask) 95 | obj_info_list = {} 96 | for seg_id in seg_ids: 97 | if seg_id == 0: 98 | continue 99 | obj_info = {} 100 | mask = (seg_mask == seg_id) 101 | # fit primitive 102 | center, radius = fit_circle_from_mask((mask * 255).astype(np.uint8)) 103 | center = tuple(map(int, center)) 104 | radius = int(radius) 105 | pred_mask = np.zeros(mask.shape, dtype=np.uint8) 106 | cv2.circle(pred_mask, center, radius, (255), thickness=-1) 107 | area = np.count_nonzero(mask) 108 | pred_mask = pred_mask > 0 109 | iou = compute_iou(mask, pred_mask) 110 | if iou > 0.85: 111 | obj_info["primitive"] = 'circle' 112 | else: 113 | if is_mask_truncated(mask): 114 | continue 115 | obj_info["primitive"] = 'polygon' 116 | points = fit_polygon_from_mask(mask) 117 | points = tuple(map(tuple, points)) 118 | polygon = pymunk.Poly(None, points) 119 | area = polygon.area 120 | 121 | result = gpt.call(image_path, mask) 122 | for key in result: 123 | if key == "mass": 124 | if obj_info['primitive'] == 'polygon': 125 | density = result["mass"] / area 126 | obj_info["mass"] = None 127 | obj_info["density"] = density 128 | else: 129 | obj_info["mass"] = result["mass"] 130 | obj_info["density"] = None 131 | else: 132 | obj_info[key] = result[key] 133 | 134 | obj_info_list[int(seg_id)] = obj_info 135 | 136 | yaml = YAML() 137 | yaml_data = CommentedMap() 138 | 139 | yaml_data['obj_info'] = obj_info_list 140 | yaml_data.yaml_set_comment_before_after_key('obj_info', before="physics properties of each object") 141 | 142 | with open(save_path, 'w') as yaml_file: 143 | yaml.dump(yaml_data, yaml_file) 144 | 145 | print(f"GPT-4V Physics reasoning results saved to {save_path}") -------------------------------------------------------------------------------- /perception/gpt_ram.py: -------------------------------------------------------------------------------- 1 | import os 2 | import openai 3 | import ast 4 | from time import sleep 5 | import json 6 | from gpt.gpt_utils import find_json_response, encode_image 7 | 8 | 9 | class GPTV_ram: 10 | def __init__(self, query_prompt="gpt/gpt_configs/movable/user.txt", retry_limit=3): 11 | 12 | with open(query_prompt, "r") as file: 13 | self.query = file.read().strip() 14 | self.retry_limit = retry_limit 15 | 16 | 17 | def call(self, image_path, max_tokens=300): 18 | query_image = encode_image(image_path) 19 | try_count = 0 20 | while True: 21 | response = openai.ChatCompletion.create( 22 | model="gpt-4-vision-preview", 23 | messages=[{ 24 | "role": "system", 25 | "content": self.query 26 | }, 27 | { 28 | "role": "user", 29 | "content": [ 30 | { 31 | "type": "text", 32 | "text": f"{self.query}" 33 | }, 34 | { 35 | "type": "image_url", 36 | "image_url": { 37 | "url": f"data:image/{image_path[-3:]};base64,{query_image}" 38 | } 39 | }, 40 | ] 41 | } 42 | ], 43 | seed=100, 44 | max_tokens=max_tokens, 45 | ) 46 | response = response["choices"][0]["message"]["content"] 47 | try: 48 | result = find_json_response(response) 49 | result = ast.literal_eval(result.replace(' ', '').replace('\n', '')) 50 | break 51 | except: 52 | try_count += 1 53 | if try_count > self.retry_limit: 54 | raise ValueError(f"Over Limit: Unknown response: {response}") 55 | else: 56 | print("Retrying after 1s.") 57 | sleep(1) 58 | return result 59 | 60 | if __name__ == "__main__": 61 | import argparse 62 | import inflect 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument("--img_path", type=str, default="../data/domino/original.png") 65 | parser.add_argument("--save_path", type=str, default=None) 66 | parser.add_argument("--apikey_path", type=str, default="gpt/gpt_configs/my_apikey") 67 | args = parser.parse_args() 68 | 69 | with open(args.apikey_path, "r") as file: 70 | apikey = file.read().strip() 71 | 72 | openai.api_key = apikey 73 | gpt = GPTV_ram(query_prompt="gpt/gpt_configs/ram/user.txt") 74 | result = gpt.call(args.img_path) 75 | 76 | if args.save_path is None: 77 | save_dir = os.path.join(os.path.dirname(args.img_path), "intermediate") 78 | os.makedirs(save_dir, exist_ok=True) 79 | save_path = os.path.join(save_dir, "obj_movable.json") 80 | else: 81 | save_path = args.save_path 82 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 83 | 84 | p = inflect.engine() 85 | for obj_name in result: 86 | singular = p.singular_noun(obj_name) 87 | if singular: 88 | result[singular] = result.pop(obj_name) 89 | else: 90 | continue 91 | 92 | with open(save_path, "w") as file: 93 | json.dump(result, file) 94 | 95 | print("result:", result) 96 | print(f"GPT4V image movable objects results saved to {save_path}") 97 | -------------------------------------------------------------------------------- /perception/run_albedo_shading.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import glob 4 | import numpy as np 5 | from tqdm.auto import tqdm 6 | from PIL import Image 7 | 8 | 9 | from intrinsic.pipeline import run_pipeline 10 | from intrinsic.model_util import load_models 11 | 12 | 13 | def load_image(path, bits=8): 14 | np_arr = np.array(Image.open(path)).astype(np.single) 15 | return np_arr / float((2 ** bits) - 1) 16 | 17 | 18 | def np_to_pil(img, bits=8): 19 | if bits == 8: 20 | int_img = (img * 255).astype(np.uint8) 21 | if bits == 16: 22 | int_img = (img * ((2 ** 16) - 1)).astype(np.uint16) 23 | 24 | return Image.fromarray(int_img) 25 | 26 | 27 | def view_scale(img, p=100): 28 | return (img / np.percentile(img, p)).clip(0, 1) 29 | 30 | 31 | def view(img, p=100): 32 | return view_scale(img ** (1/2.2), p=p) 33 | 34 | 35 | def uninvert(x, eps=0.001, clip=True): 36 | if clip: 37 | x = x.clip(eps, 1.0) 38 | 39 | out = (1.0 / x) - 1.0 40 | return out 41 | 42 | 43 | if __name__ == '__main__': 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("--input", type=str, default="../data/pig_ball", help="Input image or directory.") 46 | parser.add_argument("--output", type=str, default="../outputs/pig_ball", help="Output directory.") 47 | parser.add_argument("--vis", action="store_true", help="Visualize the results.") 48 | 49 | args = parser.parse_args() 50 | 51 | if os.path.isfile(args.input): 52 | image_path = args.input 53 | else: 54 | image_path = os.path.join(args.input, "original.png") 55 | output = args.output 56 | if output is None: 57 | output = args.input if os.path.isdir(args.input) else os.path.dirname(args.input) 58 | os.makedirs(output, exist_ok=True) 59 | 60 | model = load_models('paper_weights') 61 | 62 | # Read input image 63 | img = load_image(image_path, bits=8) 64 | # run the model on the image using R_0 resizing 65 | 66 | results = run_pipeline(model, img, resize_conf=0.0, maintain_size=True) 67 | 68 | albedo = results['albedo'] 69 | inv_shd = results['inv_shading'] 70 | 71 | shd = uninvert(inv_shd) 72 | shd_save_path = os.path.join(output, "shading.npy") 73 | np.save(shd_save_path, shd) 74 | 75 | if args.vis: 76 | intermediate_dir = os.path.join(output, "intermediate") 77 | os.makedirs(intermediate_dir, exist_ok=True) 78 | 79 | alb_save_path = os.path.join(intermediate_dir, "albedo.npy") 80 | np.save(alb_save_path, albedo) 81 | np_to_pil(albedo).save(os.path.join(intermediate_dir, 'albedo_vis.png')) 82 | np_to_pil(view(shd)).save(os.path.join(intermediate_dir, 'shading_vis.png')) 83 | 84 | 85 | -------------------------------------------------------------------------------- /perception/run_depth_normal.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import logging 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from tqdm.auto import tqdm 9 | 10 | 11 | BASE_DIR = "GeoWizard/geowizard" 12 | sys.path.append(os.path.join(BASE_DIR)) 13 | 14 | from models.geowizard_pipeline import DepthNormalEstimationPipeline 15 | from utils.seed_all import seed_all 16 | 17 | 18 | if __name__=="__main__": 19 | 20 | logging.basicConfig(level=logging.INFO) 21 | 22 | parser = argparse.ArgumentParser( 23 | description="Run MonoDepthNormal Estimation using Stable Diffusion." 24 | ) 25 | parser.add_argument( 26 | "--pretrained_model_path", 27 | type=str, 28 | default='lemonaddie/geowizard', 29 | help="pretrained model path from hugging face or local dir", 30 | ) 31 | 32 | parser.add_argument( 33 | "--domain", 34 | type=str, 35 | default='indoor', 36 | help="domain prediction", 37 | ) 38 | 39 | # inference setting 40 | parser.add_argument( 41 | "--denoise_steps", 42 | type=int, 43 | default=10, 44 | help="Diffusion denoising steps, more steps results in higher accuracy but slower inference speed.", 45 | ) 46 | parser.add_argument( 47 | "--ensemble_size", 48 | type=int, 49 | default=10, 50 | help="Number of predictions to be ensembled, more inference gives better results but runs slower.", 51 | ) 52 | parser.add_argument( 53 | "--half_precision", 54 | action="store_true", 55 | help="Run with half-precision (16-bit float), might lead to suboptimal result.", 56 | ) 57 | 58 | # resolution setting 59 | parser.add_argument( 60 | "--processing_res", 61 | type=int, 62 | default=768, 63 | help="Maximum resolution of processing. 0 for using input image resolution. Default: 768.", 64 | ) 65 | parser.add_argument( 66 | "--output_processing_res", 67 | action="store_true", 68 | help="When input is resized, out put depth at resized operating resolution. Default: False.", 69 | ) 70 | 71 | # depth map colormap 72 | parser.add_argument( 73 | "--color_map", 74 | type=str, 75 | default="Spectral", 76 | help="Colormap used to render depth predictions.", 77 | ) 78 | # other settings 79 | parser.add_argument("--seed", type=int, default=None, help="Random seed.") 80 | parser.add_argument( 81 | "--batch_size", 82 | type=int, 83 | default=0, 84 | help="Inference batch size. Default: 0 (will be set automatically).", 85 | ) 86 | 87 | # custom settings 88 | parser.add_argument("--input", type=str, default="../data/pig_ball", help="Input image or directory.") 89 | parser.add_argument("--output", type=str, default="../outputs/pig_ball", help="Output directory") 90 | parser.add_argument("--vis", action="store_true", help="Visualize the output.") 91 | 92 | args = parser.parse_args() 93 | 94 | checkpoint_path = args.pretrained_model_path 95 | denoise_steps = args.denoise_steps 96 | ensemble_size = args.ensemble_size 97 | 98 | if ensemble_size>15: 99 | logging.warning("long ensemble steps, low speed..") 100 | 101 | half_precision = args.half_precision 102 | 103 | processing_res = args.processing_res 104 | match_input_res = not args.output_processing_res 105 | domain = args.domain 106 | 107 | color_map = args.color_map 108 | seed = args.seed 109 | batch_size = args.batch_size 110 | 111 | if batch_size==0: 112 | batch_size = 1 # set default batchsize 113 | 114 | # -------------------- Preparation -------------------- 115 | # Random seed 116 | if seed is None: 117 | import time 118 | seed = int(time.time()) 119 | seed_all(seed) 120 | 121 | # -------------------- Device -------------------- 122 | if torch.cuda.is_available(): 123 | device = torch.device("cuda") 124 | else: 125 | device = torch.device("cpu") 126 | logging.warning("CUDA is not available. Running on CPU will be slow.") 127 | logging.info(f"device = {device}") 128 | 129 | # -------------------- Data -------------------- 130 | if os.path.isfile(args.input): 131 | image_path = args.input 132 | else: 133 | image_path = os.path.join(args.input, "original.png") 134 | output = args.output 135 | if output is None: 136 | output = args.input if os.path.isdir(args.input) else os.path.dirname(args.input) 137 | os.makedirs(output, exist_ok=True) 138 | 139 | # -------------------- Model -------------------- 140 | if half_precision: 141 | dtype = torch.float16 142 | logging.info(f"Running with half precision ({dtype}).") 143 | else: 144 | dtype = torch.float32 145 | 146 | # declare a pipeline 147 | pipe = DepthNormalEstimationPipeline.from_pretrained(checkpoint_path, torch_dtype=dtype) 148 | logging.info("loading pipeline whole successfully.") 149 | 150 | try: 151 | pipe.enable_xformers_memory_efficient_attention() 152 | except: 153 | pass # run without xformers 154 | 155 | pipe = pipe.to(device) 156 | 157 | # -------------------- Inference and saving -------------------- 158 | with torch.no_grad(): 159 | 160 | # Read input image 161 | input_image = Image.open(image_path) 162 | 163 | # predict the depth here 164 | pipe_out = pipe(input_image, 165 | denoising_steps = denoise_steps, 166 | ensemble_size= ensemble_size, 167 | processing_res = processing_res, 168 | match_input_res = match_input_res, 169 | domain = domain, 170 | color_map = color_map, 171 | show_progress_bar = True, 172 | ) 173 | 174 | depth_pred: np.ndarray = pipe_out.depth_np 175 | depth_colored: Image.Image = pipe_out.depth_colored 176 | normal_pred: np.ndarray = pipe_out.normal_np 177 | normal_colored: Image.Image = pipe_out.normal_colored 178 | 179 | # Save as npy 180 | depth_npy_save_path = os.path.join(output, f"depth.npy") 181 | if os.path.exists(depth_npy_save_path): 182 | logging.warning(f"Existing file: '{depth_npy_save_path}' will be overwritten") 183 | np.save(depth_npy_save_path, depth_pred) 184 | 185 | normal_npy_save_path = os.path.join(output, f"normal.npy") 186 | if os.path.exists(normal_npy_save_path): 187 | logging.warning(f"Existing file: '{normal_npy_save_path}' will be overwritten") 188 | np.save(normal_npy_save_path, normal_pred) 189 | 190 | # Colorize 191 | if args.vis: 192 | intermediate_dir = os.path.join(output, "intermediate") 193 | os.makedirs(output, exist_ok=True) 194 | os.makedirs(intermediate_dir, exist_ok=True) 195 | 196 | depth_colored_save_path = os.path.join(intermediate_dir, f"depth_vis.png") 197 | if os.path.exists(depth_colored_save_path): 198 | logging.warning( 199 | f"Existing file: '{depth_colored_save_path}' will be overwritten" 200 | ) 201 | depth_colored.save(depth_colored_save_path) 202 | 203 | normal_colored_save_path = os.path.join(intermediate_dir, f"normal_vis.png") 204 | if os.path.exists(normal_colored_save_path): 205 | logging.warning( 206 | f"Existing file: '{normal_colored_save_path}' will be overwritten" 207 | ) 208 | normal_colored.save(normal_colored_save_path) 209 | print("Done.") -------------------------------------------------------------------------------- /perception/run_fg_bg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import os 5 | import json 6 | import matplotlib 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | def vis_seg(mask_img, text=False): 11 | norm = matplotlib.colors.Normalize(vmin=np.min(mask_img), vmax=np.max(mask_img)) 12 | norm_segmentation_map = norm(mask_img) 13 | cmap = "tab20" 14 | colormap = plt.get_cmap(cmap) 15 | colored_segmentation_map = colormap(norm_segmentation_map) 16 | colored_segmentation_map = (colored_segmentation_map[:, :, :3] * 255).astype(np.uint8) 17 | colored_segmentation_map[mask_img == 0] = [255, 255, 255] 18 | if text: 19 | unique_seg_ids = np.unique(mask_img) 20 | unique_seg_ids = unique_seg_ids[unique_seg_ids != 0] 21 | for seg_id in unique_seg_ids: 22 | mask_indices = np.where(mask_img == seg_id) 23 | if len(mask_indices[0]) > 0: 24 | center_y = int(np.mean(mask_indices[0])) 25 | center_x = int(np.mean(mask_indices[1])) 26 | cv2.putText( 27 | colored_segmentation_map, 28 | str(seg_id), 29 | (center_x, center_y), 30 | cv2.FONT_HERSHEY_SIMPLEX, 31 | 0.5, 32 | (0, 0, 0), 33 | 2, 34 | cv2.LINE_AA 35 | ) 36 | return colored_segmentation_map 37 | 38 | 39 | def refine_segmentation_mask(mask, seg_mask, value, min_size): 40 | """ 41 | Refine a segmentation mask by removing small connected components. 42 | """ 43 | num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask.astype(np.uint8), connectivity=8) 44 | 45 | # Filter out small connected components 46 | for label in range(1, num_labels): 47 | if stats[label, cv2.CC_STAT_AREA] >= min_size: 48 | seg_mask[labels == label] = value 49 | value = value + 1 50 | 51 | return seg_mask, value 52 | 53 | 54 | def find_neighbor_segmentation_classes(segmentation_map, class_id): 55 | """ 56 | Find the neighboring segmentation classes of a specified class in a segmentation map. 57 | """ 58 | class_mask = np.uint8(segmentation_map == class_id) 59 | dilated_mask = cv2.dilate(class_mask, np.ones((3, 3), np.uint8), iterations=3) 60 | neighbor_class_ids = np.unique(segmentation_map[dilated_mask > 0]) 61 | neighbor_class_ids = neighbor_class_ids[neighbor_class_ids != class_id] 62 | return list(neighbor_class_ids) 63 | 64 | 65 | 66 | def compute_angle_to_direction(normal, direction): 67 | """ 68 | Compute the angle (in degrees) between a surface normal and a direction vector 69 | """ 70 | dot_product = np.dot(normal, direction) 71 | norm_normal = np.linalg.norm(normal, axis=-1) 72 | norm_direction = np.linalg.norm(direction) 73 | cos_theta = dot_product / (norm_normal * norm_direction) 74 | return np.abs(cos_theta) 75 | 76 | 77 | def normal_planes(normals): 78 | ''' 79 | normal coordinate: 80 | y 81 | | 82 | o -- x 83 | / 84 | z 85 | ''' 86 | upward_direction = np.array([0, -1, 0]) 87 | downward_direction = np.array([0, 1, 0]) 88 | leftward_direction = np.array([-1, 0, 0]) 89 | rightward_direction = np.array([1, 0, 0]) 90 | forward_direction = np.array([0, 0, 1]) 91 | inward_direction = np.array([0, 0, -1]) 92 | 93 | # Compute the angles between each pixel's normal and the 6 directions 94 | angle_leftward = compute_angle_to_direction(normals, leftward_direction) 95 | angle_rightward = compute_angle_to_direction(normals, rightward_direction) 96 | angle_X = np.maximum(angle_leftward, angle_rightward).mean() 97 | 98 | angle_upward = compute_angle_to_direction(normals, upward_direction) 99 | angle_downward = compute_angle_to_direction(normals, downward_direction) 100 | angle_Y = np.maximum(angle_upward, angle_downward).mean() 101 | 102 | angle_forward = compute_angle_to_direction(normals, forward_direction) 103 | angle_inward = compute_angle_to_direction(normals, inward_direction) 104 | angle_Z = np.maximum(angle_forward, angle_inward).mean() 105 | 106 | angles = np.array([angle_X, angle_Y, angle_Z]) 107 | index = np.argmax(angles) 108 | if index == 0: 109 | return "X" 110 | elif index == 1: 111 | return "Y" 112 | else: 113 | return "Z" 114 | 115 | 116 | def find_upper_contour(binary_mask): 117 | binary_mask = binary_mask.astype(np.uint8) 118 | contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 119 | contour = contours[0] 120 | leftmost = tuple(contour[contour[:,:,0].argmin()][0]) 121 | rightmost = tuple(contour[contour[:,:,0].argmax()][0]) 122 | 123 | start_index = np.where((contour[:, 0] == leftmost).all(axis=1))[0][0] 124 | end_index = np.where((contour[:, 0] == rightmost).all(axis=1))[0][0] 125 | 126 | if start_index > end_index: 127 | start_index, end_index = end_index, start_index 128 | 129 | subarray1 = contour[start_index:end_index+1] 130 | subarray2 = np.concatenate((contour[end_index:], contour[:start_index+1])) 131 | if subarray1[:, 0, 1].mean() < subarray2[:, 0, 1].mean(): 132 | upper_contour = subarray1 133 | else: 134 | upper_contour = subarray2 135 | 136 | start_x, end_x = upper_contour[:, 0, 0].min(), upper_contour[:, 0, 0].max() 137 | start_x, end_x = int(start_x), int(end_x) 138 | y = upper_contour[:, 0, 1].mean() 139 | y = int(y) 140 | edge = [(start_x, y), (end_x, y)] 141 | return edge 142 | 143 | 144 | def is_mask_truncated(mask): 145 | if np.any(mask[0, :] == 1) or np.any(mask[-1, :] == 1): 146 | return True 147 | if np.any(mask[:, 0] == 1) or np.any(mask[:, -1] == 1): 148 | return True 149 | return False 150 | 151 | def draw_edge_on_image(image, edges, color=(0, 255, 0), thickness=2): 152 | output_image = image.copy() 153 | for edge in edges: 154 | for i in range(1, len(edge)): 155 | start_point = (edge[i-1][0], edge[i-1][1]) 156 | end_point = (edge[i][0], edge[i][1]) 157 | cv2.line(output_image, start_point, end_point, color, thickness) 158 | 159 | return output_image 160 | 161 | 162 | 163 | if __name__ == '__main__': 164 | parser = argparse.ArgumentParser() 165 | parser.add_argument('--input', type=str, default="../data/pig_ball", help="input path") 166 | parser.add_argument("--output", type=str, default=None, help="output directory") 167 | parser.add_argument("--vis_edge", action="store_true", help="visualize edges") 168 | 169 | args = parser.parse_args() 170 | 171 | output = args.output 172 | if output is None: 173 | output = args.input if os.path.isdir(args.input) else os.path.dirname(args.input) 174 | os.makedirs(output, exist_ok=True) 175 | 176 | seg_path = os.path.join(args.input, "intermediate", 'mask.png') 177 | seg_mask = cv2.imread(seg_path, 0) 178 | 179 | depth_path = os.path.join(args.input, 'depth.npy') 180 | depth = np.load(depth_path) 181 | 182 | normal_path = os.path.join(args.input, 'normal.npy') 183 | normal = np.load(normal_path) 184 | 185 | seg_info_path = os.path.join(args.input, "intermediate", 'mask.json') 186 | with open(seg_info_path, 'r') as f: 187 | seg_info = json.load(f) 188 | 189 | seg_ids = np.unique(seg_mask) 190 | 191 | obj_infos = seg_info 192 | for seg_id in seg_ids: 193 | 194 | mask = seg_mask == seg_id 195 | seg_depth = depth[mask] 196 | 197 | min_depth = np.percentile(seg_depth, 5) 198 | max_depth = np.percentile(seg_depth, 95) 199 | 200 | obj_infos[seg_id]["depth"] = [min_depth, max_depth] 201 | 202 | movable_seg_ids = [seg_id for seg_id in seg_ids if obj_infos[seg_id]['movable']] 203 | fg_seg_mask = np.zeros_like(seg_mask) 204 | fg_truncated_mask = np.zeros_like(seg_mask) 205 | edges = [] 206 | for seg_id in movable_seg_ids: 207 | mask = seg_mask == seg_id 208 | if mask.sum() < 100: # ignore small objects 209 | seg_mask[mask] = 0 210 | continue 211 | if is_mask_truncated(mask): 212 | fg_truncated_mask[mask] = seg_id 213 | else: 214 | fg_seg_mask[mask] = seg_id 215 | 216 | seg_mask_path = os.path.join(output, 'mask.png') # save the foreground mask as final mask 217 | fg_mask = fg_seg_mask + fg_truncated_mask 218 | cv2.imwrite(seg_mask_path, fg_mask) 219 | 220 | vis_seg_mask = vis_seg(fg_seg_mask, text=True) 221 | vis_save_path = os.path.join(output, "intermediate", 'fg_mask_vis.png') 222 | cv2.imwrite(vis_save_path, vis_seg_mask) 223 | 224 | for seg_id in np.unique(fg_truncated_mask): 225 | if seg_id == 0: 226 | continue 227 | mask = fg_truncated_mask == seg_id 228 | points = cv2.findNonZero((mask> 0).astype(np.uint8)) 229 | 230 | x, y, w, h = cv2.boundingRect(points) 231 | 232 | top_edge = [[x, y], [x + w, y]] 233 | left_edge = [[x, y], [x, y + h]] 234 | right_edge = [[x + w, y], [x + w, y + h]] 235 | edges.extend([right_edge, left_edge, top_edge]) 236 | 237 | bg_seg_mask = np.zeros_like(seg_mask) 238 | value = 1 239 | nonmovable_seg_ids = [seg_id for seg_id in seg_ids if not obj_infos[seg_id]['movable']] 240 | for seg_id in nonmovable_seg_ids: 241 | seg_area_ratio = np.sum(seg_mask == seg_id) / seg_mask.size 242 | if seg_id == 0 or seg_info[seg_id]['label'] == "background" or seg_area_ratio > 0.5: 243 | depth_threshold = np.array([obj_infos[seg_id]["depth"][0] for seg_id in movable_seg_ids]).min() 244 | mask = np.logical_and(seg_mask == seg_id, depth <= depth_threshold) 245 | min_size = int(500 / (512 * 512) * mask.size) 246 | bg_seg_mask, value = refine_segmentation_mask(mask, bg_seg_mask, value, min_size=min_size) 247 | 248 | else: 249 | neighbor_seg_ids = find_neighbor_segmentation_classes(seg_mask, seg_id) 250 | neighbor_seg_ids = [seg_id for seg_id in neighbor_seg_ids if seg_id in movable_seg_ids] 251 | if len(neighbor_seg_ids) == 0: 252 | depth_array = np.array([obj_infos[seg_id]["depth"][1] for seg_id in movable_seg_ids]) 253 | else: 254 | depth_array = np.array([obj_infos[seg_id]["depth"][1] for seg_id in neighbor_seg_ids]) 255 | 256 | min_threshold = depth_array.min() 257 | max_threshold = depth_array.max() 258 | 259 | min_size = int(500 / (512 * 512) * mask.size) 260 | mask = np.logical_and(seg_mask == seg_id, depth <= min_threshold) 261 | old_value = value 262 | bg_seg_mask, value = refine_segmentation_mask(mask, bg_seg_mask, value, min_size=min_size) 263 | 264 | if value == old_value: 265 | mask = np.logical_and(seg_mask == seg_id, min_threshold < depth) 266 | mask = np.logical_and(mask, depth <= max_threshold) 267 | bg_seg_mask, value = refine_segmentation_mask(mask, bg_seg_mask, value, min_size=min_size) 268 | 269 | seg_save_path = os.path.join(output, 'intermediate', 'bg_mask.png') 270 | cv2.imwrite(seg_save_path, bg_seg_mask) 271 | 272 | vis_seg_mask = vis_seg(bg_seg_mask) 273 | vis_seg_path = os.path.join(output, 'intermediate', 'bg_mask_vis.png') 274 | cv2.imwrite(vis_seg_path, vis_seg_mask) 275 | 276 | seg_ids = np.unique(bg_seg_mask) 277 | edge_map = np.zeros_like(bg_seg_mask) 278 | for seg_id in seg_ids: 279 | if seg_id == 0: 280 | continue 281 | else: 282 | mask = (bg_seg_mask == seg_id) 283 | normals = normal[mask] 284 | axis = normal_planes(normals) 285 | if axis == "X" or axis == "Z": 286 | points = cv2.findNonZero((mask> 0).astype(np.uint8)) 287 | 288 | if points is not None: 289 | x, y, w, h = cv2.boundingRect(points) 290 | 291 | if axis == "X": 292 | right_col = x + w - 1 293 | edge = [[right_col, y], [right_col, y + h]] 294 | edges.append(edge) 295 | else: # Z 296 | top_edge = [[x, y], [x + w, y]] 297 | left_edge = [[x, y], [x, y + h]] 298 | right_edge = [[x + w, y], [x + w, y + h]] 299 | edges.extend([right_edge, left_edge, top_edge]) 300 | 301 | elif axis == "Y": 302 | edge= find_upper_contour(mask) 303 | edges.append(edge) 304 | if args.vis_edge: 305 | img_path = os.path.join(args.input, 'original.png') 306 | img = cv2.imread(img_path) 307 | vis_edge_mask = draw_edge_on_image(img, edges, color=(0, 0, 255)) 308 | vis_edge_save_path = os.path.join(output, 'intermediate', 'edge_vis.png') 309 | cv2.imwrite(vis_edge_save_path, vis_edge_mask) 310 | 311 | with open(os.path.join(output, 'edges.json'), 'w') as f: 312 | json.dump(edges, f) 313 | print("Done!") 314 | 315 | 316 | 317 | 318 | -------------------------------------------------------------------------------- /perception/run_gsam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import json 5 | import torch 6 | import torchvision 7 | from PIL import Image 8 | import cv2 9 | import numpy as np 10 | import matplotlib 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | BASE_DIR = "Grounded-Segment-Anything" 15 | sys.path.append(os.path.join(BASE_DIR)) 16 | 17 | # Grounding DINO 18 | import GroundingDINO.groundingdino.datasets.transforms as T 19 | from GroundingDINO.groundingdino.models import build_model 20 | from GroundingDINO.groundingdino.util.slconfig import SLConfig 21 | from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap 22 | 23 | 24 | # segment anything 25 | from segment_anything import ( 26 | sam_model_registry, 27 | sam_hq_model_registry, 28 | SamPredictor 29 | ) 30 | 31 | 32 | def load_image(image_path, return_transform=False): 33 | image_pil = Image.open(image_path).convert("RGB") 34 | 35 | transform = T.Compose( 36 | [ 37 | T.RandomResize([800], max_size=1333), 38 | T.ToTensor(), 39 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 40 | ] 41 | ) 42 | image, _ = transform(image_pil, None) # 3, h, w 43 | if not return_transform: 44 | return image_pil, image 45 | else: 46 | return image_pil, image, transform 47 | 48 | 49 | def load_model(model_config_path, model_checkpoint_path, device): 50 | args = SLConfig.fromfile(model_config_path) 51 | args.device = device 52 | model = build_model(args) 53 | checkpoint = torch.load(model_checkpoint_path, map_location="cpu") 54 | load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) 55 | print(load_res) 56 | _ = model.eval() 57 | return model 58 | 59 | 60 | def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"): 61 | caption = caption.lower() 62 | caption = caption.strip() 63 | if not caption.endswith("."): 64 | caption = caption + "." 65 | model = model.to(device) 66 | image = image.to(device) 67 | with torch.no_grad(): 68 | outputs = model(image[None], captions=[caption]) 69 | logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) 70 | boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4) 71 | logits.shape[0] 72 | 73 | # filter output 74 | logits_filt = logits.clone() 75 | boxes_filt = boxes.clone() 76 | filt_mask = logits_filt.max(dim=1)[0] > box_threshold 77 | logits_filt = logits_filt[filt_mask] # num_filt, 256 78 | boxes_filt = boxes_filt[filt_mask] # num_filt, 4 79 | logits_filt.shape[0] 80 | 81 | # get phrase 82 | tokenlizer = model.tokenizer 83 | tokenized = tokenlizer(caption) 84 | # build pred 85 | pred_phrases = [] 86 | scores = [] 87 | for logit, box in zip(logits_filt, boxes_filt): 88 | pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer) 89 | if with_logits: 90 | pred_phrases.append(pred_phrase + f"({str(logit.max().item())})") 91 | else: 92 | pred_phrases.append(pred_phrase) 93 | scores.append(logit.max().item()) 94 | 95 | return boxes_filt, torch.Tensor(scores), pred_phrases 96 | 97 | 98 | def save_mask_data(output_dir, mask_list, label_list, movable_dict): 99 | value = 0 # 0 for background 100 | mask_img = torch.zeros(mask_list.shape[-2:]) 101 | for idx, mask in enumerate(mask_list): 102 | mask_img[mask.cpu().numpy() == True] = value + idx + 1 103 | 104 | mask_img = mask_img.numpy().astype(np.uint8) 105 | Image.fromarray(mask_img).save(os.path.join(output_dir, "mask.png")) 106 | print("number of classes: ", len(np.unique(mask_img))) 107 | 108 | norm = matplotlib.colors.Normalize(vmin=np.min(mask_img), vmax=np.max(mask_img)) 109 | norm_segmentation_map = norm(mask_img) 110 | cmap = "tab20" 111 | colormap = plt.get_cmap(cmap) 112 | colored_segmentation_map = colormap(norm_segmentation_map) 113 | colored_segmentation_map = (colored_segmentation_map[:, :, :3] * 255).astype(np.uint8) 114 | colored_segmentation_map[mask_img == 0] = [255, 255, 255] 115 | 116 | cv2.imwrite(os.path.join(output_dir, "vis_mask.jpg"), colored_segmentation_map[..., ::-1]) 117 | print("seg visualization saved to ", os.path.join(output_dir, "vis_mask.jpg")) 118 | 119 | json_data = [{ 120 | 'value': value, 121 | 'label': 'background', 122 | 'movable': False 123 | }] 124 | for label in label_list: 125 | value += 1 126 | movable = movable_dict[label] if (label in movable_dict and movable_dict[label]) else False 127 | json_data.append({ 128 | 'value': value, 129 | 'label': label, 130 | 'movable': movable 131 | }) 132 | with open(os.path.join(output_dir, 'mask.json'), 'w') as f: 133 | json.dump(json_data, f) 134 | 135 | 136 | def masks_postprocess(masks, phrases, movable_dict): 137 | new_masks_list = [] 138 | new_phrases_list = [] 139 | for idx, (label, mask) in enumerate(zip(phrases, masks)): 140 | mask = mask.numpy() 141 | 142 | # post-processing label 143 | if (label != "background") or label not in movable_dict: 144 | found_keys = [key for key in movable_dict if key in label] 145 | if found_keys: 146 | phrases[idx] = found_keys[0] 147 | label = found_keys[0] 148 | 149 | if label == "background": # background 150 | continue 151 | 152 | if label in movable_dict: 153 | if movable_dict[label]: # movable object 154 | contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 155 | if len(contours) == 1: 156 | new_masks_list.append(mask) 157 | new_phrases_list.append(label) 158 | else: 159 | for contour in contours: 160 | component_mask = np.zeros_like(mask, dtype=np.uint8) 161 | cv2.drawContours(component_mask, [contour], -1, 1, thickness=cv2.FILLED) 162 | new_masks_list.append(component_mask) 163 | new_phrases_list.append(label) 164 | else: 165 | new_masks_list.append(mask) 166 | new_phrases_list.append(label) 167 | else: # merge into background 168 | continue 169 | 170 | masks = np.stack(new_masks_list, axis=0) 171 | masks = torch.from_numpy(masks) 172 | assert len(masks) == len(new_phrases_list) 173 | return masks, new_phrases_list 174 | 175 | 176 | if __name__ == "__main__": 177 | parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True) 178 | parser.add_argument("--config", type=str, default="Grounded-Segment-Anything/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", help="path to config file") 179 | parser.add_argument("--grounded_checkpoint", type=str, default="Grounded-Segment-Anything/groundingdino_swint_ogc.pth", help="path to checkpoint file" 180 | ) 181 | parser.add_argument("--sam_version", type=str, default="vit_h", required=False, help="SAM ViT version: vit_b / vit_l / vit_h" 182 | ) 183 | parser.add_argument( 184 | "--sam_checkpoint", type=str, default="Grounded-Segment-Anything/sam_vit_h_4b8939.pth", help="path to sam checkpoint file" 185 | ) 186 | parser.add_argument( 187 | "--sam_hq_checkpoint", type=str, default=None, help="path to sam-hq checkpoint file" 188 | ) 189 | parser.add_argument( 190 | "--use_sam_hq", action="store_true", help="using sam-hq for prediction" 191 | ) 192 | 193 | parser.add_argument("--device", type=str, default="cuda", help="running on cpu only!, default=False") 194 | 195 | # custom setting 196 | parser.add_argument("--box_threshold", type=float, default=0.2, help="box threshold") 197 | parser.add_argument("--text_threshold", type=float, default=0.2, help="text threshold") 198 | parser.add_argument("--iou_threshold", type=float, default=0.15, help="nms threshold") 199 | parser.add_argument('--disable_nms', action='store_true', help='Disable nms') 200 | 201 | parser.add_argument("--input", type=str, default="../data/pig_ball", help="input path") 202 | parser.add_argument("--output", type=str, default=None, help="output directory") 203 | parser.add_argument("--prompts_path", type=str, default=None, help="prompt file under same path as input") 204 | 205 | args = parser.parse_args() 206 | 207 | # cfg 208 | config_file = args.config # change the path of the model config file 209 | grounded_checkpoint = args.grounded_checkpoint # change the path of the model 210 | sam_version = args.sam_version 211 | sam_checkpoint = args.sam_checkpoint 212 | sam_hq_checkpoint = args.sam_hq_checkpoint 213 | use_sam_hq = args.use_sam_hq 214 | box_threshold = args.box_threshold 215 | text_threshold = args.text_threshold 216 | iou_threshold = args.iou_threshold 217 | device = args.device 218 | 219 | if os.path.isfile(args.input): 220 | image_path = args.input 221 | else: 222 | image_path = os.path.join(args.input, "original.png") 223 | output = args.output 224 | if output is None: 225 | output = args.input if os.path.isdir(args.input) else os.path.dirname(args.input) 226 | os.makedirs(output, exist_ok=True) 227 | 228 | if args.prompts_path is None: 229 | prompts_path = os.path.join(args.input, "intermediate", "obj_movable.json") 230 | else: 231 | prompts_path = args.prompts_path 232 | 233 | with open(prompts_path, "r") as f: 234 | movable_dict = json.load(f) 235 | 236 | model = load_model(config_file, grounded_checkpoint, device=device) 237 | 238 | if use_sam_hq: 239 | predictor = SamPredictor(sam_hq_model_registry[sam_version](checkpoint=sam_hq_checkpoint).to(device)) 240 | else: 241 | predictor = SamPredictor(sam_model_registry[sam_version](checkpoint=sam_checkpoint).to(device)) 242 | 243 | image_pil, image_dino_input, transform = load_image(image_path, return_transform=True) 244 | 245 | text_prompt = "" 246 | for idx, tag in enumerate(movable_dict): 247 | 248 | if idx < len(movable_dict) - 1: 249 | text_prompt += tag + ". " 250 | else: 251 | text_prompt += tag 252 | 253 | print(f"Text Prompt input: {text_prompt}") 254 | 255 | # run grounding dino model 256 | boxes_filt, scores, pred_phrases = get_grounding_output( 257 | model, image_dino_input, text_prompt, box_threshold, text_threshold, with_logits=False, device=device) 258 | 259 | size = image_pil.size 260 | H, W = size[1], size[0] 261 | for i in range(boxes_filt.size(0)): 262 | boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) 263 | boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 264 | boxes_filt[i][2:] += boxes_filt[i][:2] 265 | 266 | boxes_filt = boxes_filt.cpu() # boxes_filt: xyxy 267 | 268 | # use NMS to handle overlapped boxes 269 | if not args.disable_nms: 270 | print(f"Before NMS: {boxes_filt.shape[0]} boxes") 271 | nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist() 272 | boxes_filt = boxes_filt[nms_idx] 273 | pred_phrases = [pred_phrases[idx] for idx in nms_idx] 274 | scores = scores[nms_idx] 275 | print(f"After NMS: {boxes_filt.shape[0]} boxes") 276 | 277 | image_input = np.array(image_pil) 278 | predictor.set_image(image_input) 279 | transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_input.shape[:2]).to(device) 280 | masks, _, _ = predictor.predict_torch( 281 | point_coords = None, 282 | point_labels = None, 283 | boxes = transformed_boxes.to(device), 284 | multimask_output = False, 285 | ) 286 | masks = masks.cpu().squeeze(1) # (N, H, W) 287 | 288 | 289 | image_input2 = image_input.copy() 290 | image_pil2 = Image.fromarray(image_input2) 291 | 292 | for idx, (label, box, mask) in enumerate(zip(pred_phrases, boxes_filt, masks)): 293 | mask = mask.numpy() 294 | if label == "background": # background 295 | continue 296 | 297 | if label in movable_dict and movable_dict[label]: # movable object 298 | contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 299 | if len(contours) == 1: 300 | image_input2[mask == True, :] = 0 # set as black 301 | 302 | 303 | image_pil2 = Image.fromarray(image_input2) 304 | image_dino_input2, _ = transform(image_pil2, None) 305 | 306 | boxes_filt2, scores2, pred_phrases2 = get_grounding_output( 307 | model, image_dino_input2, text_prompt, 0.8 * box_threshold, 0.8 * text_threshold, with_logits=False, device=device 308 | ) 309 | 310 | size = image_pil.size 311 | H, W = size[1], size[0] 312 | for i in range(boxes_filt2.size(0)): 313 | boxes_filt2[i] = boxes_filt2[i] * torch.Tensor([W, H, W, H]) 314 | boxes_filt2[i][:2] -= boxes_filt2[i][2:] / 2 315 | boxes_filt2[i][2:] += boxes_filt2[i][:2] 316 | 317 | boxes_filt2 = boxes_filt2.cpu() # boxes_filt: xyxy 318 | 319 | # use NMS to handle overlapped boxes 320 | if len(boxes_filt2) > 0: 321 | 322 | scores2 = 0.5 * scores2 # reduce the score of the second round prediction 323 | if not args.disable_nms: 324 | boxes_combine = torch.cat([boxes_filt, boxes_filt2], dim=0) 325 | phrases_combine = pred_phrases + pred_phrases2 326 | sc_combine = torch.cat([scores, scores2], dim=0) 327 | 328 | print(f"Before NMS: {boxes_filt2.shape[0]} boxes") 329 | nms_idx = torchvision.ops.nms(boxes_combine, sc_combine, 1.25 * iou_threshold).numpy().tolist() 330 | nms_idx = [idx for idx in nms_idx if idx >= len(boxes_filt)] 331 | boxes_filt2 = boxes_combine[nms_idx] 332 | pred_phrases2 = [phrases_combine[idx] for idx in nms_idx] 333 | scores2 = sc_combine[nms_idx] 334 | print(f"After NMS: {boxes_filt2.shape[0]} boxes") 335 | 336 | if len(boxes_filt2) > 0: 337 | image_input2 = np.array(image_pil2) 338 | predictor.set_image(image_input2) 339 | transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt2, image_input2.shape[:2]).to(device) 340 | masks2, _, _ = predictor.predict_torch( 341 | point_coords = None, 342 | point_labels = None, 343 | boxes = transformed_boxes.to(device), 344 | multimask_output = False, 345 | ) 346 | masks2 = masks2.cpu().squeeze(1) 347 | 348 | pred_phrases = pred_phrases + pred_phrases2 349 | masks = torch.cat([masks, masks2], dim=0) 350 | masks, pred_phrases = masks_postprocess(masks, pred_phrases, movable_dict) 351 | else: 352 | masks, pred_phrases = masks_postprocess(masks, pred_phrases, movable_dict) 353 | else: 354 | masks, pred_phrases = masks_postprocess(masks, pred_phrases, movable_dict) 355 | 356 | intermediate_dir = os.path.join(output, "intermediate") # save intermediate results and visualization 357 | os.makedirs(intermediate_dir, exist_ok=True) 358 | save_mask_data(intermediate_dir, masks, pred_phrases, movable_dict) 359 | 360 | -------------------------------------------------------------------------------- /perception/run_inpaint.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | import numpy as np 5 | import torch 6 | import cv2 7 | from PIL import Image 8 | 9 | BASE_DIR = "Inpaint-Anything" 10 | sys.path.append(os.path.join(BASE_DIR)) 11 | 12 | from lama_inpaint import inpaint_img_with_lama 13 | 14 | 15 | def load_img_to_array(img_p): 16 | img = Image.open(img_p) 17 | if img.mode == "RGBA": 18 | img = img.convert("RGB") 19 | return np.array(img) 20 | 21 | 22 | def save_array_to_img(img_arr, img_p): 23 | Image.fromarray(img_arr.astype(np.uint8)).save(img_p) 24 | 25 | 26 | def dilate_mask(mask, dilate_factor=20): 27 | mask = mask.astype(np.uint8) 28 | mask = cv2.dilate( 29 | mask, 30 | np.ones((dilate_factor, dilate_factor), np.uint8), 31 | iterations=1 32 | ) 33 | return mask 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--input', type=str, default="../data/pig_ball", help="input path") 39 | parser.add_argument("--output", type=str, default=None, help="output directory") 40 | parser.add_argument("--dilate_kernel_size", type=int, default=20, help="Dilate kernel size") 41 | parser.add_argument("--lama_config", type=str, default="Inpaint-Anything/lama/configs/prediction/default.yaml", 42 | help="The path to the config file of lama model.") 43 | parser.add_argument("--lama_ckpt", type=str, default="Inpaint-Anything/pretrained_models/big-lama", 44 | help="The path to the lama checkpoint.") 45 | 46 | args = parser.parse_args() 47 | 48 | if os.path.isfile(args.input): 49 | image_path = args.input 50 | else: 51 | image_path = os.path.join(args.input, "original.png") 52 | output = args.output 53 | if output is None: 54 | output = args.input if os.path.isdir(args.input) else os.path.dirname(args.input) 55 | os.makedirs(output, exist_ok=True) 56 | 57 | device = "cuda" if torch.cuda.is_available() else "cpu" 58 | 59 | image_path = os.path.join(args.input, "original.png") 60 | img = load_img_to_array(image_path) 61 | seg_path = os.path.join(args.input, 'mask.png') 62 | seg_mask = cv2.imread(seg_path, 0) 63 | mask = ((seg_mask > 0)*255).astype(np.uint8) 64 | 65 | if args.dilate_kernel_size is not None: 66 | mask = dilate_mask(mask, args.dilate_kernel_size) 67 | 68 | img_inpainted = inpaint_img_with_lama(img, mask, args.lama_config, args.lama_ckpt, device=device) 69 | save_array_to_img(img_inpainted, os.path.join(output, "inpaint.png")) -------------------------------------------------------------------------------- /relight/relight.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import copy 5 | import torch.nn.functional as F 6 | from tqdm import tqdm 7 | import torch 8 | import kornia as K 9 | 10 | from relight_utils import get_light_coeffs, generate_shd, writing_video 11 | 12 | 13 | class Relight: 14 | def __init__(self, img, mask_img, nrm, shd): 15 | # nrm: +X right, +Y down, +Z from screen to me, between 0 and 1 16 | self.img = img # RGB between 0 and 1 17 | self.mask_img = mask_img 18 | self.nrm = nrm 19 | 20 | src_msks = [] 21 | seg_ids = np.unique(mask_img) 22 | seg_ids.sort() 23 | for seg_id in np.unique(mask_img): 24 | if seg_id == 0: 25 | continue 26 | src_msk = (mask_img == seg_id) 27 | src_msks.append(src_msk) 28 | src_msks = np.stack(src_msks, axis=0) # (N, H, W) 29 | self.src_msks = torch.from_numpy(src_msks) 30 | self.num_classes = src_msks.shape[0] 31 | 32 | self.coeffs = get_light_coeffs(shd, self.nrm, self.img) # lambertian fitting 33 | shd[self.mask_img > 0] = generate_shd(self.nrm, self.coeffs, self.mask_img > 0) 34 | self.alb = (self.img ** 2.2) / shd[:, :, None].clip(1e-4) 35 | self.shd = shd 36 | 37 | 38 | def relight(self, comp_img, target_obj_mask, trans): 39 | # target_obj_mask (H, W) segmentation map 40 | # trans: (N, 2, 3) 41 | binary_seg_mask = [] 42 | for seg_id in np.unique(self.mask_img): 43 | if seg_id == 0: 44 | continue 45 | seg_mask = target_obj_mask == seg_id 46 | binary_seg_mask.append(seg_mask) 47 | 48 | binary_seg_mask = np.stack(binary_seg_mask, axis=0) # (N, H, W) 49 | binary_seg_mask = torch.from_numpy(binary_seg_mask[:, None, :, :]) # (N, 1, H, W) 50 | 51 | assert binary_seg_mask.shape[0] == self.num_classes, "number of segmentation ids should be equal to the number of foreground objects" 52 | 53 | src_normal = torch.from_numpy(self.nrm).unsqueeze(dim=0).repeat(self.num_classes, 1, 1, 1) # (N, H, W, 3) 54 | rot_matrix = trans[:, :2, :2] # (N, 2, 2) 55 | full_rot = torch.cat([rot_matrix, torch.zeros(self.num_classes, 1, 2)], dim=1) # (N, 3, 2) 56 | full_rot = torch.cat([full_rot, torch.tensor([0, 0, 1]).unsqueeze(dim=0).unsqueeze(dim=-1).repeat(self.num_classes, 1, 1)], dim=2) # (N, 3, 3) 57 | 58 | nrm = (src_normal * 2.0) - 1.0 59 | nrm = nrm.reshape(self.num_classes, -1, 3) # (N, H*W, 3) 60 | trans_nrm = torch.bmm(nrm, full_rot) # the core rot is created by kornia, so there is no transpose # (N, H*W, 3) 61 | trans_nrm = (trans_nrm +1) / 2.0 62 | trans_nrm = trans_nrm.reshape(self.num_classes, self.img.shape[0], self.img.shape[1], 3) # (N, H, W, 3) 63 | 64 | trans_nrm = trans_nrm.permute(0, 3, 1, 2) # (N, 3, H, W) 65 | out = K.geometry.warp_affine(trans_nrm, trans, (mask_img.shape[0], mask_img.shape[1])) # (N, 3, H, W) 66 | target_nrm = (out * binary_seg_mask.float()).sum(dim=0) # (3, H, W) 67 | target_nrm = target_nrm.permute(1, 2, 0).numpy() # (H, W, 3) 68 | 69 | comp_nrm = copy.deepcopy(self.nrm) 70 | comp_nrm[target_obj_mask > 0] = target_nrm[target_obj_mask > 0] 71 | 72 | comp_shd = copy.deepcopy(self.shd) 73 | comp_shd[target_obj_mask > 0] = generate_shd(comp_nrm, self.coeffs, target_obj_mask > 0) 74 | comp_alb = (comp_img ** 2.2) / comp_shd[:, :, None].clip(1e-4) 75 | 76 | # compose albedo from src 77 | assert trans.shape[0] == self.num_classes, "number of transformations should be equal to the number of foreground objects" 78 | src_alb = torch.from_numpy(self.alb).permute(2, 0, 1).unsqueeze(dim=0).repeat(self.num_classes, 1, 1, 1) # (N, 3, H, W) 79 | out = K.geometry.warp_affine(src_alb, trans, (self.mask_img.shape[0], self.mask_img.shape[1])) # (N, 3, H, W) 80 | 81 | foreground_alb = (out * binary_seg_mask.float()).sum(dim=0) # (3, H, W) 82 | foreground_alb = foreground_alb.permute(1, 2, 0).numpy() # (H, W, 3) 83 | comp_alb[target_obj_mask > 0, :] = foreground_alb[target_obj_mask > 0, :] 84 | 85 | compose_img = ((comp_alb * comp_shd[:, :, None])** (1/2.2)).clip(0, 1) 86 | 87 | return compose_img 88 | 89 | 90 | def prepare_input(video_path, mask_video_path, trans_list_path): 91 | comp_video = torch.load(video_path) # (f, 3, H, W) (16, 3, 512, 512) between 0 and 255 92 | if comp_video.max().item() > 1.: 93 | comp_video = comp_video / 255.0 94 | if comp_video.shape[1] ==3: # (f, 3, H, W) -> (f, H, W, 3) 95 | comp_video = comp_video.permute(0, 2, 3, 1) 96 | T, H, W = comp_video.shape[:3] 97 | obj_masks = torch.load(mask_video_path).squeeze().float() # (f, H, W) 98 | trans_list = torch.load(trans_list_path) 99 | if trans_list.ndim == 3: # (f, 2, 3) -> (f, 1, 2, 3) 100 | trans_list = trans_list.unsqueeze(dim=1) 101 | if trans_list.shape[-2] == 3: # (f, *, 3, 3) -> (f, *, 2, 3) 102 | trans_list = trans_list[:, :, :2, :] 103 | assert comp_video.shape[0] == obj_masks.shape[0] == trans_list.shape[0], "video and mask should have the same length" 104 | comp_video = comp_video.numpy() 105 | if obj_masks.ndim == 4: 106 | obj_masks = obj_masks[:, :, 0] # (f, H, W) 107 | obj_masks = obj_masks.numpy() # (f, H, W) 108 | return comp_video, obj_masks, trans_list 109 | 110 | 111 | if __name__ == '__main__': 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument("--perception_input", type=str, default="../data/pool", help='input dir') 114 | parser.add_argument('--previous_output', type=str, default="../outputs/pool", help='previous output dir') 115 | 116 | args = parser.parse_args() 117 | 118 | perception_input = args.perception_input 119 | previous_output = args.previous_output 120 | 121 | video_path = os.path.join(previous_output, "composite.pt") 122 | mask_video_path = os.path.join(previous_output, "mask_video.pt") 123 | trans_list_path = os.path.join(previous_output, "trans_list.pt") 124 | comp_video, obj_masks, trans_list = prepare_input(video_path, mask_video_path, trans_list_path) 125 | 126 | normal_path = os.path.join(perception_input, "normal.npy") 127 | shading_path = os.path.join(perception_input, "shading.npy") 128 | normal = np.load(normal_path) # (H, W, 3) between -1 and 1 129 | # convert geowizard normal to OMNI normal 130 | normal[:, :, 0]= -normal[:, :, 0] 131 | normal = (normal + 1) / 2.0 132 | shading = np.load(shading_path) 133 | 134 | 135 | output = args.previous_output 136 | os.makedirs(output, exist_ok=True) 137 | 138 | T = comp_video.shape[0] 139 | img = comp_video[0] # (H, W, 3) 140 | mask_img = obj_masks[0] 141 | model = Relight(img, mask_img, normal, shading) 142 | 143 | relight_list = [img] 144 | for time_idx in tqdm(range(1, T)): 145 | comp_img = comp_video[time_idx] 146 | target_obj_mask = obj_masks[time_idx] # segmentation msk 147 | trans = trans_list[time_idx] # (*, 2, 3) 148 | 149 | compose_img = model.relight(comp_img, target_obj_mask, trans) 150 | relight_list.append(compose_img) 151 | 152 | relight_video = np.stack(relight_list, axis=0) 153 | torch.save(torch.from_numpy(relight_video).permute(0, 3, 1, 2), f'{output}/relight.pt') # (0, 1) 154 | relight_video = (relight_video * 255).astype(np.uint8) 155 | writing_video(relight_video[..., ::-1], f'{output}/relight.mp4', frame_rate=7) 156 | print('done!') -------------------------------------------------------------------------------- /relight/relight_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | from torch.optim import Adam 5 | 6 | def invert(x): 7 | out = 1.0 / (x + 1.0) 8 | return out 9 | 10 | 11 | def uninvert(x, eps=0.001, clip=True): 12 | if clip: 13 | x = x.clip(eps, 1.0) 14 | 15 | out = (1.0 / x) - 1.0 16 | return out 17 | 18 | 19 | def spherical2cart(r, theta, phi): 20 | return [ 21 | r * torch.sin(theta) * torch.cos(phi), 22 | r * torch.sin(theta) * torch.sin(phi), 23 | r * torch.cos(theta) 24 | ] 25 | 26 | 27 | def run_optimization(params, A, b): 28 | 29 | optim = Adam([params], lr=0.01) 30 | prev_loss = 1000 31 | 32 | for i in range(500): 33 | optim.zero_grad() 34 | 35 | x, y, z = spherical2cart(params[2], params[0], params[1]) 36 | 37 | dir_shd = (A[:, 0] * x) + (A[:, 1] * y) + (A[:, 2] * z) 38 | pred_shd = dir_shd + params[3] 39 | 40 | loss = torch.nn.functional.mse_loss(pred_shd.reshape(-1), b) 41 | 42 | loss.backward() 43 | 44 | optim.step() 45 | 46 | # theta can range from 0 -> pi/2 (0 to 90 degrees) 47 | # phi can range from 0 -> 2pi (0 to 360 degrees) 48 | with torch.no_grad(): 49 | if params[0] < 0: 50 | params[0] = 0 51 | 52 | if params[0] > np.pi / 2: 53 | params[0] = np.pi / 2 54 | 55 | if params[1] < 0: 56 | params[1] = 0 57 | 58 | if params[1] > 2 * np.pi: 59 | params[1] = 2 * np.pi 60 | 61 | if params[2] < 0: 62 | params[2] = 0 63 | 64 | if params[3] < 0.1: 65 | params[3] = 0.1 66 | 67 | delta = prev_loss - loss 68 | 69 | if delta < 0.0001: 70 | break 71 | 72 | prev_loss = loss 73 | 74 | return loss, params 75 | 76 | 77 | def test_init(params, A, b): 78 | x, y, z = spherical2cart(params[2], params[0], params[1]) 79 | 80 | dir_shd = (A[:, 0] * x) + (A[:, 1] * y) + (A[:, 2] * z) 81 | pred_shd = dir_shd + params[3] 82 | 83 | loss = torch.nn.functional.mse_loss(pred_shd.reshape(-1), b) 84 | return loss 85 | 86 | 87 | def get_light_coeffs(shd, nrm, img, mask=None): 88 | valid = (img.mean(-1) > 0.05) * (img.mean(-1) < 0.95) 89 | 90 | if mask is not None: 91 | valid *= (mask == 0) 92 | 93 | nrm = (nrm * 2.0) - 1.0 94 | 95 | A = nrm[valid == 1] 96 | A /= np.linalg.norm(A, axis=1, keepdims=True) 97 | b = shd[valid == 1] 98 | 99 | # parameters are theta, phi, and bias (c) 100 | A = torch.from_numpy(A) 101 | b = torch.from_numpy(b) 102 | 103 | min_init = 1000 104 | for t in np.arange(0, np.pi/2, 0.1): 105 | for p in np.arange(0, 2*np.pi, 0.25): 106 | params = torch.nn.Parameter(torch.tensor([t, p, 1, 0.5])) 107 | init_loss = test_init(params, A, b) 108 | 109 | if init_loss < min_init: 110 | best_init = params 111 | min_init = init_loss 112 | 113 | loss, params = run_optimization(best_init, A, b) 114 | 115 | x, y, z = spherical2cart(params[2], params[0], params[1]) 116 | 117 | coeffs = torch.tensor([x, y, z]).reshape(3, 1).detach().numpy() 118 | coeffs = np.array([x.item(), y.item(), z.item(), params[3].item()]) 119 | return coeffs 120 | 121 | 122 | def generate_shd(nrm, coeffs, msk, bias=True): 123 | 124 | nrm = (nrm * 2.0) - 1.0 125 | 126 | A = nrm.reshape(-1, 3) 127 | A /= np.linalg.norm(A, axis=1, keepdims=True) 128 | 129 | A_fg = nrm[msk == 1] 130 | A_fg /= np.linalg.norm(A_fg, axis=1, keepdims=True) 131 | 132 | if bias: 133 | A = np.concatenate((A, np.ones((A.shape[0], 1))), 1) 134 | A_fg = np.concatenate((A_fg, np.ones((A_fg.shape[0], 1))), 1) 135 | 136 | inf_shd = (A_fg @ coeffs) 137 | inf_shd = inf_shd.clip(0) + 0.2 138 | return inf_shd 139 | 140 | 141 | def writing_video(rgb_list, save_path: str, frame_rate: int = 30): 142 | fourcc = cv2.VideoWriter_fourcc(*'MP4V') 143 | h, w, _ = rgb_list[0].shape 144 | out = cv2.VideoWriter(save_path, fourcc, frame_rate, (w, h)) 145 | 146 | for img in rgb_list: 147 | out.write(img) 148 | 149 | out.release() 150 | return -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # basic simulation requirements 2 | omegaconf 3 | pygame 4 | pymunk 5 | opencv-python 6 | imantics 7 | pillow 8 | kornia 9 | av 10 | # seine requirements 11 | numpy==1.26.4 12 | matplotlib==3.9.2 13 | torch==2.0.1 14 | torchaudio==2.0.2 15 | torchvision==0.15.2 16 | decord==0.6.0 17 | diffusers==0.15.0 18 | imageio==2.29.0 19 | transformers==4.29.2 20 | xformers==0.0.20 21 | einops 22 | tensorboard==2.15.1 23 | timm==0.9.10 24 | rotary-embedding-torch==0.3.5 25 | natsort==8.4.0 26 | openai==0.28 27 | ruamel.yaml 28 | 29 | 30 | -------------------------------------------------------------------------------- /scripts/run_demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME=$1 4 | 5 | cd simulation 6 | python animate.py --config ../data/${NAME}/sim.yaml --save_root ../outputs 7 | 8 | cd ../relight 9 | python relight.py --perception_input ../data/${NAME} --previous_output ../outputs/${NAME} 10 | 11 | cd ../diffusion 12 | python video_diffusion.py --perception_input ../data/${NAME} --previous_output ../outputs/${NAME} -------------------------------------------------------------------------------- /simulation/animate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | import argparse 4 | from omegaconf import OmegaConf 5 | import pygame 6 | import pymunk 7 | import pymunk.pygame_util 8 | import numpy as np 9 | import cv2 10 | import matplotlib.pyplot as plt 11 | from matplotlib import cm 12 | import matplotlib.colors as colors 13 | from imantics import Mask 14 | 15 | from animate_utils import AnimateObj, AnimateImgSeg 16 | from sim_utils import fit_circle_from_mask, list_to_numpy 17 | 18 | 19 | 20 | class AnimateUni(AnimateObj): 21 | 22 | def __init__(self, 23 | save_dir, 24 | mask_path, 25 | obj_info: dict, 26 | edge_list: Optional[np.array]=None, 27 | init_velocity: Optional[dict]=None, 28 | init_acc: Optional[dict]=None, 29 | gravity=980, 30 | ground_elasticity=0, 31 | ground_friction=1, 32 | num_steps=200, 33 | size=(512, 512), 34 | display=True, 35 | save_snapshot: Optional[bool]=False, 36 | snapshot_frames: Optional[int]=16, 37 | colormap: Optional[str]="tab20") -> None: 38 | 39 | 40 | super(AnimateUni, self).__init__(save_dir, size, gravity, display) 41 | self.num_steps = num_steps 42 | self.objs = {} # key is seg_id, value is pymunk shape 43 | mask_img = cv2.imread(mask_path, 0) 44 | self.mask_img = mask_img 45 | 46 | self.save_snapshot = save_snapshot 47 | if self.save_snapshot: 48 | self.snapshot_dir = os.path.join(save_dir, "snapshot") 49 | os.makedirs(self.snapshot_dir, exist_ok=True) 50 | self.snapshot_frames = snapshot_frames 51 | self.colormap = colormap 52 | 53 | self.obj_info = obj_info 54 | 55 | for seg_id in self.obj_info: 56 | class_name = self.obj_info[seg_id]["primitive"] 57 | density = self.obj_info[seg_id]["density"] 58 | mass = self.obj_info[seg_id]["mass"] 59 | elasticity = self.obj_info[seg_id]["elasticity"] 60 | friction = self.obj_info[seg_id]["friction"] 61 | if class_name == "polygon": 62 | mask = self.mask_img == seg_id 63 | polygons = Mask(mask).polygons() 64 | if len(polygons.points) > 1: # find the largest polygon 65 | areas = [] 66 | for points in polygons.points: 67 | points = points.reshape((-1, 1, 2)).astype(np.int32) 68 | img = np.zeros((self.mask_img.shape[0], self.mask_img.shape[1], 3), dtype=np.uint8) 69 | img = cv2.fillPoly(img, [points], color=[0, 255, 0]) 70 | mask = img[:, :, 1] > 0 71 | area = np.count_nonzero(mask) 72 | areas.append(area) 73 | areas = np.array(areas) 74 | largest_idx = np.argmax(areas) 75 | points = polygons.points[largest_idx] 76 | else: 77 | points = polygons.points[0] 78 | 79 | points = tuple(map(tuple, points)) 80 | poly = self._create_poly(density, points, elasticity, friction) 81 | self.objs[seg_id] = poly 82 | 83 | elif class_name == "circle": 84 | mask = self.mask_img == int(seg_id) 85 | mask = (mask * 255).astype(np.uint8) 86 | center, radius = fit_circle_from_mask(mask) 87 | ball = self._create_ball(mass, radius, center, elasticity, friction) 88 | self.objs[seg_id] = ball 89 | 90 | if edge_list is not None: 91 | self._create_wall_segments(edge_list, ground_elasticity, ground_friction) 92 | 93 | self.init_velocity = init_velocity 94 | self.init_acc = init_acc 95 | 96 | def _create_wall_segments(self, edge_list, elasticity=0.05, friction=0.9): 97 | """Create a number of wall segments connecting the points""" 98 | for edge in edge_list: 99 | point_1, point_2 = edge 100 | v1 = pymunk.Vec2d(point_1[0], point_1[1]) 101 | v2 = pymunk.Vec2d(point_2[0], point_2[1]) 102 | wall_body = pymunk.Body(body_type=pymunk.Body.STATIC) 103 | wall_shape = pymunk.Segment(wall_body, v1, v2, 0.0) 104 | wall_shape.collision_type = 0 105 | wall_shape.elasticity = elasticity 106 | wall_shape.friction = friction 107 | self._space.add(wall_body, wall_shape) 108 | 109 | def _create_ball(self, mass, radius, position, elasticity=0.5, friction=0.4): 110 | """ 111 | Create a ball defined by mass, radius and position 112 | """ 113 | inertia = pymunk.moment_for_circle(mass, 0, radius, (0, 0)) 114 | body = pymunk.Body(mass, inertia) 115 | body.position = position 116 | shape = pymunk.Circle(body, radius, (0, 0)) 117 | shape.elasticity = elasticity 118 | shape.friction = friction 119 | self._space.add(body, shape) 120 | return shape 121 | 122 | def _create_poly(self, density, points, elasticity=0.5, friction=0.4, collision_type=0): 123 | """ 124 | Create a poly defined by density, points 125 | """ 126 | body = pymunk.Body() 127 | shape = pymunk.Poly(body, points) 128 | shape.density = density 129 | shape.elasticity = elasticity 130 | shape.friction = friction 131 | shape.collision_type = collision_type 132 | self._space.add(body, shape) 133 | return shape 134 | 135 | 136 | def _draw_objects(self) -> None: 137 | jet = plt.get_cmap(self.colormap) 138 | cNorm = colors.Normalize(vmin=0, vmax=len(self.objs)) 139 | scalarMap = cm.ScalarMappable(norm=cNorm, cmap=jet) 140 | for seg_id in self.objs: 141 | shape = self.objs[seg_id] 142 | color = scalarMap.to_rgba(int(seg_id)) 143 | color = (color[0]*255, color[1]*255, color[2]*255, 255) 144 | if isinstance(shape, pymunk.Circle): 145 | self.draw_ball(shape, color) 146 | elif isinstance(shape, pymunk.Poly): 147 | self.draw_poly(shape, color) 148 | for shape in self._space.shapes: 149 | if isinstance(shape, pymunk.Segment): 150 | self.draw_wall(shape) 151 | 152 | def draw_ball(self, ball, color=(0, 0, 255, 255)): 153 | body = ball.body 154 | v = body.position + ball.offset.cpvrotate(body.rotation_vector) 155 | r = ball.radius 156 | pygame.draw.circle(self._screen, pygame.Color(color), v, int(r), 0) 157 | 158 | def draw_wall(self, wall, color=(252, 3, 169, 255), width=3): 159 | body = wall.body 160 | pv1 = body.position + wall.a.cpvrotate(body.rotation_vector) 161 | pv2 = body.position + wall.b.cpvrotate(body.rotation_vector) 162 | pygame.draw.lines(self._screen, pygame.Color(color), False, [pv1, pv2], width=width) 163 | 164 | def draw_poly(self, poly, color=(0, 255, 0, 255)): 165 | body = poly.body 166 | ps = [p.rotated(body.angle) + body.position for p in poly.get_vertices()] 167 | ps.append(ps[0]) 168 | pygame.draw.polygon(self._screen, pygame.Color(color), ps) 169 | 170 | def get_transform(self): 171 | # get current timestep objs state transformation 172 | """ 173 | rotation_matrix: np.array([[math.cos(rad), -math.sin(rad)], [math.sin(rad), math.cos(rad)]]) 174 | in Kornia, the rotation matrix created by K.geometry.transform.get_rotation_matrix2d should be transpose, 175 | equivalent to use -angle in record 176 | """ 177 | state = {} 178 | keys = list(self.objs.keys()) 179 | keys = sorted(keys, key=lambda x: int(x)) 180 | for seg_id in keys: 181 | shape = self.objs[seg_id] 182 | if isinstance(shape, pymunk.Poly): 183 | ps = [p.rotated(shape.body.angle) + shape.body.position for p in shape.get_vertices()] 184 | ps = np.array(ps) 185 | center = np.mean(ps, axis=0) 186 | angle = shape.body.angle 187 | 188 | elif isinstance(shape, pymunk.Circle): 189 | center = shape.body.position 190 | angle = shape.body.angle 191 | 192 | state[seg_id] = (center, angle) 193 | return state 194 | 195 | def init_condition(self, init_velocity, init_acc): 196 | for seg_id in self.objs: 197 | shape = self.objs[seg_id] 198 | query_seg_id = seg_id 199 | if init_velocity is not None and query_seg_id in init_velocity: 200 | ins_init_vel = list(init_velocity[query_seg_id]) if not isinstance(init_velocity[query_seg_id], list) else init_velocity[query_seg_id] 201 | shape.body.velocity = ins_init_vel 202 | if init_acc is not None and query_seg_id in init_acc: 203 | ins_init_acc = init_acc[query_seg_id] 204 | shape.body.apply_impulse_at_local_point((ins_init_acc[0] * shape.body.mass, ins_init_acc[1] * shape.body.mass), (0, 0)) 205 | 206 | def run(self) -> None: 207 | """ 208 | The main loop of the game. 209 | :return: None 210 | """ 211 | # Main loop 212 | num_steps = self.num_steps // self._physics_steps_per_frame 213 | count = 0 214 | self.init_condition(self.init_velocity, self.init_acc) 215 | if self.save_snapshot: 216 | # save snapshot of the simulation of #animate frames 217 | snapshot_indices = np.arange(num_steps)[::num_steps//(self.snapshot_frames)][:self.snapshot_frames] 218 | snapshot_indices[0] = 0 219 | save_path = os.path.join(self.snapshot_dir, "snapshot_{:03d}.png") 220 | self._process_events() 221 | self._clear_screen() 222 | self._draw_objects() 223 | pygame.display.flip() 224 | pygame.image.save(self._screen, save_path.format(0)) 225 | 226 | while self._running and count < num_steps: 227 | # Progress time forward 228 | count += 1 229 | for x in range(self._physics_steps_per_frame): 230 | self._space.step(self._dt) 231 | state = self.get_transform() 232 | self.history.append(state) 233 | if self.display: 234 | self._process_events() 235 | self._clear_screen() 236 | self._draw_objects() 237 | pygame.display.flip() 238 | if self.save_snapshot and count in snapshot_indices: 239 | index = np.where(snapshot_indices==count)[0].item() 240 | pygame.image.save(self._screen, save_path.format(index)) 241 | 242 | self._clock.tick(50) 243 | pygame.display.set_caption("fps: " + str(self._clock.get_fps())) 244 | self.save_state() 245 | 246 | 247 | def main(args, data_root, save_root): 248 | save_dir = os.path.join(save_root, args.cat.lower()) 249 | os.makedirs(save_dir, exist_ok=True) 250 | 251 | data_dir = os.path.join(data_root, args.cat) 252 | mask_path=os.path.join(data_dir, "mask.png") 253 | 254 | anim = AnimateUni( 255 | mask_path=mask_path, save_dir=save_dir, 256 | obj_info=args.obj_info, edge_list=args.edge_list, 257 | init_velocity=args.init_velocity, init_acc=args.init_acc, 258 | num_steps=args.num_steps, 259 | gravity=args.gravity, 260 | ground_elasticity=getattr(args, "ground_elasticity", 0), 261 | ground_friction=getattr(args, "ground_friction", 1), 262 | size=args.size, display=args.display, save_snapshot=args.save_snapshot, snapshot_frames=args.animation_frames) 263 | anim.run() 264 | 265 | history_path = os.path.join(save_dir, "history.pkl") 266 | animation_frames = getattr(args, "animation_frames", 16) 267 | replay = AnimateImgSeg(data_dir=data_dir, save_dir=save_dir, history_path=history_path, animate_frames=animation_frames) 268 | replay.record() 269 | 270 | 271 | if __name__ == "__main__": 272 | parser = argparse.ArgumentParser() 273 | parser.add_argument("--data_root", type=str, default="../data") 274 | parser.add_argument("--save_root", type=str, default="../outputs") 275 | parser.add_argument("--config", type=str, default="../data/pool/sim.yaml") 276 | args = parser.parse_args() 277 | config = OmegaConf.load(args.config) 278 | config = list_to_numpy(config) 279 | main(config, data_root=args.data_root, save_root=args.save_root) -------------------------------------------------------------------------------- /simulation/animate_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from typing import List 4 | import numpy as np 5 | import pickle 6 | import pygame 7 | import pymunk 8 | import pymunk.pygame_util 9 | 10 | import kornia as K 11 | from sim_utils import writing_video, prep_data, composite_trans 12 | 13 | 14 | class AnimateObj(object): 15 | 16 | def __init__(self, save_dir, size=(512,512), gravity=980, display=True) -> None: 17 | self.size = size 18 | self.save_dir = save_dir 19 | os.makedirs(self.save_dir, exist_ok=True) 20 | self.history = [] # record state 21 | 22 | # Space 23 | self._space = pymunk.Space() 24 | # Physics 25 | # Time step 26 | self._space.gravity = (0.0, gravity) # positive_y_is_down 27 | self._dt = 1.0 / 60.0 28 | # Number of physics steps per screen frame 29 | self._physics_steps_per_frame = 1 # larger, faster 30 | self.history = [] 31 | # pygame 32 | self.display = display 33 | if self.display: 34 | pygame.init() 35 | self._screen = pygame.display.set_mode(size) 36 | self._clock = pygame.time.Clock() 37 | self._draw_options = pymunk.pygame_util.DrawOptions(self._screen) 38 | 39 | # Execution control and time until the next ball spawns 40 | self._running = True 41 | 42 | def run(self) -> None: 43 | """Custom implement func 44 | The main loop of the game. 45 | :return: None 46 | """ 47 | # Main loop 48 | while self._running: 49 | # Progress time forward 50 | for x in range(self._physics_steps_per_frame): 51 | self._space.step(self._dt) 52 | 53 | state = self.get_transform() 54 | self.history.append(state) 55 | if self.display: 56 | self._process_events() 57 | self.draw() 58 | # self._clear_screen() 59 | # self._draw_objects() 60 | pygame.display.flip() 61 | # Delay fixed time between frames 62 | self._clock.tick(50) 63 | pygame.display.set_caption("fps: " + str(self._clock.get_fps())) 64 | self.save_state() 65 | 66 | 67 | def get_transform(self, verbose=False) -> List[int]: 68 | # custom func 69 | center = self.ball.body.position 70 | angle = self.poly.body.angle 71 | if verbose: 72 | print(center, angle) 73 | return center, angle 74 | 75 | def save_state(self): 76 | # save self.state 77 | with open(os.path.join(self.save_dir, 'history.pkl'), 'wb') as f: 78 | pickle.dump(self.history, f) 79 | 80 | def _process_events(self) -> None: 81 | """ 82 | Handle game and events like keyboard input. Call once per frame only. 83 | :return: None 84 | """ 85 | for event in pygame.event.get(): 86 | if event.type == pygame.QUIT: 87 | self._running = False 88 | elif event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE: 89 | self._running = False 90 | 91 | def _clear_screen(self) -> None: 92 | """ 93 | Clears the screen. 94 | :return: None 95 | """ 96 | self._screen.fill(pygame.Color("white")) 97 | 98 | def _draw_objects(self) -> None: 99 | """ 100 | Draw the objects. 101 | :return: None 102 | """ 103 | self._space.debug_draw(self._draw_options) 104 | 105 | 106 | class AnimateImgSeg(object): 107 | # Animate Img from Segmentation 108 | def __init__(self, data_dir, save_dir, history_path, animate_frames=16): 109 | self.save_dir = save_dir 110 | os.makedirs(self.save_dir, exist_ok=True) 111 | 112 | self.img, self.mask_img, inpaint_img = prep_data(data_dir) 113 | 114 | with open(history_path, 'rb') as f: 115 | history = pickle.load(f) 116 | self.animate_frames = animate_frames 117 | self.history = history[::len(history)//(self.animate_frames)][:animate_frames] 118 | assert len(self.history) == self.animate_frames 119 | 120 | active_keys = list(self.history[0].keys()) 121 | self.active_keys = sorted(active_keys, key=lambda x: int(x)) 122 | for seg_id in np.unique(self.mask_img): 123 | if seg_id == 0: # background 124 | continue 125 | if seg_id not in active_keys: 126 | mask = self.mask_img == int(seg_id) 127 | inpaint_img[mask] = self.img[mask] 128 | self.mask_img[mask] = 0 # set as background 129 | self.inpaint_img = inpaint_img # inpaint_img is the background 130 | 131 | init_centers = [] 132 | for seg_id in active_keys: 133 | if seg_id == "0": 134 | raise ValueError("seg_id should not be 0") 135 | mask = self.mask_img == int(seg_id) 136 | init_center = np.mean(np.argwhere(mask), axis=0) 137 | init_center = np.array([init_center[1], init_center[0]]) # center (y,x) to (x,y) 138 | init_centers.append(init_center) 139 | init_centers = np.stack(init_centers, axis=0) # (num_objs, 2) 140 | self.init_centers = init_centers.astype(np.float32) 141 | 142 | def record(self): 143 | H, W = self.img.shape[:2] 144 | masked_src_imgs = [] 145 | for seg_id in self.active_keys: 146 | mask = self.mask_img == int(seg_id) 147 | src_img = torch.from_numpy(self.img * mask[:, :, None]).permute(2, 0, 1).float() # (3, H, W) 148 | masked_src_imgs.append(src_img) 149 | masked_src_imgs = torch.stack(masked_src_imgs, dim=0) # tensor (num_objs, 3, H, W) 150 | init_centers = torch.from_numpy(self.init_centers).float()# tensor (num_objs, 2) 151 | 152 | imgs = [] 153 | msk_list = [] 154 | trans_list = [] # (animate_frames [dict{seg_id: Trans(2, 3)}] 155 | for i in range(len(self.history)): 156 | if i == 0: # use original image 157 | trans = [] 158 | imgs.append((self.img*255).astype(np.uint8)) 159 | # active segmentation mask 160 | seg_mask = np.zeros_like(self.mask_img) 161 | for seg_id in self.active_keys: 162 | seg_mask[self.mask_img == int(seg_id)] = seg_id 163 | trans.append(torch.eye(3)[:2, :]) # (2, 3) 164 | msk_list.append(seg_mask) 165 | trans = torch.stack(trans, dim=0) # (num_objs, 2, 3) 166 | trans_list.append(trans) 167 | else: 168 | history = self.history[i] # dict of seg_id: (center, angle) 169 | centers, scales, angles = [], [], [] 170 | for seg_id in self.active_keys: 171 | center, angle = history[seg_id] 172 | 173 | # the rotation matrix used in K is pymunk rotation transpose, thus use -angle 174 | angle = -angle 175 | 176 | center = torch.tensor([center]).float() # (1, 2) 177 | angle = torch.tensor([angle/np.pi * 180]).float() # [1] 178 | scale = torch.ones(1, 2).float() 179 | centers.append(center) 180 | scales.append(scale) 181 | angles.append(angle) 182 | centers = torch.cat(centers, dim=0) # (num_objs, 2) 183 | scales = torch.cat(scales, dim=0) # (num_objs, 2) 184 | angles = torch.cat(angles, dim=0) # (num_objs) 185 | 186 | trans = K.geometry.transform.get_rotation_matrix2d(init_centers, angles, scales) 187 | trans[:, :, 2] += centers - init_centers # (num_objs, 2, 3) 188 | 189 | active_list = list(map(int, self.active_keys)) 190 | final_frame, seg_mask = composite_trans(masked_src_imgs, trans, self.inpaint_img, active_list) 191 | 192 | imgs.append(final_frame) 193 | msk_list.append(seg_mask) 194 | trans_list.append(trans) # (num_objs, 2, 3) 195 | 196 | imgs = np.stack(imgs, axis=0) # (animate_frames, H, W, 3) 197 | writing_video(imgs[..., ::-1], os.path.join(self.save_dir, 'composite.mp4'), frame_rate=7) 198 | 199 | msk_list = np.stack(msk_list, axis=0) # (animate_frames, H, W) 200 | msk_list = torch.from_numpy(msk_list) 201 | imgs = torch.from_numpy(imgs).permute(0, 3, 1, 2) 202 | trans_list = np.stack(trans_list, axis=0) # (animate_frames, num_objs, 2, 3) 203 | trans_list = torch.from_numpy(trans_list) 204 | assert len(imgs) == len(msk_list) == len(trans_list) 205 | 206 | torch.save(msk_list, os.path.join(self.save_dir,'mask_video.pt')) # foreground objs segmentation mask 207 | torch.save(imgs, os.path.join(self.save_dir,'composite.pt')) 208 | torch.save(trans_list, os.path.join(self.save_dir, "trans_list.pt")) -------------------------------------------------------------------------------- /simulation/sim_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import torch 5 | import torch.nn.functional as F 6 | from PIL import Image 7 | import kornia as K 8 | from omegaconf import OmegaConf 9 | 10 | 11 | def get_bbox(mask): 12 | rows = np.any(mask, axis=1) 13 | cols = np.any(mask, axis=0) 14 | rmin, rmax = np.where(rows)[0][[0, -1]] 15 | cmin, cmax = np.where(cols)[0][[0, -1]] 16 | 17 | return rmin, rmax, cmin, cmax 18 | 19 | 20 | def writing_video(rgb_list, save_path: str, frame_rate: int = 30): 21 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 22 | h, w, _ = rgb_list[0].shape 23 | out = cv2.VideoWriter(save_path, fourcc, frame_rate, (w, h)) 24 | 25 | for img in rgb_list: 26 | out.write(img) 27 | 28 | out.release() 29 | return 30 | 31 | 32 | def prep_data(data_dir): 33 | 34 | img_path = os.path.join(data_dir, "original.png") 35 | mask_path = os.path.join(data_dir, "mask.png") 36 | inpaint_path = os.path.join(data_dir, "inpaint.png") 37 | img = np.array(Image.open(img_path)) / 255 38 | mask_img = np.array(Image.open(mask_path)) 39 | if mask_img.ndim == 3: 40 | mask_img = mask_img[:, :, 0] 41 | inpaint_img = np.array(Image.open(inpaint_path)).astype(np.float32) / 255 42 | return img, mask_img, inpaint_img 43 | 44 | 45 | def fit_circle_from_mask(mask_image): 46 | contours, _ = cv2.findContours(mask_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 47 | 48 | if len(contours) == 0: 49 | print("No contours found in the mask image.") 50 | return None 51 | max_contour = max(contours, key=cv2.contourArea) 52 | (x, y), radius = cv2.minEnclosingCircle(max_contour) 53 | 54 | center = (x, y) 55 | return center, radius 56 | 57 | 58 | def composite_trans(masked_src_imgs, trans, inpaint_img, active_seg_ids): 59 | """params: 60 | src_imgs: (N, C, H, W) 61 | src_seg: (H, W) original segmentaion mask 62 | trans: (N, 2, 3) 63 | inpaint_img: (C, H, W) 64 | active_seg_ids: (N, ) record segmentaion id 65 | thre: threshold for foreground segmentaion mask 66 | return: 67 | 68 | """ 69 | H, W = inpaint_img.shape[:2] 70 | out = K.geometry.warp_affine(masked_src_imgs, trans, (H, W)) # (N, C, H, W) 71 | src_binary_seg = (masked_src_imgs.sum(dim=1, keepdim=True) > 0).float() # (N, 1, H, W) 72 | out_binary_seg = K.geometry.warp_affine(src_binary_seg, trans, (H, W)) # (N, C, H, W) 73 | 74 | foreground_msk = out_binary_seg.sum(dim=0).sum(dim=0) > 0 # (H, W) 75 | seg_map = torch.zeros((H, W)).long() 76 | seg_map[~foreground_msk] = 0 77 | seg_mask = out_binary_seg.sum(dim=1).argmax(dim=0) + 1 # (H, W) 78 | seg_map[foreground_msk] = seg_mask[foreground_msk] # 0 is background, 1~N is the fake segmentaion id 79 | 80 | num_classes = len(active_seg_ids) + 1 81 | binary_seg_map = F.one_hot(seg_map, num_classes=num_classes) # (H, W, N+1) 82 | binary_seg_map = binary_seg_map.permute(2, 0, 1).float() # (N+1, H, W) 83 | binary_seg_map = binary_seg_map.unsqueeze(dim=1) # (N, 1, H, W) 84 | 85 | inpaint_img = torch.from_numpy(inpaint_img).permute(2, 0, 1).unsqueeze(0) # (1, C, H, W) 86 | input = torch.cat([inpaint_img, out], dim=0) # (N+1, C, H, W) 87 | 88 | composite = (binary_seg_map * input).sum(0) # (C, H, W) 89 | 90 | composite = composite.permute(1, 2, 0).numpy() # (H, W, C) 91 | final_frame = (composite*255).astype(np.uint8) 92 | 93 | final_seg_map = torch.zeros((H, W)).long() 94 | for idx, seg_id in enumerate(active_seg_ids): 95 | final_seg_map[seg_map==idx+1] = seg_id 96 | 97 | final_seg_map = final_seg_map.numpy() 98 | 99 | return final_frame, final_seg_map 100 | 101 | 102 | def list_to_numpy(data): 103 | if isinstance(data, list): 104 | try: 105 | return np.array(data) 106 | except ValueError: 107 | return data 108 | elif isinstance(data, dict) or isinstance(data, OmegaConf): 109 | for key in data: 110 | data[key] = list_to_numpy(data[key]) 111 | return data --------------------------------------------------------------------------------