├── .gitmodules
├── README.md
├── assets
├── car_final.gif
├── car_relight.gif
├── demo.gif
├── domino_composite.gif
├── domino_sim.gif
├── method.png
├── pig_ball_composite.gif
├── pig_ball_relight.gif
├── pool_composite.gif
└── pool_sim.gif
├── custom_data
├── balls_shelf
│ └── original.png
├── boxes
│ └── original.png
├── kitchen
│ └── original.png
├── table
│ └── original.png
└── wall_toy
│ └── original.png
├── data
├── balls
│ ├── depth.npy
│ ├── inpaint.png
│ ├── intermediate
│ │ ├── albedo_vis.png
│ │ ├── depth_vis.png
│ │ ├── mask.png
│ │ ├── normal_vis.png
│ │ ├── obj_movable.json
│ │ ├── shading_vis.png
│ │ └── vis_mask.jpg
│ ├── mask.png
│ ├── normal.npy
│ ├── original.png
│ ├── shading.npy
│ ├── sim.yaml
│ └── vis.png
├── car
│ ├── inpaint.png
│ ├── mask.png
│ ├── normal.npy
│ ├── original.png
│ ├── shading.npy
│ └── sim.yaml
├── domino
│ ├── depth.npy
│ ├── inpaint.png
│ ├── intermediate
│ │ ├── albedo_vis.png
│ │ ├── depth_vis.png
│ │ ├── normal_vis.png
│ │ ├── obj_movable.json
│ │ ├── shading_vis.png
│ │ └── vis_mask.jpg
│ ├── mask.png
│ ├── normal.npy
│ ├── original.png
│ ├── shading.npy
│ ├── sim.yaml
│ └── vis.png
├── pig_ball
│ ├── depth.npy
│ ├── inpaint.png
│ ├── intermediate
│ │ ├── albedo_vis.png
│ │ ├── bg_mask_vis.png
│ │ ├── depth_vis.png
│ │ ├── edge_vis.png
│ │ ├── fg_mask_vis.png
│ │ ├── normal_vis.png
│ │ ├── obj_movable.json
│ │ ├── shading_vis.png
│ │ └── vis_mask.jpg
│ ├── mask.png
│ ├── normal.npy
│ ├── original.png
│ ├── shading.npy
│ ├── sim.yaml
│ └── vis.png
└── pool
│ ├── depth.npy
│ ├── inpaint.png
│ ├── intermediate
│ ├── albedo_vis.png
│ ├── depth_vis.png
│ ├── normal_vis.png
│ ├── obj_movable.json
│ ├── shading_vis.png
│ └── vis_mask.jpg
│ ├── mask.png
│ ├── normal.npy
│ ├── original.png
│ ├── shading.npy
│ ├── sim.yaml
│ └── vis.png
├── diffusion
└── video_diffusion.py
├── perception
├── README.md
├── gpt
│ ├── __init__.py
│ ├── gpt_configs
│ │ ├── my_apikey
│ │ ├── physics
│ │ │ └── user.txt
│ │ └── ram
│ │ │ └── user.txt
│ └── gpt_utils.py
├── gpt_physic.py
├── gpt_ram.py
├── run_albedo_shading.py
├── run_depth_normal.py
├── run_fg_bg.py
├── run_gsam.py
└── run_inpaint.py
├── relight
├── relight.py
└── relight_utils.py
├── requirements.txt
├── scripts
└── run_demo.sh
└── simulation
├── animate.py
├── animate_utils.py
└── sim_utils.py
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "diffusion/SEINE"]
2 | path = diffusion/SEINE
3 | url = https://github.com/Vchitect/SEINE.git
4 | [submodule "perception/Grounded-Segment-Anything"]
5 | path = perception/Grounded-Segment-Anything
6 | url = https://github.com/IDEA-Research/Grounded-Segment-Anything.git
7 | [submodule "perception/GeoWizard"]
8 | path = perception/GeoWizard
9 | url = https://github.com/fuxiao0719/GeoWizard
10 | [submodule "perception/Inpaint-Anything"]
11 | path = perception/Inpaint-Anything
12 | url = https://github.com/geekyutao/Inpaint-Anything.git
13 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
PhysGen: Rigid-Body Physics-Grounded
Image-to-Video Generation
5 |
6 |
7 | ECCV, 2024
8 |
9 | Shaowei Liu
10 | ·
11 | Zhongzheng Ren
12 | ·
13 | Saurabh Gupta*
14 | ·
15 | Shenlong Wang*
16 | ·
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 | This repository contains the pytorch implementation for the paper [PhysGen: Rigid-Body Physics-Grounded Image-to-Video Generation](https://stevenlsw.github.io/physgen/), ECCV 2024. In this paper, we present a novel training-free image-to-video generation pipeline integrates physical simulation and generative video diffusion prior.
38 |
39 | ## Overview
40 | 
41 |
42 | ## 📄 Table of Contents
43 |
44 | - [Installation](#installation)
45 | - [Colab Notebook](#colab-notebook)
46 | - [Quick Demo](#quick-demo)
47 | - [Perception](#perception)
48 | - [Simulation](#simulation)
49 | - [Rendering](#rendering)
50 | - [All-in-One command](#all-in-one-command)
51 | - [Evaluation](#evaluation)
52 | - [Custom Image Video Generation](#custom-image-video-generation)
53 | - [Citation](#citation)
54 |
55 |
56 | ## Installation
57 | - Clone this repository:
58 | ```Shell
59 | git clone --recurse-submodules https://github.com/stevenlsw/physgen.git
60 | cd physgen
61 | ```
62 | - Install requirements by the following commands:
63 | ```Shell
64 | conda create -n physgen python=3.9
65 | conda activate physgen
66 | pip install -r requirements.txt
67 | ```
68 |
69 | ## Colab Notebook
70 | Run our [Colab notebook](https://colab.research.google.com/drive/1imGIms3Y4RRtddA6IuxZ9bkP7N2gVVC_) for quick start!
71 |
72 |
73 |
74 | ## Quick Demo
75 |
76 | - Run image space dynamics simulation in just **3** seconds **without GPU and any displace device and additional setup** required!
77 | ```Shell
78 | export PYTHONPATH=$(pwd)
79 | name="pool"
80 | python simulation/animate.py --data_root data --save_root outputs --config data/${name}/sim.yaml
81 | ```
82 | - The output video should be saved in `outputs/${name}/composite.mp4`. Try set `name` to be `domino`, `balls`, `pig_ball` and `car` for other scenes exploration. The example outputs are shown below:
83 |
84 | | **Input Image** | **Simulation** | **Output Video** |
85 | |:---------------:|:--------------:|:----------------:|
86 | |
|
|
|
87 | |
|
|
|
88 |
89 |
90 | ## Perception
91 | - Please see [perception/README.md](perception/README.md) for details.
92 |
93 | | **Input** | **Segmentation** | **Normal** | **Albedo** | **Shading** | **Inpainting** |
94 | |:---------:|:----------------:|:----------:|:----------:|:-----------:|:--------------:|
95 | |
|
|
|
|
|
|
96 |
97 |
98 |
99 | ## Simulation
100 | - Simulation requires the following input for each image:
101 | ```Shell
102 | image folder/
103 | ├── original.png
104 | ├── mask.png # segmentation mask
105 | ├── inpaint.png # background inpainting
106 | ├── sim.yaml # simulation configuration file
107 | ```
108 |
109 | - `sim.yaml` specify the physical properties of each object and initial conditions (force and speed on each object). Please see `data/pig_ball/sim.yaml` for an example. Set `display` to `true` to visualize the simulation process with display device, set `save_snapshot` to `true` to save the simulation snapshots.
110 | - Run the simulation by the following command:
111 | ```Shell
112 | cd simulation
113 | python animate.py --data_root ../data --save_root ../outputs --config ../data/${name}/sim.yaml
114 | ```
115 | - The outputs are saved in `outputs/${name}` as follows:
116 | ```Shell
117 | output folder/
118 | ├── history.pkl # simulation history
119 | ├── composite.mp4 # composite video
120 | |── composite.pt # composite video tensor
121 | ├── mask_video.pt # foreground masked video tensor
122 | ├── trans_list.pt # objects transformation list tensor
123 | ```
124 |
125 | ## Rendering
126 |
127 | ### Relighting
128 | - Relighting requires the following input:
129 | ```Shell
130 | image folder/ #
131 | ├── normal.npy # normal map
132 | ├── shading.npy # shading map by intrinsic decomposition
133 | previous output folder/
134 | ├── composite.pt # composite video
135 | ├── mask_video.pt # foreground masked video tensor
136 | ├── trans_list.pt # objects transformation list tensor
137 |
138 | ```
139 | - The `perception_input` is the image folder contains the perception result. The `previous_output` is the output folder from the previous simulation step.
140 | - Run the relighting by the following command:
141 | ```Shell
142 | cd relight
143 | python relight.py --perception_input ../data/${name} --previous_output ../outputs/${name}
144 | ```
145 | - The output `relight.mp4` and `relight.pt` is the relighted video and tensor.
146 | - Compare between composite video and relighted video:
147 | | **Input Image** | **Composite Video** | **Relight Video** |
148 | |:---------------:|:-------------------:|:-----------------:|
149 | |
|
|
|
150 |
151 |
152 |
153 | ### Video Diffusion Rendering
154 | - Download the [SEINE](https://github.com/Vchitect/SEINE/) model follow [instruction](https://github.com/Vchitect/SEINE/tree/main?tab=readme-ov-file#download-our-model-and-t2i-base-model)
155 |
156 | ```Shell
157 | # install git-lfs beforehand
158 | mkdir -p diffusion/SEINE/pretrained
159 | git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 diffusion/SEINE/pretrained/stable-diffusion-v1-4
160 | wget -P diffusion/SEINE/pretrained https://huggingface.co/Vchitect/SEINE/resolve/main/seine.pt
161 | ```
162 |
163 | - The video diffusion rendering requires the following input:
164 | ```Shell
165 | image folder/ #
166 | ├── original.png # input image
167 | ├── sim.yaml # simulation configuration file (optional)
168 | previous output folder/
169 | ├── relight.pt # composite video
170 | ├── mask_video.pt # foreground masked video tensor
171 | ```
172 | - Run the video diffusion rendering by the following command:
173 | ```Shell
174 | cd diffusion
175 | python video_diffusion.py --perception_input ../data/${name} --previous_output ../outputs/${name}
176 | ```
177 | `denoise_strength` and `prompt` could be adjusted in the above script. `denoise_strength` controls the amount of noise added, 0 means no denoising, 1 means denoise from scratch with lots of variance to the input image. `prompt` is the input prompt for video diffusion model, we use default foreground object names from perception model as prompt.
178 |
179 |
180 | - The output `final_video.mp4` is the rendered video.
181 |
182 | - Compare between relight video and diffuson rendered video:
183 | | **Input Image** | **Relight Video** | **Final Video** |
184 | |:--------------------------------------:|:--------------------------------------------:|:--------------------------------------------:|
185 | |
|
|
|
186 |
187 |
188 |
189 | ## All-in-One command
190 | We integrate the simulation, relighting and video diffusion rendering in one script. Please follow the [Video Diffusion Rendering](#video-diffusion-rendering) to download the SEINE model first.
191 | ```Shell
192 | bash scripts/run_demo.sh ${name}
193 | ```
194 |
195 | ## Evaluation
196 | We compare ours against open-sourced img-to-video models [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter), [I2VGen-XL](https://github.com/ali-vilab/VGen), [SEINE](https://github.com/Vchitect/SEINE) and collected reference videos [GT]() in Sec. 4.3.
197 |
198 | - Install [pytorch-fid](https://github.com/mseitzer/pytorch-fid):
199 |
200 | ```
201 | pip install pytorch-fid
202 | ```
203 |
204 | - Download the evaluation data from [here](https://uofi.box.com/s/zl8au6w3jopke9sxb7v9sdyboglcozhl) for all comparisons and unzip to `evaluation` directory. Choose `${method name}` from `DynamiCrafter`, `I2VGen-XL`, `SEINE`, `ours`.
205 |
206 |
207 | - Evaluate image FID:
208 | ```Shell
209 | python -m pytorch_fid evaluation/${method name}/all evaluation/GT/all
210 | ```
211 |
212 | - Evaluate motion FID:
213 | ```Shell
214 | python -m pytorch_fid evaluation/${method name}/all_flow evaluation/GT/all_flow
215 | ```
216 |
217 | - For motion FID, we use [RAFT](https://github.com/princeton-vl/RAFT) to compute optical flow between neighbor frames. The video processing scripts can be found [here](https://drive.google.com/drive/folders/10KDXRGEdcYSJuxLp8v6u1N5EB8Ghs6Xk?usp=sharing).
218 |
219 |
220 |
221 | ## Custom Image Video Generation
222 |
223 | - Our method should generally work for side-view and top-down view images. For custom images, please follow the [perception](#perception), [simulation](#simulation), [rendering](#rendering) pipeline to generate the video.
224 |
225 | - Critical steps (assume proper environment installed)
226 |
227 | - Input:
228 | ```Shell
229 | image folder/
230 | ├── original.png
231 | ```
232 | - Perception:
233 | ```Shell
234 | cd perception/
235 | python gpt_ram.py --img_path ${image folder}
236 | python run_gsam.py --input ${image folder}
237 | python run_depth_normal.py --input ${image folder} --vis
238 | python run_fg_bg.py --input ${image folder} --vis_edge
239 | python run_inpaint.py --input ${image folder} --dilate_kernel_size 20
240 | python run_albedo_shading.py --input ${image folder} --vis
241 | ```
242 |
243 | - After perception step, you should get
244 | ```Shell
245 | image folder/
246 | ├── original.png
247 | ├── mask.png # foreground segmentation mask
248 | ├── inpaint.png # background inpainting
249 | ├── normal.npy # normal map
250 | ├── shading.npy # shading map by intrinsic decomposition
251 | ├── edges.json # edges
252 | ├── physics.yaml # physics properties of foreground objects
253 | ```
254 |
255 | - Compose `${image folder}/sim.yaml` for simulation by specifying the object init conditions (you could check foreground objects ids in `${image folder}/intermediate/fg_mask_vis.png`), please see example in `data/pig_ball/sim.yaml`, copy the content in `physics.yaml` to `sim.yaml` and edges information from `edges.json`.
256 |
257 | - Run simulation:
258 | ```Shell
259 | cd simulation/
260 | python animate.py --data_root ${image_folder} --save_root ${image_folder} --config ${image_folder}/sim.yaml
261 | ```
262 |
263 | - Run rendering:
264 | ```Shell
265 | cd relight/
266 | python relight.py --perception_input ${image_folder} --previous_output ${image_folder}
267 | cd ../diffusion/
268 | python video_diffusion.py --perception_input ${image_folder} --previous_output ${image_folder} --denoise_strength ${denoise_strength}
269 | ```
270 |
271 | - We put some custom images under `custom_data` folder. You could play with each image by running the above steps and see different physical simulations.
272 |
273 | | **Balls Shelf** | **Boxes** | **Kitchen** | **Table** | **Toy**
274 | | :---------------: | :-------: | :---------: | :-------: | :----------: |
275 | |
|
|
|
|
|
276 |
277 | ## Citation
278 |
279 | If you find our work useful in your research, please cite:
280 |
281 | ```BiBTeX
282 | @inproceedings{liu2024physgen,
283 | title={PhysGen: Rigid-Body Physics-Grounded Image-to-Video Generation},
284 | author={Liu, Shaowei and Ren, Zhongzheng and Gupta, Saurabh and Wang, Shenlong},
285 | booktitle={European Conference on Computer Vision ECCV},
286 | year={2024}
287 | }
288 | ```
289 |
290 |
291 | ## Acknowledgement
292 | * [Grounded-Segment-Anything
293 | ](https://github.com/IDEA-Research/Grounded-Segment-Anything) for segmentation in [perception](#perception)
294 | * [GeoWizard
295 | ](https://github.com/fuxiao0719/GeoWizard) for depth and normal estimation in [perception](#perception)
296 | * [Intrinsic](https://github.com/compphoto/Intrinsic/) for intrinsic image decomposition in [perception](#perception)
297 | * [Inpaint-Anything](https://github.com/geekyutao/Inpaint-Anything) for image inpainting in [perception](#perception)
298 | * [Pymunk](https://github.com/viblo/pymunk) for physics simulation in [simulation](#simulation)
299 | * [SEINE](https://github.com/Vchitect/SEINE/) for video diffusion in [rendering](#rendering)
300 |
--------------------------------------------------------------------------------
/assets/car_final.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/car_final.gif
--------------------------------------------------------------------------------
/assets/car_relight.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/car_relight.gif
--------------------------------------------------------------------------------
/assets/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/demo.gif
--------------------------------------------------------------------------------
/assets/domino_composite.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/domino_composite.gif
--------------------------------------------------------------------------------
/assets/domino_sim.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/domino_sim.gif
--------------------------------------------------------------------------------
/assets/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/method.png
--------------------------------------------------------------------------------
/assets/pig_ball_composite.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/pig_ball_composite.gif
--------------------------------------------------------------------------------
/assets/pig_ball_relight.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/pig_ball_relight.gif
--------------------------------------------------------------------------------
/assets/pool_composite.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/pool_composite.gif
--------------------------------------------------------------------------------
/assets/pool_sim.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/assets/pool_sim.gif
--------------------------------------------------------------------------------
/custom_data/balls_shelf/original.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/custom_data/balls_shelf/original.png
--------------------------------------------------------------------------------
/custom_data/boxes/original.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/custom_data/boxes/original.png
--------------------------------------------------------------------------------
/custom_data/kitchen/original.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/custom_data/kitchen/original.png
--------------------------------------------------------------------------------
/custom_data/table/original.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/custom_data/table/original.png
--------------------------------------------------------------------------------
/custom_data/wall_toy/original.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/custom_data/wall_toy/original.png
--------------------------------------------------------------------------------
/data/balls/depth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/depth.npy
--------------------------------------------------------------------------------
/data/balls/inpaint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/inpaint.png
--------------------------------------------------------------------------------
/data/balls/intermediate/albedo_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/intermediate/albedo_vis.png
--------------------------------------------------------------------------------
/data/balls/intermediate/depth_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/intermediate/depth_vis.png
--------------------------------------------------------------------------------
/data/balls/intermediate/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/intermediate/mask.png
--------------------------------------------------------------------------------
/data/balls/intermediate/normal_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/intermediate/normal_vis.png
--------------------------------------------------------------------------------
/data/balls/intermediate/obj_movable.json:
--------------------------------------------------------------------------------
1 | {"background": false, "puck": true}
--------------------------------------------------------------------------------
/data/balls/intermediate/shading_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/intermediate/shading_vis.png
--------------------------------------------------------------------------------
/data/balls/intermediate/vis_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/intermediate/vis_mask.jpg
--------------------------------------------------------------------------------
/data/balls/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/mask.png
--------------------------------------------------------------------------------
/data/balls/normal.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/normal.npy
--------------------------------------------------------------------------------
/data/balls/original.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/original.png
--------------------------------------------------------------------------------
/data/balls/shading.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/shading.npy
--------------------------------------------------------------------------------
/data/balls/sim.yaml:
--------------------------------------------------------------------------------
1 | # general setting
2 | cat: balls
3 | # simulation setting
4 | animation_frames: 16
5 | size: [512, 512]
6 | num_steps: 200
7 | # enable for simulation visualization
8 | display: false
9 | save_snapshot: false
10 | gravity: 0
11 | init_velocity:
12 | 6: [40, -150]
13 | init_acc:
14 | # perception setting
15 | edge_list: [[[1, 1], [1, 511]], [[1, 511], [511, 511]], [[511, 511], [511, 1]], [
16 | [511, 1], [1, 1]]]
17 | obj_info:
18 | 1:
19 | label: puck
20 | mass: 1900
21 | density:
22 | friction: 0.1
23 | elasticity: 0.05
24 | primitive: circle
25 | 2:
26 | label: puck
27 | mass: 1900
28 | density:
29 | friction: 0.1
30 | elasticity: 0.05
31 | primitive: circle
32 | 3:
33 | label: puck
34 | mass: 1900
35 | density:
36 | friction: 0.1
37 | elasticity: 0.05
38 | primitive: circle
39 | 4:
40 | label: puck
41 | mass: 1900
42 | density:
43 | friction: 0.1
44 | elasticity: 0.05
45 | primitive: circle
46 | 5:
47 | label: puck
48 | mass: 1900
49 | density:
50 | friction: 0.1
51 | elasticity: 0.05
52 | primitive: circle
53 | 6:
54 | label: puck
55 | mass: 1900
56 | density:
57 | friction: 0.1
58 | elasticity: 0.05
59 | primitive: circle
60 | # diffusion setting (optional)
61 | denoise_strength: 0.5
62 |
--------------------------------------------------------------------------------
/data/balls/vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/balls/vis.png
--------------------------------------------------------------------------------
/data/car/inpaint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/car/inpaint.png
--------------------------------------------------------------------------------
/data/car/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/car/mask.png
--------------------------------------------------------------------------------
/data/car/normal.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/car/normal.npy
--------------------------------------------------------------------------------
/data/car/original.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/car/original.png
--------------------------------------------------------------------------------
/data/car/shading.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/car/shading.npy
--------------------------------------------------------------------------------
/data/car/sim.yaml:
--------------------------------------------------------------------------------
1 | # general setting
2 | cat: car
3 | # simulation setting
4 | animation_frames: 16
5 | size: [512, 512]
6 | num_steps: 100
7 | # enable for simulation visualization
8 | display: false
9 | save_snapshot: false
10 | gravity: 980
11 | ground_friction: 0.9
12 | ground_elasticity: 1.0
13 | init_velocity:
14 | 1: [-400, 0]
15 | init_acc:
16 | # perception setting
17 | edge_list: [[[0, 303], [512, 303]]]
18 | obj_info:
19 | 1:
20 | label: toy car
21 | mass:
22 | density: 0.003
23 | friction: 0.2
24 | elasticity: 0.05
25 | primitive: polygon
26 | # diffusion setting (optional)
27 | denoise_strength: 0.65
28 |
--------------------------------------------------------------------------------
/data/domino/depth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/depth.npy
--------------------------------------------------------------------------------
/data/domino/inpaint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/inpaint.png
--------------------------------------------------------------------------------
/data/domino/intermediate/albedo_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/intermediate/albedo_vis.png
--------------------------------------------------------------------------------
/data/domino/intermediate/depth_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/intermediate/depth_vis.png
--------------------------------------------------------------------------------
/data/domino/intermediate/normal_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/intermediate/normal_vis.png
--------------------------------------------------------------------------------
/data/domino/intermediate/obj_movable.json:
--------------------------------------------------------------------------------
1 | {"candle": true, "table": false, "background": false}
--------------------------------------------------------------------------------
/data/domino/intermediate/shading_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/intermediate/shading_vis.png
--------------------------------------------------------------------------------
/data/domino/intermediate/vis_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/intermediate/vis_mask.jpg
--------------------------------------------------------------------------------
/data/domino/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/mask.png
--------------------------------------------------------------------------------
/data/domino/normal.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/normal.npy
--------------------------------------------------------------------------------
/data/domino/original.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/original.png
--------------------------------------------------------------------------------
/data/domino/shading.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/shading.npy
--------------------------------------------------------------------------------
/data/domino/sim.yaml:
--------------------------------------------------------------------------------
1 | # general setting
2 | cat: domino
3 | # simulation setting
4 | animation_frames: 16
5 | size: [512, 512]
6 | num_steps: 160
7 | # enable for simulation visualization
8 | display: false
9 | save_snapshot: false
10 | gravity: 980
11 | init_velocity:
12 | 6: [-80, 0]
13 | init_acc:
14 | # perception setting
15 | edge_list: [[[0, 457], [512, 457]]]
16 | obj_info:
17 | 1:
18 | label: white cuboid
19 | mass:
20 | density: 2.5
21 | friction: 0.5
22 | elasticity: 0.01
23 | primitive: polygon
24 | 2:
25 | label: white cuboid
26 | mass:
27 | density: 2.5
28 | friction: 0.5
29 | elasticity: 0.01
30 | primitive: polygon
31 | 3:
32 | label: white cuboid
33 | mass:
34 | density: 2.5
35 | friction: 0.5
36 | elasticity: 0.01
37 | primitive: polygon
38 | 4:
39 | label: white cuboid
40 | mass:
41 | density: 2.5
42 | friction: 0.5
43 | elasticity: 0.01
44 | primitive: polygon
45 | 5:
46 | label: white cuboid
47 | mass:
48 | density: 2.5
49 | friction: 0.5
50 | elasticity: 0.01
51 | primitive: polygon
52 | 6:
53 | label: white cuboid
54 | mass:
55 | density: 2.5
56 | friction: 0.5
57 | elasticity: 0.01
58 | primitive: polygon
59 | # diffusion setting (optional)
60 | denoise_strength: 0.45
61 |
--------------------------------------------------------------------------------
/data/domino/vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/domino/vis.png
--------------------------------------------------------------------------------
/data/pig_ball/depth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/depth.npy
--------------------------------------------------------------------------------
/data/pig_ball/inpaint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/inpaint.png
--------------------------------------------------------------------------------
/data/pig_ball/intermediate/albedo_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/intermediate/albedo_vis.png
--------------------------------------------------------------------------------
/data/pig_ball/intermediate/bg_mask_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/intermediate/bg_mask_vis.png
--------------------------------------------------------------------------------
/data/pig_ball/intermediate/depth_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/intermediate/depth_vis.png
--------------------------------------------------------------------------------
/data/pig_ball/intermediate/edge_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/intermediate/edge_vis.png
--------------------------------------------------------------------------------
/data/pig_ball/intermediate/fg_mask_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/intermediate/fg_mask_vis.png
--------------------------------------------------------------------------------
/data/pig_ball/intermediate/normal_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/intermediate/normal_vis.png
--------------------------------------------------------------------------------
/data/pig_ball/intermediate/obj_movable.json:
--------------------------------------------------------------------------------
1 | {"ball": true, "shelf": false, "piggybank": true, "book": true, "wall": false, "floor": false, "baseboard": false}
--------------------------------------------------------------------------------
/data/pig_ball/intermediate/shading_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/intermediate/shading_vis.png
--------------------------------------------------------------------------------
/data/pig_ball/intermediate/vis_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/intermediate/vis_mask.jpg
--------------------------------------------------------------------------------
/data/pig_ball/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/mask.png
--------------------------------------------------------------------------------
/data/pig_ball/normal.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/normal.npy
--------------------------------------------------------------------------------
/data/pig_ball/original.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/original.png
--------------------------------------------------------------------------------
/data/pig_ball/shading.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/shading.npy
--------------------------------------------------------------------------------
/data/pig_ball/sim.yaml:
--------------------------------------------------------------------------------
1 | # general setting
2 | cat: pig_ball
3 | # simulation setting
4 | animation_frames: 16
5 | size: [512, 512]
6 | num_steps: 220
7 | # enable for simulation visualization
8 | display: false
9 | save_snapshot: false
10 | gravity: 980
11 | init_velocity: {}
12 | init_acc:
13 | 3: [-100, 0]
14 | # perception setting
15 | edge_list: [[[409, 154], [512, 154]], [[182, 186], [512, 186]], [[0, 512], [512, 512]]]
16 | obj_info:
17 | 1:
18 | label: ball
19 | mass: 100
20 | density:
21 | friction: 0.7
22 | elasticity: 0.8
23 | primitive: circle
24 | 3:
25 | label: toy
26 | mass:
27 | density: 1.0
28 | friction: 0.5
29 | elasticity: 0.3
30 | primitive: polygon
31 | # diffusion setting (optional)
32 | denoise_strength: 0.3
33 |
--------------------------------------------------------------------------------
/data/pig_ball/vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pig_ball/vis.png
--------------------------------------------------------------------------------
/data/pool/depth.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/depth.npy
--------------------------------------------------------------------------------
/data/pool/inpaint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/inpaint.png
--------------------------------------------------------------------------------
/data/pool/intermediate/albedo_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/intermediate/albedo_vis.png
--------------------------------------------------------------------------------
/data/pool/intermediate/depth_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/intermediate/depth_vis.png
--------------------------------------------------------------------------------
/data/pool/intermediate/normal_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/intermediate/normal_vis.png
--------------------------------------------------------------------------------
/data/pool/intermediate/obj_movable.json:
--------------------------------------------------------------------------------
1 | {"billiardball": true, "table": false}
--------------------------------------------------------------------------------
/data/pool/intermediate/shading_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/intermediate/shading_vis.png
--------------------------------------------------------------------------------
/data/pool/intermediate/vis_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/intermediate/vis_mask.jpg
--------------------------------------------------------------------------------
/data/pool/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/mask.png
--------------------------------------------------------------------------------
/data/pool/normal.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/normal.npy
--------------------------------------------------------------------------------
/data/pool/original.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/original.png
--------------------------------------------------------------------------------
/data/pool/shading.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/shading.npy
--------------------------------------------------------------------------------
/data/pool/sim.yaml:
--------------------------------------------------------------------------------
1 | # general setting
2 | cat: pool
3 | # simulation setting
4 | animation_frames: 16
5 | size: [512, 512]
6 | num_steps: 100
7 | # enable for simulation visualization
8 | display: false
9 | save_snapshot: false
10 | gravity: 0
11 | init_velocity:
12 | 1: [-100, 0]
13 | init_acc:
14 | 1: [-5882.35, 0.0]
15 |
16 | # perception setting
17 | edge_list: [[[1, 1], [1, 511]], [[1, 511], [511, 511]], [[511, 511], [511, 1]], [
18 | [511, 1], [1, 1]]]
19 | obj_info:
20 | 1:
21 | label: billiard
22 | mass: 170
23 | density:
24 | friction: 0.3
25 | elasticity: 0.7
26 | primitive: circle
27 | 10:
28 | label: billiard
29 | mass: 170
30 | density:
31 | friction: 0.3
32 | elasticity: 0.7
33 | primitive: circle
34 | 11:
35 | label: billiard
36 | mass: 170
37 | density:
38 | friction: 0.3
39 | elasticity: 0.7
40 | primitive: circle
41 | 12:
42 | label: billiard
43 | mass: 170
44 | density:
45 | friction: 0.3
46 | elasticity: 0.7
47 | primitive: circle
48 | 13:
49 | label: billiard
50 | mass: 170
51 | density:
52 | friction: 0.3
53 | elasticity: 0.7
54 | primitive: circle
55 | 14:
56 | label: billiard
57 | mass: 170
58 | density:
59 | friction: 0.3
60 | elasticity: 0.7
61 | primitive: circle
62 | 15:
63 | label: billiard
64 | mass: 170
65 | density:
66 | friction: 0.3
67 | elasticity: 0.7
68 | primitive: circle
69 | 16:
70 | label: billiard
71 | mass: 170
72 | density:
73 | friction: 0.3
74 | elasticity: 0.7
75 | primitive: circle
76 | 2:
77 | label: billiard
78 | mass: 170
79 | density:
80 | friction: 0.3
81 | elasticity: 0.7
82 | primitive: circle
83 | 3:
84 | label: billiard
85 | mass: 170
86 | density:
87 | friction: 0.3
88 | elasticity: 0.7
89 | primitive: circle
90 | 4:
91 | label: billiard
92 | mass: 170
93 | density:
94 | friction: 0.3
95 | elasticity: 0.7
96 | primitive: circle
97 | 5:
98 | label: billiard
99 | mass: 170
100 | density:
101 | friction: 0.3
102 | elasticity: 0.7
103 | primitive: circle
104 | 6:
105 | label: billiard
106 | mass: 170
107 | density:
108 | friction: 0.3
109 | elasticity: 0.7
110 | primitive: circle
111 | 7:
112 | label: billiard
113 | mass: 170
114 | density:
115 | friction: 0.3
116 | elasticity: 0.7
117 | primitive: circle
118 | 8:
119 | label: billiard
120 | mass: 170
121 | density:
122 | friction: 0.3
123 | elasticity: 0.7
124 | primitive: circle
125 | 9:
126 | label: billiard
127 | mass: 170
128 | density:
129 | friction: 0.3
130 | elasticity: 0.7
131 | primitive: circle
132 | # diffusion setting (optional)
133 | denoise_strength: 0.2
134 |
--------------------------------------------------------------------------------
/data/pool/vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/data/pool/vis.png
--------------------------------------------------------------------------------
/diffusion/video_diffusion.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 | from omegaconf import OmegaConf
5 | import argparse
6 | from PIL import Image
7 | import torch
8 | import torch.nn.functional as F
9 | import torchvision
10 | from torchvision import transforms
11 | torch.backends.cuda.matmul.allow_tf32 = True
12 | torch.backends.cudnn.allow_tf32 = True
13 | from einops import rearrange
14 | from tqdm.auto import tqdm
15 | from diffusers.models import AutoencoderKL
16 | from diffusers.utils.import_utils import is_xformers_available
17 |
18 |
19 | sys.path.insert(0, "SEINE")
20 | from datasets import video_transforms
21 | from diffusion import create_diffusion
22 | from models.clip import TextEmbedder
23 | from models import get_models
24 | from utils import mask_generation_before
25 |
26 |
27 | def prepare_input(relight_video_path, mask_video_path,
28 | image_h, image_w, latent_h, latent_w, device, use_fp16):
29 | relight_video = torch.load(relight_video_path).to(device) # (f, 3, H, W)
30 | if relight_video.max() > 1.:
31 | relight_video = relight_video / 255.0
32 | if relight_video.shape[1] !=3: # (f, H, W, 3) -> (f, 3, H, W)
33 | relight_video = relight_video.permute(0, 3, 1, 2) # (f, 3, H, W)
34 | if relight_video.shape[-2:] != (image_h, image_w):
35 | relight_video = F.interpolate(relight_video, size=(image_h, image_w))
36 | relight_video = 2 * relight_video - 1 # [-1, 1]
37 | if use_fp16:
38 | relight_video = relight_video.to(dtype=torch.float16)
39 |
40 | mask_video = torch.load(mask_video_path).to(device).float() # (f, 1, H, W)
41 | if mask_video.ndim == 3: # (f, H, W) -> (f, 4, H, W)
42 | mask_video = mask_video.unsqueeze(1).repeat(1, 4, 1, 1)
43 | elif mask_video.ndim == 4 and mask_video.shape[1] == 1: # (f, 1, H, W) -> (f, 4, H, W)
44 | mask_video = mask_video.repeat(1, 4, 1, 1)
45 | elif mask_video.ndim == 4 and mask_video.shape[0] == 1: # (1, f, H, W) -> (f, 4, H, W)
46 | mask_video = mask_video.repeat(4, 1, 1, 1).permute(1, 0, 2, 3) # (f, 4, H, W)
47 | if mask_video.shape[-2:] != (latent_h, latent_w):
48 | mask_video = F.interpolate(mask_video, size=(latent_h, latent_w)) # (f, 4, h, w)
49 | mask_video = mask_video > 0.5
50 | mask_video = rearrange(mask_video, 'f c h w -> c f h w').contiguous()
51 | mask_video = mask_video.unsqueeze(0) # (1, 4, f, h, w)
52 | return relight_video, mask_video
53 |
54 |
55 | def get_input(image_path, image_h, image_w, mask_type="first1", num_frames=16):
56 | transform_video = transforms.Compose([
57 | video_transforms.ToTensorVideo(),
58 | video_transforms.ResizeVideo((image_h, image_w)),
59 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)])
60 | print(f'loading video from {image_path}')
61 | _, full_file_name = os.path.split(image_path)
62 | file_name, extension = os.path.splitext(full_file_name)
63 | if extension == '.jpg' or extension == '.jpeg' or extension == '.png':
64 | print("loading the input image")
65 | video_frames = []
66 | num = int(mask_type.split('first')[-1])
67 | first_frame = torch.as_tensor(np.array(Image.open(image_path), dtype=np.uint8, copy=True)).unsqueeze(0)
68 | for i in range(num):
69 | video_frames.append(first_frame)
70 | num_zeros = num_frames-num
71 | for i in range(num_zeros):
72 | zeros = torch.zeros_like(first_frame)
73 | video_frames.append(zeros)
74 | n = 0
75 | video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
76 | video_frames = transform_video(video_frames)
77 | return video_frames, n
78 |
79 |
80 |
81 | if __name__ == "__main__":
82 | parser = argparse.ArgumentParser()
83 | # Model checkpoint
84 | parser.add_argument("--ckpt", default="SEINE/pretrained/seine.pt")
85 | parser.add_argument("--pretrained_model_path", default="SEINE/pretrained/stable-diffusion-v1-4")
86 | # Model config
87 | parser.add_argument('--model', type=str, default='UNet', help='Model architecture to use.')
88 | parser.add_argument('--num_frames', type=int, default=16, help='Number of frames to process.')
89 | parser.add_argument('--image_size', type=int, nargs=2, default=[512, 512], help='Resolution of the input images.')
90 |
91 | # Model speedup config
92 | parser.add_argument('--use_fp16', type=bool, default=True, help='Use FP16 for faster inference. Set to False if debugging with video loss.')
93 | parser.add_argument('--enable_xformers_memory_efficient_attention', type=bool, default=True, help='Enable xformers memory efficient attention.')
94 |
95 | # Sample config
96 | parser.add_argument('--seed', type=int, default=None, help='Random seed for reproducibility.')
97 | parser.add_argument('--run_time', type=int, default=13, help='Run time of the model.')
98 | parser.add_argument('--cfg_scale', type=float, default=8.0, help='Configuration scale factor.')
99 | parser.add_argument('--sample_method', type=str, default='ddim', choices=['ddim', 'ddpm'], help='Sampling method to use.')
100 | parser.add_argument('--do_classifier_free_guidance', type=bool, default=True, help='Enable classifier-free guidance.')
101 | parser.add_argument('--mask_type', type=str, default="first1", help='Type of mask to use.')
102 | parser.add_argument('--use_mask', type=bool, default=True, help='Whether to use a mask.')
103 | parser.add_argument('--num_sampling_steps', type=int, default=50, help='Number of sampling steps to perform.')
104 | parser.add_argument('--prompt', type=str, default=None, help='input prompt')
105 | parser.add_argument('--negative_prompt', type=str, default="", help='Negative prompt to use.')
106 |
107 | # Video diffusion config
108 | parser.add_argument('--denoise_strength', type=float, default=0.4, help='Denoise strength parameter.')
109 | parser.add_argument('--stop_idx', type=int, default=5, help='Stop index for the diffusion process.')
110 | parser.add_argument('--perception_input', type=str, default="../data/pool", help='input dir')
111 | parser.add_argument('--previous_output', type=str, default="../outputs/pool", help='previous output dir')
112 |
113 | # Parse the arguments
114 | args = parser.parse_args()
115 |
116 | output = args.previous_output
117 | os.makedirs(output, exist_ok=True)
118 |
119 | if args.seed:
120 | torch.manual_seed(args.seed)
121 | torch.set_grad_enabled(False)
122 | device = "cuda" if torch.cuda.is_available() else "cpu"
123 | if args.ckpt is None:
124 | raise ValueError("Please specify a checkpoint path using --ckpt ")
125 |
126 | latent_h = args.image_size[0] // 8
127 | latent_w = args.image_size[1] // 8
128 | image_h = args.image_size[0]
129 | image_w = args.image_size[1]
130 | latent_h = latent_h
131 | latent_w = latent_w
132 | print('loading model')
133 | model = get_models(args).to(device)
134 |
135 | if args.enable_xformers_memory_efficient_attention:
136 | if is_xformers_available():
137 | model.enable_xformers_memory_efficient_attention()
138 | else:
139 | raise ValueError("xformers is not available. Make sure it is installed correctly")
140 |
141 | ckpt_path = args.ckpt
142 | state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema']
143 | model.load_state_dict(state_dict)
144 | print('loading succeed')
145 | model.eval()
146 |
147 | pretrained_model_path = args.pretrained_model_path
148 | diffusion = create_diffusion(str(args.num_sampling_steps))
149 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
150 | text_encoder = TextEmbedder(pretrained_model_path).to(device)
151 |
152 | sim_config_path = os.path.join(args.perception_input, "sim.yaml")
153 | config = OmegaConf.load(sim_config_path)
154 | objects = config.obj_info
155 | object_names = []
156 | for seg_id in objects:
157 | name = objects[seg_id]['label']
158 | object_names.append(name)
159 | object_names = list(set(object_names))
160 | if args.prompt is None:
161 | prompt = ", ".join(object_names)
162 | else:
163 | prompt = args.prompt
164 | print(f"input prompt: {prompt}")
165 | denoise_strength = getattr(config, 'denoise_strength', args.denoise_strength)
166 |
167 | if args.use_fp16:
168 | print('Warnning: using half percision for inferencing!')
169 | vae.to(dtype=torch.float16)
170 | model.to(dtype=torch.float16)
171 | text_encoder.to(dtype=torch.float16)
172 |
173 |
174 | relight_video_path = os.path.join(args.previous_output, "relight.pt")
175 | mask_video_path = os.path.join(args.previous_output, "mask_video.pt")
176 | relight_video, mask_video = prepare_input(relight_video_path, mask_video_path, image_h, image_w, latent_h, latent_w, device, args.use_fp16)
177 |
178 | with torch.no_grad():
179 | ref_latent = vae.encode(relight_video).latent_dist.sample().mul_(0.18215) # (f, 4, h, w)
180 | ref_latent = ref_latent.permute(1, 0, 2, 3).contiguous().unsqueeze(0) # (1, 4, f, h, w)
181 |
182 | image_path = os.path.join(args.perception_input, "original.png")
183 | video, reserve_frames = get_input(image_path, image_h, image_w, args.mask_type, args.num_frames)
184 | video_input = video.unsqueeze(0).to(device) # b,f,c,h,w
185 | b,f,c,h,w=video_input.shape
186 | mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w [1, 16, 3, 512, 512])
187 | masked_video = video_input * (mask == 0)
188 |
189 | if args.use_fp16:
190 | masked_video = masked_video.to(dtype=torch.float16)
191 | mask = mask.to(dtype=torch.float16)
192 |
193 | masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous()
194 | masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215)
195 | masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous()
196 | mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1)
197 |
198 | if args.do_classifier_free_guidance:
199 | masked_video = torch.cat([masked_video] * 2)
200 | mask = torch.cat([mask] * 2)
201 | prompt_all = [prompt] + [args.negative_prompt]
202 | else:
203 | masked_video = masked_video
204 | mask = mask
205 | prompt_all = [prompt]
206 |
207 | text_prompt = text_encoder(text_prompts=prompt_all, train=False)
208 | model_kwargs = dict(encoder_hidden_states=text_prompt,
209 | class_labels=None,
210 | cfg_scale=args.cfg_scale,
211 | use_fp16=args.use_fp16) # tav unet
212 |
213 | indices = list(range(diffusion.num_timesteps))[::-1]
214 | noise_level = int(denoise_strength * diffusion.num_timesteps)
215 | indices = indices[-noise_level:]
216 |
217 | latent = diffusion.q_sample(ref_latent, torch.tensor([indices[0]], device=device))
218 | latent = torch.cat([latent] * 2)
219 | stop_idx = args.stop_idx
220 |
221 | for idx, i in tqdm(enumerate(indices)):
222 | t = torch.tensor([indices[idx]] * masked_video.shape[0], device=device)
223 | with torch.no_grad():
224 | out = diffusion.ddim_sample(
225 | model.forward_with_cfg,
226 | latent,
227 | t,
228 | clip_denoised=False,
229 | denoised_fn=None,
230 | cond_fn=None,
231 | model_kwargs=model_kwargs,
232 | eta=0.0,
233 | mask=mask,
234 | x_start=masked_video,
235 | use_concat=args.use_mask,
236 | )
237 |
238 | # update latent
239 | latent = out["sample"]
240 | if idx < len(indices)-stop_idx:
241 | x = diffusion.q_sample(ref_latent, torch.tensor([indices[idx+1]], device=device))
242 | pred_xstart = out["pred_xstart"]
243 |
244 | weight = min(idx / len(indices), 1.0)
245 | latent = (1 - mask_video.float()) * latent + mask_video.float() * ((1 - weight) * x + weight * latent)
246 |
247 | if args.use_fp16:
248 | latent = latent.to(dtype=torch.float16)
249 | latent = latent[0].permute(1, 0, 2, 3).contiguous() # (f, 4, h, w)
250 | video_output = vae.decode(latent / 0.18215).sample
251 | video_ = ((video_output * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
252 | save_video_path = os.path.join(output, f'final_video.mp4')
253 | torchvision.io.write_video(save_video_path, video_, fps=7)
254 | print("all done!")
--------------------------------------------------------------------------------
/perception/README.md:
--------------------------------------------------------------------------------
1 | # Perception
2 |
3 | ## Segmentation
4 |
5 | ### GPT-based Recognize Anything
6 | - We use GPT-4V to recognize objects in the image as input prompt for [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). Alternatively, you could use any other object recognition model (e.g. [RAM](https://github.com/xinyu1205/recognize-anything)) to get the objects in the given image.
7 |
8 | - We put pre-computed GPT-4V result under each `data/${name}/obj_movable.json`. You could skip below and run [segmentation](#grounded-segment-anything) if you don't want to re-run GPT-4V.
9 |
10 | - Copy the OpenAI API key into `gpt/gpt_configs/my_apikey`.
11 |
12 | - Install requirements
13 | ```bash
14 | pip install inflect openai==0.28
15 | ```
16 | - Run GPT-4V based RAM
17 | ```Shell
18 | python gpt_ram.py --img_path ../data/${name}
19 | ```
20 | - The default `save_path` is saved under same folder as input `../data/${name}/intermediate/obj_movable.json`. The output is a list in json format.
21 |
22 | ```shell
23 | [
24 | {"obj_1": True # True if the object is movable or False if not},
25 | {"obj_2": True},
26 | ...
27 | ]
28 | ```
29 |
30 | ### Grounded-Segment-Anything
31 | We use [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/753dd6675ea7935af401a983e88d159629ad4d5b) to segment the input image given the prompts input. We use the earlier checkout version for the paper. You could adapt the code to the latest version of [Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything) or [Grounded-SAM-2](https://github.com/IDEA-Research/Grounded-SAM-2).
32 |
33 | - Follow the [Grounded-SAM setup](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/753dd6675ea7935af401a983e88d159629ad4d5b?tab=readme-ov-file#install-without-docker)
34 | ```Shell
35 | cd Grounded-Segment-Anything/
36 | git checkout 753dd6675ea7935af401a983e88d159629ad4d5b
37 |
38 | # Follow Grounded-SAM readme to install requirements
39 |
40 | # Download pretrained weights to current folder
41 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
42 | wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
43 |
44 | ```
45 | - Segmentation requires the input image `input` and a prompt `prompts_path` for each object in the image. The default prompt path is `../data/${name}/intermediate/obj_movable.json`.
46 |
47 | ```Shell
48 | python run_gsam.py --input ../data/${name}
49 | ```
50 | - The default `output` is saved under the same folder as input `../data/${name}` and visualizations under `../data/${name}/intermediate` as follows:
51 | ```Shell
52 | image folder/
53 | ├── intermediate/
54 | ├── mask.png # FG and BG segmentation mask
55 | ├── mask.json # segmentation id and and object name, movability
56 | ├── vis_mask.jpg # segmentation visualization
57 | ```
58 | | **Pool** | **Domino** | **Pig Ball** | **Balls**
59 | |:---------:|:----------------:|:----------:| :----------:|
60 | |
|
|
|
|
61 | |
|
|
|
|
62 |
63 |
64 | ## Depth and Normal Estimation
65 | - We use [GeoWizard](https://github.com/fuxiao0719/GeoWizard) to estimate depth and normal of input image. Follow [GeoWizard setup](https://github.com/fuxiao0719/GeoWizard/blob/main/README.md#%EF%B8%8F-setup) to install requirements. Recommend to create a new conda environment.
66 |
67 | - Run GeoWizard on input image
68 | ```Shell
69 | python run_depth_normal.py --input ../data/${name} --output ../outputs/${name} --vis
70 | ```
71 | - Depth and normal are saved in `outputs/${name}`. Visualizations are saved in `outputs/${name}/intermediate`.
72 | ```Shell
73 | image folder/
74 | ├──depth.npy
75 | ├── normal.npy
76 | ├── intermediate/
77 | ├── depth_vis.png
78 | ├── normal_vis.png
79 | ```
80 |
81 | | **Input** | **Normal** | **Depth**
82 | |:---------:|:----------------:|:----------:|
83 | |
|
|
|
84 |
85 |
86 | ## Foreground / Background & Edge Detection
87 | - We separate the foreground and background using the segmentation mask. Foreground objects with complete masks are used for physics reasoning and simulation, while truncated objects are treated as static. We use edges from static objects and the background as physical boundaries for simulation.
88 | - Module requires the following input for each image:
89 | ```Shell
90 | image folder/
91 | ├── depth.npy
92 | ├──normal.npy
93 | ├── original.png # optional:for visualization only
94 | ├── intermediate/
95 | ├── mask.png # complete image segmentation mask
96 | ├── mask.json # segmentation id and and object name, movability
97 | ```
98 | - Run foreground/background separation and edge detection
99 | ```Shell
100 | python run_fg_bg.py --input ../data/${name} --vis_edge
101 | ```
102 | - The default output is saved under the same folder as input `../data/${name}`, contains the final foreground objects mask `mask.png` and edge list `edges.json` saved in `outputs/${name}`.
103 | ```Shell
104 | image folder/
105 | ├── mask.png # final mask
106 | ├── edges.json
107 | ├── intermediate/
108 | ├── edge_vis.png # red line for edges
109 | ├── fg_mask_vis.png # text is the segmentation id
110 | ├── bg_mask_vis.png
111 | ├── bg_mask.png
112 | ```
113 |
114 | | **Input** | **Foreground** | **Background** | **Edges**
115 | |:---------:|:----------------:|:----------:| :----------:|
116 | |
|
|
|
|
117 | - We could simulate all foreground objects by specifying their velocity and acceleration use the segmentation id in simulation.
118 |
119 |
120 |
121 | ## Inpainting
122 | We use [Inpaint-Anything](https://github.com/geekyutao/Inpaint-Anything) to inpaint the background of input image. You could adapt the code to any other latest inpainting model.
123 |
124 | - Follow [Inpaint-Anything setup](https://github.com/geekyutao/Inpaint-Anything?tab=readme-ov-file#installation) to install requirements and download the [pretrained model](https://drive.google.com/drive/folders/1wpY-upCo4GIW4wVPnlMh_ym779lLIG2A). Recommend to create a new conda environment.
125 |
126 | ```Shell
127 | python -m pip install torch torchvision torchaudio
128 | python -m pip install -e segment_anything
129 | python -m pip install -r lama/requirements.txt
130 | # Download pretrained model under Inpaint-Anything/pretrained_models/
131 | ```
132 |
133 | - Inpainting requires the input image `../data/${name}/original.png` and a foreground mask of `../data/${name}/mask.png` under the same folder.
134 | ```Shell
135 | python run_inpaint.py --input ../data/${name} --output ../outputs/${name} --dilate_kernel_size 20
136 | ```
137 | `dilate_kernel_size` could be adjusted in the above script. For heavy shadow image, increase `dilate_kernel_size` to get better inpainting results.
138 |
139 | - The output `inpaint.png` is saved in `outputs/${name}`.
140 |
141 | | **Input** | **Inpainting**
142 | |:---------:|:----------------:|
143 | |
|
|
144 |
145 |
146 | ## Physics Reasoning
147 |
148 | - Install requirements
149 | ```bash
150 | pip install openai==0.28 ruamel.yaml
151 | ```
152 | - Copy the OpenAI API key into `gpt/gpt_configs/my_apikey`.
153 |
154 | - Physics reasoning requires the following input for each image:
155 | ```Shell
156 | image folder/
157 | ├── original.png
158 | ├── mask.png # movable segmentation mask
159 | ```
160 | - Run GPT-4V physical property reasoning by the following command:
161 | ```Shell
162 | python gpt_physic.py --input ../data/${name} --output ../outputs/${name}
163 | ```
164 |
165 | - The output `physics.yaml` contains the physical properties and primitive shape of each object segment in the image. Note GPT-4V outputs may vary for different runs and differ from the original setting in `data/${name}/sim.yaml`. Users could adjust accordingly to each run output.
166 |
167 |
168 | ## Albedo and Shading Estimation
169 | - We use [Intrinsic](https://github.com/compphoto/Intrinsic/tree/d9741e99b2997e679c4055e7e1f773498b791288) to infer albedo and shading of input image. Follow [Intrinsic setup](https://github.com/compphoto/Intrinsic/tree/d9741e99b2997e679c4055e7e1f773498b791288?tab=readme-ov-file#setup) to install requirements. Recommend to create a new conda environment.
170 | ```
171 | git clone https://github.com/compphoto/Intrinsic
172 | cd Intrinsic/
173 | git checkout d9741e99b2997e679c4055e7e1f773498b791288
174 | pip install .
175 | ```
176 |
177 | - Run Intrinsic decomposition on input image
178 | ```Shell
179 | python run_albedo_shading.py --input ../data/${name} --output ../outputs/${name} --vis
180 | ```
181 |
182 | - `shading.npy` are saved in `outputs/${name}`. Visualization of albedo and shading are saved in `outputs/${name}/intermediate`.
183 |
184 | | **Input** | **Albedo** | **Shading**
185 | |:---------:|:----------------:|:----------:|
186 | |
|
|
|
187 |
188 | - [Intrinsic](https://github.com/compphoto/Intrinsic) has released updated trained model with better results. Feel free to use the updated model or any other model for better performance.
189 |
--------------------------------------------------------------------------------
/perception/gpt/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/perception/gpt/__init__.py
--------------------------------------------------------------------------------
/perception/gpt/gpt_configs/my_apikey:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stevenlsw/physgen/67fe0f2909bc332f0d8f1da1cd94c5fe245a9bfe/perception/gpt/gpt_configs/my_apikey
--------------------------------------------------------------------------------
/perception/gpt/gpt_configs/physics/user.txt:
--------------------------------------------------------------------------------
1 | You will be given an image and a binary mask specifying an object on the image, analyze and provide your final answer of the object physical property. The query object will be enclosed in white mask. The physical property includes the mass, the friction and elasticity. The mass is in grams. The friction uses the Coulomb friction model, a value of 0.0 is frictionless. The elasticity value of 0.0 gives no bounce, while a value of 1.0 will give a perfect bounce.
2 |
3 | Format Requirement:
4 | You must provide your answer in the following JSON format, as it will be parsed by a code script later. Your answer must look like:
5 | {
6 | "mass": number,
7 | "friction": number,
8 | "elasticity": number
9 | }
10 | The answer should be one exact number for each property, do not include any other text in your answer, as it will be parsed by a code script later.
--------------------------------------------------------------------------------
/perception/gpt/gpt_configs/ram/user.txt:
--------------------------------------------------------------------------------
1 | Describe all unique object categories in the given image, ensuring all pixels are included and assigned to one of the categories, do not miss any movable or static object appeared in the image, each category name is a single word and in singular noun format, do not include '-' in the name. Different categories should not be repeated or overlapped with each other in the image. For each category, judge if the instances in the image is movable, the answer is True or False. If there are multiple instances of the same category in the image, the judgement is True only if the object category satisfies the following requirements: 1. The object category is things (objects with a well-defined shape, e.g. car, person) and not stuff (amorphous background regions, e.g. grass, sky, largest segmentation component). 2. All instances in the image of this category are movable with complete shape and fully-visible.
2 |
3 | Format Requirement:
4 | You must provide your answer in the following JSON format, as it will be parsed by a code script later. Your answer must look like:
5 | {
6 | "category-1": False,
7 | "category-2": True
8 |
9 | }
10 | Do not include any other text in your answer. Do not include unnecessary words besides the category name and True/False values.
--------------------------------------------------------------------------------
/perception/gpt/gpt_utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from re import DOTALL, finditer
4 | import base64
5 | from io import BytesIO
6 | from PIL import Image
7 | from imantics import Mask
8 |
9 |
10 | def fit_polygon_from_mask(mask):
11 | # return polygon vertices
12 | polygons = Mask(mask).polygons()
13 | if len(polygons.points) > 1: # find the largest polygon
14 | areas = []
15 | for points in polygons.points:
16 | points = points.reshape((-1, 1, 2)).astype(np.int32)
17 | img = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
18 | img = cv2.fillPoly(img, [points], color=[0, 255, 0])
19 | mask = img[:, :, 1] > 0
20 | area = np.count_nonzero(mask)
21 | areas.append(area)
22 | areas = np.array(areas)
23 | largest_idx = np.argmax(areas)
24 | points = polygons.points[largest_idx]
25 | else:
26 | points = polygons.points[0]
27 | return points
28 |
29 |
30 | # same function in simulation/sim_utils.py
31 | def fit_circle_from_mask(mask_image):
32 | contours, _ = cv2.findContours(mask_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
33 |
34 | if len(contours) == 0:
35 | print("No contours found in the mask image.")
36 | return None
37 | max_contour = max(contours, key=cv2.contourArea)
38 | (x, y), radius = cv2.minEnclosingCircle(max_contour)
39 |
40 | center = (x, y)
41 | return center, radius
42 |
43 |
44 | def is_mask_truncated(mask):
45 | if np.any(mask[0, :] == 1) or np.any(mask[-1, :] == 1): # Top or bottom rows
46 | return True
47 | if np.any(mask[:, 0] == 1) or np.any(mask[:, -1] == 1): # Left or right columns
48 | return True
49 | return False
50 |
51 |
52 | def compute_iou(mask1, mask2):
53 | intersection = np.logical_and(mask1, mask2).sum()
54 | union = np.logical_or(mask1, mask2).sum()
55 | iou = intersection / union if union > 0 else 0.0
56 |
57 | return iou
58 |
59 |
60 | def find_json_response(full_response):
61 | extracted_responses = list(
62 | finditer(r"({[^}]*$|{.*})", full_response, flags=DOTALL)
63 | )
64 |
65 | if not extracted_responses:
66 | print(
67 | f"Unable to find any responses of the matching type dictionary: `{full_response}`"
68 | )
69 | return None
70 |
71 | if len(extracted_responses) > 1:
72 | print("Unexpected response > 1, continuing anyway...", extracted_responses)
73 |
74 | extracted_response = extracted_responses[0]
75 | extracted_str = extracted_response.group(0)
76 | return extracted_str
77 |
78 |
79 | def encode_image(image_path):
80 | with open(image_path, "rb") as image_file:
81 | return base64.b64encode(image_file.read()).decode("utf-8")
--------------------------------------------------------------------------------
/perception/gpt_physic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ast
3 | from time import sleep
4 | import cv2
5 | import numpy as np
6 | import pymunk
7 | import openai
8 | from ruamel.yaml import YAML
9 | from ruamel.yaml.comments import CommentedMap
10 |
11 | from gpt.gpt_utils import find_json_response, encode_image, fit_polygon_from_mask, fit_circle_from_mask, compute_iou, is_mask_truncated
12 |
13 |
14 | class GPTV_physics:
15 | def __init__(self, query_prompt="gpt/gpt_configs/physics/user.txt", retry_limit=0):
16 | with open(query_prompt, "r") as file:
17 | self.query = file.read().strip()
18 | self.retry_limit = retry_limit
19 |
20 | def call(self, image_path, mask, max_tokens=300, tmp_dir="./"):
21 | mask_path = os.path.join(tmp_dir, "tmp_msk.png")
22 | mask = (mask * 255).astype(np.uint8)
23 | cv2.imwrite(mask_path, mask)
24 | query_image = encode_image(image_path)
25 | mask_image = encode_image(mask_path)
26 |
27 | try_count = 0
28 | while True:
29 | response = openai.ChatCompletion.create(
30 | model="gpt-4-vision-preview",
31 | messages=[{
32 | "role": "system",
33 | "content": self.query
34 | },
35 | {
36 | "role": "user",
37 | "content": [
38 | {
39 | "type": "image_url",
40 | "image_url": {
41 | "url": f"data:image/{image_path[-3:]};base64,{query_image}"
42 | }
43 | },
44 | {
45 | "type": "image_url",
46 | "image_url": {
47 | "url": f"data:image/{mask_path[-3:]};base64,{mask_image}"
48 | }
49 | },
50 | ]
51 | }
52 | ],
53 | seed=100,
54 | max_tokens=max_tokens,
55 | )
56 | response = response["choices"][0]["message"]["content"]
57 | try:
58 | result = find_json_response(response)
59 | result = ast.literal_eval(result.replace(' ', '').replace('\n', ''))
60 | break
61 | except:
62 | print(f"Unknown response: {response}")
63 | try_count += 1
64 | if try_count > self.retry_limit:
65 | raise ValueError(f"Over Limit: Unknown response: {response}")
66 | else:
67 | print("Retrying after 1s.")
68 | sleep(1)
69 | os.remove(mask_path)
70 | return result
71 |
72 |
73 | if __name__ == "__main__":
74 | import argparse
75 | parser = argparse.ArgumentParser()
76 | parser.add_argument("--input", type=str, default="../data/pig_ball")
77 | parser.add_argument("--output", type=str, default="../outputs/pig_ball")
78 | parser.add_argument("--apikey_path", type=str, default="gpt/gpt_configs/my_apikey")
79 | args = parser.parse_args()
80 |
81 | with open(args.apikey_path, "r") as file:
82 | apikey = file.read().strip()
83 |
84 | openai.api_key = apikey
85 | gpt = GPTV_physics(query_prompt="gpt/gpt_configs/physics/user.txt")
86 |
87 | image_path = os.path.join(args.input, "original.png")
88 | mask_path = os.path.join(args.input, "mask.png")
89 | save_dir = args.output
90 | os.makedirs(save_dir, exist_ok=True)
91 | save_path = os.path.join(save_dir, "physics.yaml")
92 |
93 | seg_mask = cv2.imread(mask_path, 0)
94 | seg_ids = np.unique(seg_mask)
95 | obj_info_list = {}
96 | for seg_id in seg_ids:
97 | if seg_id == 0:
98 | continue
99 | obj_info = {}
100 | mask = (seg_mask == seg_id)
101 | # fit primitive
102 | center, radius = fit_circle_from_mask((mask * 255).astype(np.uint8))
103 | center = tuple(map(int, center))
104 | radius = int(radius)
105 | pred_mask = np.zeros(mask.shape, dtype=np.uint8)
106 | cv2.circle(pred_mask, center, radius, (255), thickness=-1)
107 | area = np.count_nonzero(mask)
108 | pred_mask = pred_mask > 0
109 | iou = compute_iou(mask, pred_mask)
110 | if iou > 0.85:
111 | obj_info["primitive"] = 'circle'
112 | else:
113 | if is_mask_truncated(mask):
114 | continue
115 | obj_info["primitive"] = 'polygon'
116 | points = fit_polygon_from_mask(mask)
117 | points = tuple(map(tuple, points))
118 | polygon = pymunk.Poly(None, points)
119 | area = polygon.area
120 |
121 | result = gpt.call(image_path, mask)
122 | for key in result:
123 | if key == "mass":
124 | if obj_info['primitive'] == 'polygon':
125 | density = result["mass"] / area
126 | obj_info["mass"] = None
127 | obj_info["density"] = density
128 | else:
129 | obj_info["mass"] = result["mass"]
130 | obj_info["density"] = None
131 | else:
132 | obj_info[key] = result[key]
133 |
134 | obj_info_list[int(seg_id)] = obj_info
135 |
136 | yaml = YAML()
137 | yaml_data = CommentedMap()
138 |
139 | yaml_data['obj_info'] = obj_info_list
140 | yaml_data.yaml_set_comment_before_after_key('obj_info', before="physics properties of each object")
141 |
142 | with open(save_path, 'w') as yaml_file:
143 | yaml.dump(yaml_data, yaml_file)
144 |
145 | print(f"GPT-4V Physics reasoning results saved to {save_path}")
--------------------------------------------------------------------------------
/perception/gpt_ram.py:
--------------------------------------------------------------------------------
1 | import os
2 | import openai
3 | import ast
4 | from time import sleep
5 | import json
6 | from gpt.gpt_utils import find_json_response, encode_image
7 |
8 |
9 | class GPTV_ram:
10 | def __init__(self, query_prompt="gpt/gpt_configs/movable/user.txt", retry_limit=3):
11 |
12 | with open(query_prompt, "r") as file:
13 | self.query = file.read().strip()
14 | self.retry_limit = retry_limit
15 |
16 |
17 | def call(self, image_path, max_tokens=300):
18 | query_image = encode_image(image_path)
19 | try_count = 0
20 | while True:
21 | response = openai.ChatCompletion.create(
22 | model="gpt-4-vision-preview",
23 | messages=[{
24 | "role": "system",
25 | "content": self.query
26 | },
27 | {
28 | "role": "user",
29 | "content": [
30 | {
31 | "type": "text",
32 | "text": f"{self.query}"
33 | },
34 | {
35 | "type": "image_url",
36 | "image_url": {
37 | "url": f"data:image/{image_path[-3:]};base64,{query_image}"
38 | }
39 | },
40 | ]
41 | }
42 | ],
43 | seed=100,
44 | max_tokens=max_tokens,
45 | )
46 | response = response["choices"][0]["message"]["content"]
47 | try:
48 | result = find_json_response(response)
49 | result = ast.literal_eval(result.replace(' ', '').replace('\n', ''))
50 | break
51 | except:
52 | try_count += 1
53 | if try_count > self.retry_limit:
54 | raise ValueError(f"Over Limit: Unknown response: {response}")
55 | else:
56 | print("Retrying after 1s.")
57 | sleep(1)
58 | return result
59 |
60 | if __name__ == "__main__":
61 | import argparse
62 | import inflect
63 | parser = argparse.ArgumentParser()
64 | parser.add_argument("--img_path", type=str, default="../data/domino/original.png")
65 | parser.add_argument("--save_path", type=str, default=None)
66 | parser.add_argument("--apikey_path", type=str, default="gpt/gpt_configs/my_apikey")
67 | args = parser.parse_args()
68 |
69 | with open(args.apikey_path, "r") as file:
70 | apikey = file.read().strip()
71 |
72 | openai.api_key = apikey
73 | gpt = GPTV_ram(query_prompt="gpt/gpt_configs/ram/user.txt")
74 | result = gpt.call(args.img_path)
75 |
76 | if args.save_path is None:
77 | save_dir = os.path.join(os.path.dirname(args.img_path), "intermediate")
78 | os.makedirs(save_dir, exist_ok=True)
79 | save_path = os.path.join(save_dir, "obj_movable.json")
80 | else:
81 | save_path = args.save_path
82 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
83 |
84 | p = inflect.engine()
85 | for obj_name in result:
86 | singular = p.singular_noun(obj_name)
87 | if singular:
88 | result[singular] = result.pop(obj_name)
89 | else:
90 | continue
91 |
92 | with open(save_path, "w") as file:
93 | json.dump(result, file)
94 |
95 | print("result:", result)
96 | print(f"GPT4V image movable objects results saved to {save_path}")
97 |
--------------------------------------------------------------------------------
/perception/run_albedo_shading.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import glob
4 | import numpy as np
5 | from tqdm.auto import tqdm
6 | from PIL import Image
7 |
8 |
9 | from intrinsic.pipeline import run_pipeline
10 | from intrinsic.model_util import load_models
11 |
12 |
13 | def load_image(path, bits=8):
14 | np_arr = np.array(Image.open(path)).astype(np.single)
15 | return np_arr / float((2 ** bits) - 1)
16 |
17 |
18 | def np_to_pil(img, bits=8):
19 | if bits == 8:
20 | int_img = (img * 255).astype(np.uint8)
21 | if bits == 16:
22 | int_img = (img * ((2 ** 16) - 1)).astype(np.uint16)
23 |
24 | return Image.fromarray(int_img)
25 |
26 |
27 | def view_scale(img, p=100):
28 | return (img / np.percentile(img, p)).clip(0, 1)
29 |
30 |
31 | def view(img, p=100):
32 | return view_scale(img ** (1/2.2), p=p)
33 |
34 |
35 | def uninvert(x, eps=0.001, clip=True):
36 | if clip:
37 | x = x.clip(eps, 1.0)
38 |
39 | out = (1.0 / x) - 1.0
40 | return out
41 |
42 |
43 | if __name__ == '__main__':
44 | parser = argparse.ArgumentParser()
45 | parser.add_argument("--input", type=str, default="../data/pig_ball", help="Input image or directory.")
46 | parser.add_argument("--output", type=str, default="../outputs/pig_ball", help="Output directory.")
47 | parser.add_argument("--vis", action="store_true", help="Visualize the results.")
48 |
49 | args = parser.parse_args()
50 |
51 | if os.path.isfile(args.input):
52 | image_path = args.input
53 | else:
54 | image_path = os.path.join(args.input, "original.png")
55 | output = args.output
56 | if output is None:
57 | output = args.input if os.path.isdir(args.input) else os.path.dirname(args.input)
58 | os.makedirs(output, exist_ok=True)
59 |
60 | model = load_models('paper_weights')
61 |
62 | # Read input image
63 | img = load_image(image_path, bits=8)
64 | # run the model on the image using R_0 resizing
65 |
66 | results = run_pipeline(model, img, resize_conf=0.0, maintain_size=True)
67 |
68 | albedo = results['albedo']
69 | inv_shd = results['inv_shading']
70 |
71 | shd = uninvert(inv_shd)
72 | shd_save_path = os.path.join(output, "shading.npy")
73 | np.save(shd_save_path, shd)
74 |
75 | if args.vis:
76 | intermediate_dir = os.path.join(output, "intermediate")
77 | os.makedirs(intermediate_dir, exist_ok=True)
78 |
79 | alb_save_path = os.path.join(intermediate_dir, "albedo.npy")
80 | np.save(alb_save_path, albedo)
81 | np_to_pil(albedo).save(os.path.join(intermediate_dir, 'albedo_vis.png'))
82 | np_to_pil(view(shd)).save(os.path.join(intermediate_dir, 'shading_vis.png'))
83 |
84 |
85 |
--------------------------------------------------------------------------------
/perception/run_depth_normal.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import logging
5 | import numpy as np
6 | import torch
7 | from PIL import Image
8 | from tqdm.auto import tqdm
9 |
10 |
11 | BASE_DIR = "GeoWizard/geowizard"
12 | sys.path.append(os.path.join(BASE_DIR))
13 |
14 | from models.geowizard_pipeline import DepthNormalEstimationPipeline
15 | from utils.seed_all import seed_all
16 |
17 |
18 | if __name__=="__main__":
19 |
20 | logging.basicConfig(level=logging.INFO)
21 |
22 | parser = argparse.ArgumentParser(
23 | description="Run MonoDepthNormal Estimation using Stable Diffusion."
24 | )
25 | parser.add_argument(
26 | "--pretrained_model_path",
27 | type=str,
28 | default='lemonaddie/geowizard',
29 | help="pretrained model path from hugging face or local dir",
30 | )
31 |
32 | parser.add_argument(
33 | "--domain",
34 | type=str,
35 | default='indoor',
36 | help="domain prediction",
37 | )
38 |
39 | # inference setting
40 | parser.add_argument(
41 | "--denoise_steps",
42 | type=int,
43 | default=10,
44 | help="Diffusion denoising steps, more steps results in higher accuracy but slower inference speed.",
45 | )
46 | parser.add_argument(
47 | "--ensemble_size",
48 | type=int,
49 | default=10,
50 | help="Number of predictions to be ensembled, more inference gives better results but runs slower.",
51 | )
52 | parser.add_argument(
53 | "--half_precision",
54 | action="store_true",
55 | help="Run with half-precision (16-bit float), might lead to suboptimal result.",
56 | )
57 |
58 | # resolution setting
59 | parser.add_argument(
60 | "--processing_res",
61 | type=int,
62 | default=768,
63 | help="Maximum resolution of processing. 0 for using input image resolution. Default: 768.",
64 | )
65 | parser.add_argument(
66 | "--output_processing_res",
67 | action="store_true",
68 | help="When input is resized, out put depth at resized operating resolution. Default: False.",
69 | )
70 |
71 | # depth map colormap
72 | parser.add_argument(
73 | "--color_map",
74 | type=str,
75 | default="Spectral",
76 | help="Colormap used to render depth predictions.",
77 | )
78 | # other settings
79 | parser.add_argument("--seed", type=int, default=None, help="Random seed.")
80 | parser.add_argument(
81 | "--batch_size",
82 | type=int,
83 | default=0,
84 | help="Inference batch size. Default: 0 (will be set automatically).",
85 | )
86 |
87 | # custom settings
88 | parser.add_argument("--input", type=str, default="../data/pig_ball", help="Input image or directory.")
89 | parser.add_argument("--output", type=str, default="../outputs/pig_ball", help="Output directory")
90 | parser.add_argument("--vis", action="store_true", help="Visualize the output.")
91 |
92 | args = parser.parse_args()
93 |
94 | checkpoint_path = args.pretrained_model_path
95 | denoise_steps = args.denoise_steps
96 | ensemble_size = args.ensemble_size
97 |
98 | if ensemble_size>15:
99 | logging.warning("long ensemble steps, low speed..")
100 |
101 | half_precision = args.half_precision
102 |
103 | processing_res = args.processing_res
104 | match_input_res = not args.output_processing_res
105 | domain = args.domain
106 |
107 | color_map = args.color_map
108 | seed = args.seed
109 | batch_size = args.batch_size
110 |
111 | if batch_size==0:
112 | batch_size = 1 # set default batchsize
113 |
114 | # -------------------- Preparation --------------------
115 | # Random seed
116 | if seed is None:
117 | import time
118 | seed = int(time.time())
119 | seed_all(seed)
120 |
121 | # -------------------- Device --------------------
122 | if torch.cuda.is_available():
123 | device = torch.device("cuda")
124 | else:
125 | device = torch.device("cpu")
126 | logging.warning("CUDA is not available. Running on CPU will be slow.")
127 | logging.info(f"device = {device}")
128 |
129 | # -------------------- Data --------------------
130 | if os.path.isfile(args.input):
131 | image_path = args.input
132 | else:
133 | image_path = os.path.join(args.input, "original.png")
134 | output = args.output
135 | if output is None:
136 | output = args.input if os.path.isdir(args.input) else os.path.dirname(args.input)
137 | os.makedirs(output, exist_ok=True)
138 |
139 | # -------------------- Model --------------------
140 | if half_precision:
141 | dtype = torch.float16
142 | logging.info(f"Running with half precision ({dtype}).")
143 | else:
144 | dtype = torch.float32
145 |
146 | # declare a pipeline
147 | pipe = DepthNormalEstimationPipeline.from_pretrained(checkpoint_path, torch_dtype=dtype)
148 | logging.info("loading pipeline whole successfully.")
149 |
150 | try:
151 | pipe.enable_xformers_memory_efficient_attention()
152 | except:
153 | pass # run without xformers
154 |
155 | pipe = pipe.to(device)
156 |
157 | # -------------------- Inference and saving --------------------
158 | with torch.no_grad():
159 |
160 | # Read input image
161 | input_image = Image.open(image_path)
162 |
163 | # predict the depth here
164 | pipe_out = pipe(input_image,
165 | denoising_steps = denoise_steps,
166 | ensemble_size= ensemble_size,
167 | processing_res = processing_res,
168 | match_input_res = match_input_res,
169 | domain = domain,
170 | color_map = color_map,
171 | show_progress_bar = True,
172 | )
173 |
174 | depth_pred: np.ndarray = pipe_out.depth_np
175 | depth_colored: Image.Image = pipe_out.depth_colored
176 | normal_pred: np.ndarray = pipe_out.normal_np
177 | normal_colored: Image.Image = pipe_out.normal_colored
178 |
179 | # Save as npy
180 | depth_npy_save_path = os.path.join(output, f"depth.npy")
181 | if os.path.exists(depth_npy_save_path):
182 | logging.warning(f"Existing file: '{depth_npy_save_path}' will be overwritten")
183 | np.save(depth_npy_save_path, depth_pred)
184 |
185 | normal_npy_save_path = os.path.join(output, f"normal.npy")
186 | if os.path.exists(normal_npy_save_path):
187 | logging.warning(f"Existing file: '{normal_npy_save_path}' will be overwritten")
188 | np.save(normal_npy_save_path, normal_pred)
189 |
190 | # Colorize
191 | if args.vis:
192 | intermediate_dir = os.path.join(output, "intermediate")
193 | os.makedirs(output, exist_ok=True)
194 | os.makedirs(intermediate_dir, exist_ok=True)
195 |
196 | depth_colored_save_path = os.path.join(intermediate_dir, f"depth_vis.png")
197 | if os.path.exists(depth_colored_save_path):
198 | logging.warning(
199 | f"Existing file: '{depth_colored_save_path}' will be overwritten"
200 | )
201 | depth_colored.save(depth_colored_save_path)
202 |
203 | normal_colored_save_path = os.path.join(intermediate_dir, f"normal_vis.png")
204 | if os.path.exists(normal_colored_save_path):
205 | logging.warning(
206 | f"Existing file: '{normal_colored_save_path}' will be overwritten"
207 | )
208 | normal_colored.save(normal_colored_save_path)
209 | print("Done.")
--------------------------------------------------------------------------------
/perception/run_fg_bg.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import numpy as np
4 | import os
5 | import json
6 | import matplotlib
7 | import matplotlib.pyplot as plt
8 |
9 |
10 | def vis_seg(mask_img, text=False):
11 | norm = matplotlib.colors.Normalize(vmin=np.min(mask_img), vmax=np.max(mask_img))
12 | norm_segmentation_map = norm(mask_img)
13 | cmap = "tab20"
14 | colormap = plt.get_cmap(cmap)
15 | colored_segmentation_map = colormap(norm_segmentation_map)
16 | colored_segmentation_map = (colored_segmentation_map[:, :, :3] * 255).astype(np.uint8)
17 | colored_segmentation_map[mask_img == 0] = [255, 255, 255]
18 | if text:
19 | unique_seg_ids = np.unique(mask_img)
20 | unique_seg_ids = unique_seg_ids[unique_seg_ids != 0]
21 | for seg_id in unique_seg_ids:
22 | mask_indices = np.where(mask_img == seg_id)
23 | if len(mask_indices[0]) > 0:
24 | center_y = int(np.mean(mask_indices[0]))
25 | center_x = int(np.mean(mask_indices[1]))
26 | cv2.putText(
27 | colored_segmentation_map,
28 | str(seg_id),
29 | (center_x, center_y),
30 | cv2.FONT_HERSHEY_SIMPLEX,
31 | 0.5,
32 | (0, 0, 0),
33 | 2,
34 | cv2.LINE_AA
35 | )
36 | return colored_segmentation_map
37 |
38 |
39 | def refine_segmentation_mask(mask, seg_mask, value, min_size):
40 | """
41 | Refine a segmentation mask by removing small connected components.
42 | """
43 | num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask.astype(np.uint8), connectivity=8)
44 |
45 | # Filter out small connected components
46 | for label in range(1, num_labels):
47 | if stats[label, cv2.CC_STAT_AREA] >= min_size:
48 | seg_mask[labels == label] = value
49 | value = value + 1
50 |
51 | return seg_mask, value
52 |
53 |
54 | def find_neighbor_segmentation_classes(segmentation_map, class_id):
55 | """
56 | Find the neighboring segmentation classes of a specified class in a segmentation map.
57 | """
58 | class_mask = np.uint8(segmentation_map == class_id)
59 | dilated_mask = cv2.dilate(class_mask, np.ones((3, 3), np.uint8), iterations=3)
60 | neighbor_class_ids = np.unique(segmentation_map[dilated_mask > 0])
61 | neighbor_class_ids = neighbor_class_ids[neighbor_class_ids != class_id]
62 | return list(neighbor_class_ids)
63 |
64 |
65 |
66 | def compute_angle_to_direction(normal, direction):
67 | """
68 | Compute the angle (in degrees) between a surface normal and a direction vector
69 | """
70 | dot_product = np.dot(normal, direction)
71 | norm_normal = np.linalg.norm(normal, axis=-1)
72 | norm_direction = np.linalg.norm(direction)
73 | cos_theta = dot_product / (norm_normal * norm_direction)
74 | return np.abs(cos_theta)
75 |
76 |
77 | def normal_planes(normals):
78 | '''
79 | normal coordinate:
80 | y
81 | |
82 | o -- x
83 | /
84 | z
85 | '''
86 | upward_direction = np.array([0, -1, 0])
87 | downward_direction = np.array([0, 1, 0])
88 | leftward_direction = np.array([-1, 0, 0])
89 | rightward_direction = np.array([1, 0, 0])
90 | forward_direction = np.array([0, 0, 1])
91 | inward_direction = np.array([0, 0, -1])
92 |
93 | # Compute the angles between each pixel's normal and the 6 directions
94 | angle_leftward = compute_angle_to_direction(normals, leftward_direction)
95 | angle_rightward = compute_angle_to_direction(normals, rightward_direction)
96 | angle_X = np.maximum(angle_leftward, angle_rightward).mean()
97 |
98 | angle_upward = compute_angle_to_direction(normals, upward_direction)
99 | angle_downward = compute_angle_to_direction(normals, downward_direction)
100 | angle_Y = np.maximum(angle_upward, angle_downward).mean()
101 |
102 | angle_forward = compute_angle_to_direction(normals, forward_direction)
103 | angle_inward = compute_angle_to_direction(normals, inward_direction)
104 | angle_Z = np.maximum(angle_forward, angle_inward).mean()
105 |
106 | angles = np.array([angle_X, angle_Y, angle_Z])
107 | index = np.argmax(angles)
108 | if index == 0:
109 | return "X"
110 | elif index == 1:
111 | return "Y"
112 | else:
113 | return "Z"
114 |
115 |
116 | def find_upper_contour(binary_mask):
117 | binary_mask = binary_mask.astype(np.uint8)
118 | contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
119 | contour = contours[0]
120 | leftmost = tuple(contour[contour[:,:,0].argmin()][0])
121 | rightmost = tuple(contour[contour[:,:,0].argmax()][0])
122 |
123 | start_index = np.where((contour[:, 0] == leftmost).all(axis=1))[0][0]
124 | end_index = np.where((contour[:, 0] == rightmost).all(axis=1))[0][0]
125 |
126 | if start_index > end_index:
127 | start_index, end_index = end_index, start_index
128 |
129 | subarray1 = contour[start_index:end_index+1]
130 | subarray2 = np.concatenate((contour[end_index:], contour[:start_index+1]))
131 | if subarray1[:, 0, 1].mean() < subarray2[:, 0, 1].mean():
132 | upper_contour = subarray1
133 | else:
134 | upper_contour = subarray2
135 |
136 | start_x, end_x = upper_contour[:, 0, 0].min(), upper_contour[:, 0, 0].max()
137 | start_x, end_x = int(start_x), int(end_x)
138 | y = upper_contour[:, 0, 1].mean()
139 | y = int(y)
140 | edge = [(start_x, y), (end_x, y)]
141 | return edge
142 |
143 |
144 | def is_mask_truncated(mask):
145 | if np.any(mask[0, :] == 1) or np.any(mask[-1, :] == 1):
146 | return True
147 | if np.any(mask[:, 0] == 1) or np.any(mask[:, -1] == 1):
148 | return True
149 | return False
150 |
151 | def draw_edge_on_image(image, edges, color=(0, 255, 0), thickness=2):
152 | output_image = image.copy()
153 | for edge in edges:
154 | for i in range(1, len(edge)):
155 | start_point = (edge[i-1][0], edge[i-1][1])
156 | end_point = (edge[i][0], edge[i][1])
157 | cv2.line(output_image, start_point, end_point, color, thickness)
158 |
159 | return output_image
160 |
161 |
162 |
163 | if __name__ == '__main__':
164 | parser = argparse.ArgumentParser()
165 | parser.add_argument('--input', type=str, default="../data/pig_ball", help="input path")
166 | parser.add_argument("--output", type=str, default=None, help="output directory")
167 | parser.add_argument("--vis_edge", action="store_true", help="visualize edges")
168 |
169 | args = parser.parse_args()
170 |
171 | output = args.output
172 | if output is None:
173 | output = args.input if os.path.isdir(args.input) else os.path.dirname(args.input)
174 | os.makedirs(output, exist_ok=True)
175 |
176 | seg_path = os.path.join(args.input, "intermediate", 'mask.png')
177 | seg_mask = cv2.imread(seg_path, 0)
178 |
179 | depth_path = os.path.join(args.input, 'depth.npy')
180 | depth = np.load(depth_path)
181 |
182 | normal_path = os.path.join(args.input, 'normal.npy')
183 | normal = np.load(normal_path)
184 |
185 | seg_info_path = os.path.join(args.input, "intermediate", 'mask.json')
186 | with open(seg_info_path, 'r') as f:
187 | seg_info = json.load(f)
188 |
189 | seg_ids = np.unique(seg_mask)
190 |
191 | obj_infos = seg_info
192 | for seg_id in seg_ids:
193 |
194 | mask = seg_mask == seg_id
195 | seg_depth = depth[mask]
196 |
197 | min_depth = np.percentile(seg_depth, 5)
198 | max_depth = np.percentile(seg_depth, 95)
199 |
200 | obj_infos[seg_id]["depth"] = [min_depth, max_depth]
201 |
202 | movable_seg_ids = [seg_id for seg_id in seg_ids if obj_infos[seg_id]['movable']]
203 | fg_seg_mask = np.zeros_like(seg_mask)
204 | fg_truncated_mask = np.zeros_like(seg_mask)
205 | edges = []
206 | for seg_id in movable_seg_ids:
207 | mask = seg_mask == seg_id
208 | if mask.sum() < 100: # ignore small objects
209 | seg_mask[mask] = 0
210 | continue
211 | if is_mask_truncated(mask):
212 | fg_truncated_mask[mask] = seg_id
213 | else:
214 | fg_seg_mask[mask] = seg_id
215 |
216 | seg_mask_path = os.path.join(output, 'mask.png') # save the foreground mask as final mask
217 | fg_mask = fg_seg_mask + fg_truncated_mask
218 | cv2.imwrite(seg_mask_path, fg_mask)
219 |
220 | vis_seg_mask = vis_seg(fg_seg_mask, text=True)
221 | vis_save_path = os.path.join(output, "intermediate", 'fg_mask_vis.png')
222 | cv2.imwrite(vis_save_path, vis_seg_mask)
223 |
224 | for seg_id in np.unique(fg_truncated_mask):
225 | if seg_id == 0:
226 | continue
227 | mask = fg_truncated_mask == seg_id
228 | points = cv2.findNonZero((mask> 0).astype(np.uint8))
229 |
230 | x, y, w, h = cv2.boundingRect(points)
231 |
232 | top_edge = [[x, y], [x + w, y]]
233 | left_edge = [[x, y], [x, y + h]]
234 | right_edge = [[x + w, y], [x + w, y + h]]
235 | edges.extend([right_edge, left_edge, top_edge])
236 |
237 | bg_seg_mask = np.zeros_like(seg_mask)
238 | value = 1
239 | nonmovable_seg_ids = [seg_id for seg_id in seg_ids if not obj_infos[seg_id]['movable']]
240 | for seg_id in nonmovable_seg_ids:
241 | seg_area_ratio = np.sum(seg_mask == seg_id) / seg_mask.size
242 | if seg_id == 0 or seg_info[seg_id]['label'] == "background" or seg_area_ratio > 0.5:
243 | depth_threshold = np.array([obj_infos[seg_id]["depth"][0] for seg_id in movable_seg_ids]).min()
244 | mask = np.logical_and(seg_mask == seg_id, depth <= depth_threshold)
245 | min_size = int(500 / (512 * 512) * mask.size)
246 | bg_seg_mask, value = refine_segmentation_mask(mask, bg_seg_mask, value, min_size=min_size)
247 |
248 | else:
249 | neighbor_seg_ids = find_neighbor_segmentation_classes(seg_mask, seg_id)
250 | neighbor_seg_ids = [seg_id for seg_id in neighbor_seg_ids if seg_id in movable_seg_ids]
251 | if len(neighbor_seg_ids) == 0:
252 | depth_array = np.array([obj_infos[seg_id]["depth"][1] for seg_id in movable_seg_ids])
253 | else:
254 | depth_array = np.array([obj_infos[seg_id]["depth"][1] for seg_id in neighbor_seg_ids])
255 |
256 | min_threshold = depth_array.min()
257 | max_threshold = depth_array.max()
258 |
259 | min_size = int(500 / (512 * 512) * mask.size)
260 | mask = np.logical_and(seg_mask == seg_id, depth <= min_threshold)
261 | old_value = value
262 | bg_seg_mask, value = refine_segmentation_mask(mask, bg_seg_mask, value, min_size=min_size)
263 |
264 | if value == old_value:
265 | mask = np.logical_and(seg_mask == seg_id, min_threshold < depth)
266 | mask = np.logical_and(mask, depth <= max_threshold)
267 | bg_seg_mask, value = refine_segmentation_mask(mask, bg_seg_mask, value, min_size=min_size)
268 |
269 | seg_save_path = os.path.join(output, 'intermediate', 'bg_mask.png')
270 | cv2.imwrite(seg_save_path, bg_seg_mask)
271 |
272 | vis_seg_mask = vis_seg(bg_seg_mask)
273 | vis_seg_path = os.path.join(output, 'intermediate', 'bg_mask_vis.png')
274 | cv2.imwrite(vis_seg_path, vis_seg_mask)
275 |
276 | seg_ids = np.unique(bg_seg_mask)
277 | edge_map = np.zeros_like(bg_seg_mask)
278 | for seg_id in seg_ids:
279 | if seg_id == 0:
280 | continue
281 | else:
282 | mask = (bg_seg_mask == seg_id)
283 | normals = normal[mask]
284 | axis = normal_planes(normals)
285 | if axis == "X" or axis == "Z":
286 | points = cv2.findNonZero((mask> 0).astype(np.uint8))
287 |
288 | if points is not None:
289 | x, y, w, h = cv2.boundingRect(points)
290 |
291 | if axis == "X":
292 | right_col = x + w - 1
293 | edge = [[right_col, y], [right_col, y + h]]
294 | edges.append(edge)
295 | else: # Z
296 | top_edge = [[x, y], [x + w, y]]
297 | left_edge = [[x, y], [x, y + h]]
298 | right_edge = [[x + w, y], [x + w, y + h]]
299 | edges.extend([right_edge, left_edge, top_edge])
300 |
301 | elif axis == "Y":
302 | edge= find_upper_contour(mask)
303 | edges.append(edge)
304 | if args.vis_edge:
305 | img_path = os.path.join(args.input, 'original.png')
306 | img = cv2.imread(img_path)
307 | vis_edge_mask = draw_edge_on_image(img, edges, color=(0, 0, 255))
308 | vis_edge_save_path = os.path.join(output, 'intermediate', 'edge_vis.png')
309 | cv2.imwrite(vis_edge_save_path, vis_edge_mask)
310 |
311 | with open(os.path.join(output, 'edges.json'), 'w') as f:
312 | json.dump(edges, f)
313 | print("Done!")
314 |
315 |
316 |
317 |
318 |
--------------------------------------------------------------------------------
/perception/run_gsam.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import json
5 | import torch
6 | import torchvision
7 | from PIL import Image
8 | import cv2
9 | import numpy as np
10 | import matplotlib
11 | import matplotlib.pyplot as plt
12 |
13 |
14 | BASE_DIR = "Grounded-Segment-Anything"
15 | sys.path.append(os.path.join(BASE_DIR))
16 |
17 | # Grounding DINO
18 | import GroundingDINO.groundingdino.datasets.transforms as T
19 | from GroundingDINO.groundingdino.models import build_model
20 | from GroundingDINO.groundingdino.util.slconfig import SLConfig
21 | from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
22 |
23 |
24 | # segment anything
25 | from segment_anything import (
26 | sam_model_registry,
27 | sam_hq_model_registry,
28 | SamPredictor
29 | )
30 |
31 |
32 | def load_image(image_path, return_transform=False):
33 | image_pil = Image.open(image_path).convert("RGB")
34 |
35 | transform = T.Compose(
36 | [
37 | T.RandomResize([800], max_size=1333),
38 | T.ToTensor(),
39 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
40 | ]
41 | )
42 | image, _ = transform(image_pil, None) # 3, h, w
43 | if not return_transform:
44 | return image_pil, image
45 | else:
46 | return image_pil, image, transform
47 |
48 |
49 | def load_model(model_config_path, model_checkpoint_path, device):
50 | args = SLConfig.fromfile(model_config_path)
51 | args.device = device
52 | model = build_model(args)
53 | checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
54 | load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
55 | print(load_res)
56 | _ = model.eval()
57 | return model
58 |
59 |
60 | def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
61 | caption = caption.lower()
62 | caption = caption.strip()
63 | if not caption.endswith("."):
64 | caption = caption + "."
65 | model = model.to(device)
66 | image = image.to(device)
67 | with torch.no_grad():
68 | outputs = model(image[None], captions=[caption])
69 | logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
70 | boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
71 | logits.shape[0]
72 |
73 | # filter output
74 | logits_filt = logits.clone()
75 | boxes_filt = boxes.clone()
76 | filt_mask = logits_filt.max(dim=1)[0] > box_threshold
77 | logits_filt = logits_filt[filt_mask] # num_filt, 256
78 | boxes_filt = boxes_filt[filt_mask] # num_filt, 4
79 | logits_filt.shape[0]
80 |
81 | # get phrase
82 | tokenlizer = model.tokenizer
83 | tokenized = tokenlizer(caption)
84 | # build pred
85 | pred_phrases = []
86 | scores = []
87 | for logit, box in zip(logits_filt, boxes_filt):
88 | pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
89 | if with_logits:
90 | pred_phrases.append(pred_phrase + f"({str(logit.max().item())})")
91 | else:
92 | pred_phrases.append(pred_phrase)
93 | scores.append(logit.max().item())
94 |
95 | return boxes_filt, torch.Tensor(scores), pred_phrases
96 |
97 |
98 | def save_mask_data(output_dir, mask_list, label_list, movable_dict):
99 | value = 0 # 0 for background
100 | mask_img = torch.zeros(mask_list.shape[-2:])
101 | for idx, mask in enumerate(mask_list):
102 | mask_img[mask.cpu().numpy() == True] = value + idx + 1
103 |
104 | mask_img = mask_img.numpy().astype(np.uint8)
105 | Image.fromarray(mask_img).save(os.path.join(output_dir, "mask.png"))
106 | print("number of classes: ", len(np.unique(mask_img)))
107 |
108 | norm = matplotlib.colors.Normalize(vmin=np.min(mask_img), vmax=np.max(mask_img))
109 | norm_segmentation_map = norm(mask_img)
110 | cmap = "tab20"
111 | colormap = plt.get_cmap(cmap)
112 | colored_segmentation_map = colormap(norm_segmentation_map)
113 | colored_segmentation_map = (colored_segmentation_map[:, :, :3] * 255).astype(np.uint8)
114 | colored_segmentation_map[mask_img == 0] = [255, 255, 255]
115 |
116 | cv2.imwrite(os.path.join(output_dir, "vis_mask.jpg"), colored_segmentation_map[..., ::-1])
117 | print("seg visualization saved to ", os.path.join(output_dir, "vis_mask.jpg"))
118 |
119 | json_data = [{
120 | 'value': value,
121 | 'label': 'background',
122 | 'movable': False
123 | }]
124 | for label in label_list:
125 | value += 1
126 | movable = movable_dict[label] if (label in movable_dict and movable_dict[label]) else False
127 | json_data.append({
128 | 'value': value,
129 | 'label': label,
130 | 'movable': movable
131 | })
132 | with open(os.path.join(output_dir, 'mask.json'), 'w') as f:
133 | json.dump(json_data, f)
134 |
135 |
136 | def masks_postprocess(masks, phrases, movable_dict):
137 | new_masks_list = []
138 | new_phrases_list = []
139 | for idx, (label, mask) in enumerate(zip(phrases, masks)):
140 | mask = mask.numpy()
141 |
142 | # post-processing label
143 | if (label != "background") or label not in movable_dict:
144 | found_keys = [key for key in movable_dict if key in label]
145 | if found_keys:
146 | phrases[idx] = found_keys[0]
147 | label = found_keys[0]
148 |
149 | if label == "background": # background
150 | continue
151 |
152 | if label in movable_dict:
153 | if movable_dict[label]: # movable object
154 | contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
155 | if len(contours) == 1:
156 | new_masks_list.append(mask)
157 | new_phrases_list.append(label)
158 | else:
159 | for contour in contours:
160 | component_mask = np.zeros_like(mask, dtype=np.uint8)
161 | cv2.drawContours(component_mask, [contour], -1, 1, thickness=cv2.FILLED)
162 | new_masks_list.append(component_mask)
163 | new_phrases_list.append(label)
164 | else:
165 | new_masks_list.append(mask)
166 | new_phrases_list.append(label)
167 | else: # merge into background
168 | continue
169 |
170 | masks = np.stack(new_masks_list, axis=0)
171 | masks = torch.from_numpy(masks)
172 | assert len(masks) == len(new_phrases_list)
173 | return masks, new_phrases_list
174 |
175 |
176 | if __name__ == "__main__":
177 | parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
178 | parser.add_argument("--config", type=str, default="Grounded-Segment-Anything/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", help="path to config file")
179 | parser.add_argument("--grounded_checkpoint", type=str, default="Grounded-Segment-Anything/groundingdino_swint_ogc.pth", help="path to checkpoint file"
180 | )
181 | parser.add_argument("--sam_version", type=str, default="vit_h", required=False, help="SAM ViT version: vit_b / vit_l / vit_h"
182 | )
183 | parser.add_argument(
184 | "--sam_checkpoint", type=str, default="Grounded-Segment-Anything/sam_vit_h_4b8939.pth", help="path to sam checkpoint file"
185 | )
186 | parser.add_argument(
187 | "--sam_hq_checkpoint", type=str, default=None, help="path to sam-hq checkpoint file"
188 | )
189 | parser.add_argument(
190 | "--use_sam_hq", action="store_true", help="using sam-hq for prediction"
191 | )
192 |
193 | parser.add_argument("--device", type=str, default="cuda", help="running on cpu only!, default=False")
194 |
195 | # custom setting
196 | parser.add_argument("--box_threshold", type=float, default=0.2, help="box threshold")
197 | parser.add_argument("--text_threshold", type=float, default=0.2, help="text threshold")
198 | parser.add_argument("--iou_threshold", type=float, default=0.15, help="nms threshold")
199 | parser.add_argument('--disable_nms', action='store_true', help='Disable nms')
200 |
201 | parser.add_argument("--input", type=str, default="../data/pig_ball", help="input path")
202 | parser.add_argument("--output", type=str, default=None, help="output directory")
203 | parser.add_argument("--prompts_path", type=str, default=None, help="prompt file under same path as input")
204 |
205 | args = parser.parse_args()
206 |
207 | # cfg
208 | config_file = args.config # change the path of the model config file
209 | grounded_checkpoint = args.grounded_checkpoint # change the path of the model
210 | sam_version = args.sam_version
211 | sam_checkpoint = args.sam_checkpoint
212 | sam_hq_checkpoint = args.sam_hq_checkpoint
213 | use_sam_hq = args.use_sam_hq
214 | box_threshold = args.box_threshold
215 | text_threshold = args.text_threshold
216 | iou_threshold = args.iou_threshold
217 | device = args.device
218 |
219 | if os.path.isfile(args.input):
220 | image_path = args.input
221 | else:
222 | image_path = os.path.join(args.input, "original.png")
223 | output = args.output
224 | if output is None:
225 | output = args.input if os.path.isdir(args.input) else os.path.dirname(args.input)
226 | os.makedirs(output, exist_ok=True)
227 |
228 | if args.prompts_path is None:
229 | prompts_path = os.path.join(args.input, "intermediate", "obj_movable.json")
230 | else:
231 | prompts_path = args.prompts_path
232 |
233 | with open(prompts_path, "r") as f:
234 | movable_dict = json.load(f)
235 |
236 | model = load_model(config_file, grounded_checkpoint, device=device)
237 |
238 | if use_sam_hq:
239 | predictor = SamPredictor(sam_hq_model_registry[sam_version](checkpoint=sam_hq_checkpoint).to(device))
240 | else:
241 | predictor = SamPredictor(sam_model_registry[sam_version](checkpoint=sam_checkpoint).to(device))
242 |
243 | image_pil, image_dino_input, transform = load_image(image_path, return_transform=True)
244 |
245 | text_prompt = ""
246 | for idx, tag in enumerate(movable_dict):
247 |
248 | if idx < len(movable_dict) - 1:
249 | text_prompt += tag + ". "
250 | else:
251 | text_prompt += tag
252 |
253 | print(f"Text Prompt input: {text_prompt}")
254 |
255 | # run grounding dino model
256 | boxes_filt, scores, pred_phrases = get_grounding_output(
257 | model, image_dino_input, text_prompt, box_threshold, text_threshold, with_logits=False, device=device)
258 |
259 | size = image_pil.size
260 | H, W = size[1], size[0]
261 | for i in range(boxes_filt.size(0)):
262 | boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
263 | boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
264 | boxes_filt[i][2:] += boxes_filt[i][:2]
265 |
266 | boxes_filt = boxes_filt.cpu() # boxes_filt: xyxy
267 |
268 | # use NMS to handle overlapped boxes
269 | if not args.disable_nms:
270 | print(f"Before NMS: {boxes_filt.shape[0]} boxes")
271 | nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
272 | boxes_filt = boxes_filt[nms_idx]
273 | pred_phrases = [pred_phrases[idx] for idx in nms_idx]
274 | scores = scores[nms_idx]
275 | print(f"After NMS: {boxes_filt.shape[0]} boxes")
276 |
277 | image_input = np.array(image_pil)
278 | predictor.set_image(image_input)
279 | transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_input.shape[:2]).to(device)
280 | masks, _, _ = predictor.predict_torch(
281 | point_coords = None,
282 | point_labels = None,
283 | boxes = transformed_boxes.to(device),
284 | multimask_output = False,
285 | )
286 | masks = masks.cpu().squeeze(1) # (N, H, W)
287 |
288 |
289 | image_input2 = image_input.copy()
290 | image_pil2 = Image.fromarray(image_input2)
291 |
292 | for idx, (label, box, mask) in enumerate(zip(pred_phrases, boxes_filt, masks)):
293 | mask = mask.numpy()
294 | if label == "background": # background
295 | continue
296 |
297 | if label in movable_dict and movable_dict[label]: # movable object
298 | contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
299 | if len(contours) == 1:
300 | image_input2[mask == True, :] = 0 # set as black
301 |
302 |
303 | image_pil2 = Image.fromarray(image_input2)
304 | image_dino_input2, _ = transform(image_pil2, None)
305 |
306 | boxes_filt2, scores2, pred_phrases2 = get_grounding_output(
307 | model, image_dino_input2, text_prompt, 0.8 * box_threshold, 0.8 * text_threshold, with_logits=False, device=device
308 | )
309 |
310 | size = image_pil.size
311 | H, W = size[1], size[0]
312 | for i in range(boxes_filt2.size(0)):
313 | boxes_filt2[i] = boxes_filt2[i] * torch.Tensor([W, H, W, H])
314 | boxes_filt2[i][:2] -= boxes_filt2[i][2:] / 2
315 | boxes_filt2[i][2:] += boxes_filt2[i][:2]
316 |
317 | boxes_filt2 = boxes_filt2.cpu() # boxes_filt: xyxy
318 |
319 | # use NMS to handle overlapped boxes
320 | if len(boxes_filt2) > 0:
321 |
322 | scores2 = 0.5 * scores2 # reduce the score of the second round prediction
323 | if not args.disable_nms:
324 | boxes_combine = torch.cat([boxes_filt, boxes_filt2], dim=0)
325 | phrases_combine = pred_phrases + pred_phrases2
326 | sc_combine = torch.cat([scores, scores2], dim=0)
327 |
328 | print(f"Before NMS: {boxes_filt2.shape[0]} boxes")
329 | nms_idx = torchvision.ops.nms(boxes_combine, sc_combine, 1.25 * iou_threshold).numpy().tolist()
330 | nms_idx = [idx for idx in nms_idx if idx >= len(boxes_filt)]
331 | boxes_filt2 = boxes_combine[nms_idx]
332 | pred_phrases2 = [phrases_combine[idx] for idx in nms_idx]
333 | scores2 = sc_combine[nms_idx]
334 | print(f"After NMS: {boxes_filt2.shape[0]} boxes")
335 |
336 | if len(boxes_filt2) > 0:
337 | image_input2 = np.array(image_pil2)
338 | predictor.set_image(image_input2)
339 | transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt2, image_input2.shape[:2]).to(device)
340 | masks2, _, _ = predictor.predict_torch(
341 | point_coords = None,
342 | point_labels = None,
343 | boxes = transformed_boxes.to(device),
344 | multimask_output = False,
345 | )
346 | masks2 = masks2.cpu().squeeze(1)
347 |
348 | pred_phrases = pred_phrases + pred_phrases2
349 | masks = torch.cat([masks, masks2], dim=0)
350 | masks, pred_phrases = masks_postprocess(masks, pred_phrases, movable_dict)
351 | else:
352 | masks, pred_phrases = masks_postprocess(masks, pred_phrases, movable_dict)
353 | else:
354 | masks, pred_phrases = masks_postprocess(masks, pred_phrases, movable_dict)
355 |
356 | intermediate_dir = os.path.join(output, "intermediate") # save intermediate results and visualization
357 | os.makedirs(intermediate_dir, exist_ok=True)
358 | save_mask_data(intermediate_dir, masks, pred_phrases, movable_dict)
359 |
360 |
--------------------------------------------------------------------------------
/perception/run_inpaint.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import argparse
4 | import numpy as np
5 | import torch
6 | import cv2
7 | from PIL import Image
8 |
9 | BASE_DIR = "Inpaint-Anything"
10 | sys.path.append(os.path.join(BASE_DIR))
11 |
12 | from lama_inpaint import inpaint_img_with_lama
13 |
14 |
15 | def load_img_to_array(img_p):
16 | img = Image.open(img_p)
17 | if img.mode == "RGBA":
18 | img = img.convert("RGB")
19 | return np.array(img)
20 |
21 |
22 | def save_array_to_img(img_arr, img_p):
23 | Image.fromarray(img_arr.astype(np.uint8)).save(img_p)
24 |
25 |
26 | def dilate_mask(mask, dilate_factor=20):
27 | mask = mask.astype(np.uint8)
28 | mask = cv2.dilate(
29 | mask,
30 | np.ones((dilate_factor, dilate_factor), np.uint8),
31 | iterations=1
32 | )
33 | return mask
34 |
35 |
36 | if __name__ == "__main__":
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument('--input', type=str, default="../data/pig_ball", help="input path")
39 | parser.add_argument("--output", type=str, default=None, help="output directory")
40 | parser.add_argument("--dilate_kernel_size", type=int, default=20, help="Dilate kernel size")
41 | parser.add_argument("--lama_config", type=str, default="Inpaint-Anything/lama/configs/prediction/default.yaml",
42 | help="The path to the config file of lama model.")
43 | parser.add_argument("--lama_ckpt", type=str, default="Inpaint-Anything/pretrained_models/big-lama",
44 | help="The path to the lama checkpoint.")
45 |
46 | args = parser.parse_args()
47 |
48 | if os.path.isfile(args.input):
49 | image_path = args.input
50 | else:
51 | image_path = os.path.join(args.input, "original.png")
52 | output = args.output
53 | if output is None:
54 | output = args.input if os.path.isdir(args.input) else os.path.dirname(args.input)
55 | os.makedirs(output, exist_ok=True)
56 |
57 | device = "cuda" if torch.cuda.is_available() else "cpu"
58 |
59 | image_path = os.path.join(args.input, "original.png")
60 | img = load_img_to_array(image_path)
61 | seg_path = os.path.join(args.input, 'mask.png')
62 | seg_mask = cv2.imread(seg_path, 0)
63 | mask = ((seg_mask > 0)*255).astype(np.uint8)
64 |
65 | if args.dilate_kernel_size is not None:
66 | mask = dilate_mask(mask, args.dilate_kernel_size)
67 |
68 | img_inpainted = inpaint_img_with_lama(img, mask, args.lama_config, args.lama_ckpt, device=device)
69 | save_array_to_img(img_inpainted, os.path.join(output, "inpaint.png"))
--------------------------------------------------------------------------------
/relight/relight.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | import copy
5 | import torch.nn.functional as F
6 | from tqdm import tqdm
7 | import torch
8 | import kornia as K
9 |
10 | from relight_utils import get_light_coeffs, generate_shd, writing_video
11 |
12 |
13 | class Relight:
14 | def __init__(self, img, mask_img, nrm, shd):
15 | # nrm: +X right, +Y down, +Z from screen to me, between 0 and 1
16 | self.img = img # RGB between 0 and 1
17 | self.mask_img = mask_img
18 | self.nrm = nrm
19 |
20 | src_msks = []
21 | seg_ids = np.unique(mask_img)
22 | seg_ids.sort()
23 | for seg_id in np.unique(mask_img):
24 | if seg_id == 0:
25 | continue
26 | src_msk = (mask_img == seg_id)
27 | src_msks.append(src_msk)
28 | src_msks = np.stack(src_msks, axis=0) # (N, H, W)
29 | self.src_msks = torch.from_numpy(src_msks)
30 | self.num_classes = src_msks.shape[0]
31 |
32 | self.coeffs = get_light_coeffs(shd, self.nrm, self.img) # lambertian fitting
33 | shd[self.mask_img > 0] = generate_shd(self.nrm, self.coeffs, self.mask_img > 0)
34 | self.alb = (self.img ** 2.2) / shd[:, :, None].clip(1e-4)
35 | self.shd = shd
36 |
37 |
38 | def relight(self, comp_img, target_obj_mask, trans):
39 | # target_obj_mask (H, W) segmentation map
40 | # trans: (N, 2, 3)
41 | binary_seg_mask = []
42 | for seg_id in np.unique(self.mask_img):
43 | if seg_id == 0:
44 | continue
45 | seg_mask = target_obj_mask == seg_id
46 | binary_seg_mask.append(seg_mask)
47 |
48 | binary_seg_mask = np.stack(binary_seg_mask, axis=0) # (N, H, W)
49 | binary_seg_mask = torch.from_numpy(binary_seg_mask[:, None, :, :]) # (N, 1, H, W)
50 |
51 | assert binary_seg_mask.shape[0] == self.num_classes, "number of segmentation ids should be equal to the number of foreground objects"
52 |
53 | src_normal = torch.from_numpy(self.nrm).unsqueeze(dim=0).repeat(self.num_classes, 1, 1, 1) # (N, H, W, 3)
54 | rot_matrix = trans[:, :2, :2] # (N, 2, 2)
55 | full_rot = torch.cat([rot_matrix, torch.zeros(self.num_classes, 1, 2)], dim=1) # (N, 3, 2)
56 | full_rot = torch.cat([full_rot, torch.tensor([0, 0, 1]).unsqueeze(dim=0).unsqueeze(dim=-1).repeat(self.num_classes, 1, 1)], dim=2) # (N, 3, 3)
57 |
58 | nrm = (src_normal * 2.0) - 1.0
59 | nrm = nrm.reshape(self.num_classes, -1, 3) # (N, H*W, 3)
60 | trans_nrm = torch.bmm(nrm, full_rot) # the core rot is created by kornia, so there is no transpose # (N, H*W, 3)
61 | trans_nrm = (trans_nrm +1) / 2.0
62 | trans_nrm = trans_nrm.reshape(self.num_classes, self.img.shape[0], self.img.shape[1], 3) # (N, H, W, 3)
63 |
64 | trans_nrm = trans_nrm.permute(0, 3, 1, 2) # (N, 3, H, W)
65 | out = K.geometry.warp_affine(trans_nrm, trans, (mask_img.shape[0], mask_img.shape[1])) # (N, 3, H, W)
66 | target_nrm = (out * binary_seg_mask.float()).sum(dim=0) # (3, H, W)
67 | target_nrm = target_nrm.permute(1, 2, 0).numpy() # (H, W, 3)
68 |
69 | comp_nrm = copy.deepcopy(self.nrm)
70 | comp_nrm[target_obj_mask > 0] = target_nrm[target_obj_mask > 0]
71 |
72 | comp_shd = copy.deepcopy(self.shd)
73 | comp_shd[target_obj_mask > 0] = generate_shd(comp_nrm, self.coeffs, target_obj_mask > 0)
74 | comp_alb = (comp_img ** 2.2) / comp_shd[:, :, None].clip(1e-4)
75 |
76 | # compose albedo from src
77 | assert trans.shape[0] == self.num_classes, "number of transformations should be equal to the number of foreground objects"
78 | src_alb = torch.from_numpy(self.alb).permute(2, 0, 1).unsqueeze(dim=0).repeat(self.num_classes, 1, 1, 1) # (N, 3, H, W)
79 | out = K.geometry.warp_affine(src_alb, trans, (self.mask_img.shape[0], self.mask_img.shape[1])) # (N, 3, H, W)
80 |
81 | foreground_alb = (out * binary_seg_mask.float()).sum(dim=0) # (3, H, W)
82 | foreground_alb = foreground_alb.permute(1, 2, 0).numpy() # (H, W, 3)
83 | comp_alb[target_obj_mask > 0, :] = foreground_alb[target_obj_mask > 0, :]
84 |
85 | compose_img = ((comp_alb * comp_shd[:, :, None])** (1/2.2)).clip(0, 1)
86 |
87 | return compose_img
88 |
89 |
90 | def prepare_input(video_path, mask_video_path, trans_list_path):
91 | comp_video = torch.load(video_path) # (f, 3, H, W) (16, 3, 512, 512) between 0 and 255
92 | if comp_video.max().item() > 1.:
93 | comp_video = comp_video / 255.0
94 | if comp_video.shape[1] ==3: # (f, 3, H, W) -> (f, H, W, 3)
95 | comp_video = comp_video.permute(0, 2, 3, 1)
96 | T, H, W = comp_video.shape[:3]
97 | obj_masks = torch.load(mask_video_path).squeeze().float() # (f, H, W)
98 | trans_list = torch.load(trans_list_path)
99 | if trans_list.ndim == 3: # (f, 2, 3) -> (f, 1, 2, 3)
100 | trans_list = trans_list.unsqueeze(dim=1)
101 | if trans_list.shape[-2] == 3: # (f, *, 3, 3) -> (f, *, 2, 3)
102 | trans_list = trans_list[:, :, :2, :]
103 | assert comp_video.shape[0] == obj_masks.shape[0] == trans_list.shape[0], "video and mask should have the same length"
104 | comp_video = comp_video.numpy()
105 | if obj_masks.ndim == 4:
106 | obj_masks = obj_masks[:, :, 0] # (f, H, W)
107 | obj_masks = obj_masks.numpy() # (f, H, W)
108 | return comp_video, obj_masks, trans_list
109 |
110 |
111 | if __name__ == '__main__':
112 | parser = argparse.ArgumentParser()
113 | parser.add_argument("--perception_input", type=str, default="../data/pool", help='input dir')
114 | parser.add_argument('--previous_output', type=str, default="../outputs/pool", help='previous output dir')
115 |
116 | args = parser.parse_args()
117 |
118 | perception_input = args.perception_input
119 | previous_output = args.previous_output
120 |
121 | video_path = os.path.join(previous_output, "composite.pt")
122 | mask_video_path = os.path.join(previous_output, "mask_video.pt")
123 | trans_list_path = os.path.join(previous_output, "trans_list.pt")
124 | comp_video, obj_masks, trans_list = prepare_input(video_path, mask_video_path, trans_list_path)
125 |
126 | normal_path = os.path.join(perception_input, "normal.npy")
127 | shading_path = os.path.join(perception_input, "shading.npy")
128 | normal = np.load(normal_path) # (H, W, 3) between -1 and 1
129 | # convert geowizard normal to OMNI normal
130 | normal[:, :, 0]= -normal[:, :, 0]
131 | normal = (normal + 1) / 2.0
132 | shading = np.load(shading_path)
133 |
134 |
135 | output = args.previous_output
136 | os.makedirs(output, exist_ok=True)
137 |
138 | T = comp_video.shape[0]
139 | img = comp_video[0] # (H, W, 3)
140 | mask_img = obj_masks[0]
141 | model = Relight(img, mask_img, normal, shading)
142 |
143 | relight_list = [img]
144 | for time_idx in tqdm(range(1, T)):
145 | comp_img = comp_video[time_idx]
146 | target_obj_mask = obj_masks[time_idx] # segmentation msk
147 | trans = trans_list[time_idx] # (*, 2, 3)
148 |
149 | compose_img = model.relight(comp_img, target_obj_mask, trans)
150 | relight_list.append(compose_img)
151 |
152 | relight_video = np.stack(relight_list, axis=0)
153 | torch.save(torch.from_numpy(relight_video).permute(0, 3, 1, 2), f'{output}/relight.pt') # (0, 1)
154 | relight_video = (relight_video * 255).astype(np.uint8)
155 | writing_video(relight_video[..., ::-1], f'{output}/relight.mp4', frame_rate=7)
156 | print('done!')
--------------------------------------------------------------------------------
/relight/relight_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import cv2
4 | from torch.optim import Adam
5 |
6 | def invert(x):
7 | out = 1.0 / (x + 1.0)
8 | return out
9 |
10 |
11 | def uninvert(x, eps=0.001, clip=True):
12 | if clip:
13 | x = x.clip(eps, 1.0)
14 |
15 | out = (1.0 / x) - 1.0
16 | return out
17 |
18 |
19 | def spherical2cart(r, theta, phi):
20 | return [
21 | r * torch.sin(theta) * torch.cos(phi),
22 | r * torch.sin(theta) * torch.sin(phi),
23 | r * torch.cos(theta)
24 | ]
25 |
26 |
27 | def run_optimization(params, A, b):
28 |
29 | optim = Adam([params], lr=0.01)
30 | prev_loss = 1000
31 |
32 | for i in range(500):
33 | optim.zero_grad()
34 |
35 | x, y, z = spherical2cart(params[2], params[0], params[1])
36 |
37 | dir_shd = (A[:, 0] * x) + (A[:, 1] * y) + (A[:, 2] * z)
38 | pred_shd = dir_shd + params[3]
39 |
40 | loss = torch.nn.functional.mse_loss(pred_shd.reshape(-1), b)
41 |
42 | loss.backward()
43 |
44 | optim.step()
45 |
46 | # theta can range from 0 -> pi/2 (0 to 90 degrees)
47 | # phi can range from 0 -> 2pi (0 to 360 degrees)
48 | with torch.no_grad():
49 | if params[0] < 0:
50 | params[0] = 0
51 |
52 | if params[0] > np.pi / 2:
53 | params[0] = np.pi / 2
54 |
55 | if params[1] < 0:
56 | params[1] = 0
57 |
58 | if params[1] > 2 * np.pi:
59 | params[1] = 2 * np.pi
60 |
61 | if params[2] < 0:
62 | params[2] = 0
63 |
64 | if params[3] < 0.1:
65 | params[3] = 0.1
66 |
67 | delta = prev_loss - loss
68 |
69 | if delta < 0.0001:
70 | break
71 |
72 | prev_loss = loss
73 |
74 | return loss, params
75 |
76 |
77 | def test_init(params, A, b):
78 | x, y, z = spherical2cart(params[2], params[0], params[1])
79 |
80 | dir_shd = (A[:, 0] * x) + (A[:, 1] * y) + (A[:, 2] * z)
81 | pred_shd = dir_shd + params[3]
82 |
83 | loss = torch.nn.functional.mse_loss(pred_shd.reshape(-1), b)
84 | return loss
85 |
86 |
87 | def get_light_coeffs(shd, nrm, img, mask=None):
88 | valid = (img.mean(-1) > 0.05) * (img.mean(-1) < 0.95)
89 |
90 | if mask is not None:
91 | valid *= (mask == 0)
92 |
93 | nrm = (nrm * 2.0) - 1.0
94 |
95 | A = nrm[valid == 1]
96 | A /= np.linalg.norm(A, axis=1, keepdims=True)
97 | b = shd[valid == 1]
98 |
99 | # parameters are theta, phi, and bias (c)
100 | A = torch.from_numpy(A)
101 | b = torch.from_numpy(b)
102 |
103 | min_init = 1000
104 | for t in np.arange(0, np.pi/2, 0.1):
105 | for p in np.arange(0, 2*np.pi, 0.25):
106 | params = torch.nn.Parameter(torch.tensor([t, p, 1, 0.5]))
107 | init_loss = test_init(params, A, b)
108 |
109 | if init_loss < min_init:
110 | best_init = params
111 | min_init = init_loss
112 |
113 | loss, params = run_optimization(best_init, A, b)
114 |
115 | x, y, z = spherical2cart(params[2], params[0], params[1])
116 |
117 | coeffs = torch.tensor([x, y, z]).reshape(3, 1).detach().numpy()
118 | coeffs = np.array([x.item(), y.item(), z.item(), params[3].item()])
119 | return coeffs
120 |
121 |
122 | def generate_shd(nrm, coeffs, msk, bias=True):
123 |
124 | nrm = (nrm * 2.0) - 1.0
125 |
126 | A = nrm.reshape(-1, 3)
127 | A /= np.linalg.norm(A, axis=1, keepdims=True)
128 |
129 | A_fg = nrm[msk == 1]
130 | A_fg /= np.linalg.norm(A_fg, axis=1, keepdims=True)
131 |
132 | if bias:
133 | A = np.concatenate((A, np.ones((A.shape[0], 1))), 1)
134 | A_fg = np.concatenate((A_fg, np.ones((A_fg.shape[0], 1))), 1)
135 |
136 | inf_shd = (A_fg @ coeffs)
137 | inf_shd = inf_shd.clip(0) + 0.2
138 | return inf_shd
139 |
140 |
141 | def writing_video(rgb_list, save_path: str, frame_rate: int = 30):
142 | fourcc = cv2.VideoWriter_fourcc(*'MP4V')
143 | h, w, _ = rgb_list[0].shape
144 | out = cv2.VideoWriter(save_path, fourcc, frame_rate, (w, h))
145 |
146 | for img in rgb_list:
147 | out.write(img)
148 |
149 | out.release()
150 | return
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # basic simulation requirements
2 | omegaconf
3 | pygame
4 | pymunk
5 | opencv-python
6 | imantics
7 | pillow
8 | kornia
9 | av
10 | # seine requirements
11 | numpy==1.26.4
12 | matplotlib==3.9.2
13 | torch==2.0.1
14 | torchaudio==2.0.2
15 | torchvision==0.15.2
16 | decord==0.6.0
17 | diffusers==0.15.0
18 | imageio==2.29.0
19 | transformers==4.29.2
20 | xformers==0.0.20
21 | einops
22 | tensorboard==2.15.1
23 | timm==0.9.10
24 | rotary-embedding-torch==0.3.5
25 | natsort==8.4.0
26 | openai==0.28
27 | ruamel.yaml
28 |
29 |
30 |
--------------------------------------------------------------------------------
/scripts/run_demo.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | NAME=$1
4 |
5 | cd simulation
6 | python animate.py --config ../data/${NAME}/sim.yaml --save_root ../outputs
7 |
8 | cd ../relight
9 | python relight.py --perception_input ../data/${NAME} --previous_output ../outputs/${NAME}
10 |
11 | cd ../diffusion
12 | python video_diffusion.py --perception_input ../data/${NAME} --previous_output ../outputs/${NAME}
--------------------------------------------------------------------------------
/simulation/animate.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional
3 | import argparse
4 | from omegaconf import OmegaConf
5 | import pygame
6 | import pymunk
7 | import pymunk.pygame_util
8 | import numpy as np
9 | import cv2
10 | import matplotlib.pyplot as plt
11 | from matplotlib import cm
12 | import matplotlib.colors as colors
13 | from imantics import Mask
14 |
15 | from animate_utils import AnimateObj, AnimateImgSeg
16 | from sim_utils import fit_circle_from_mask, list_to_numpy
17 |
18 |
19 |
20 | class AnimateUni(AnimateObj):
21 |
22 | def __init__(self,
23 | save_dir,
24 | mask_path,
25 | obj_info: dict,
26 | edge_list: Optional[np.array]=None,
27 | init_velocity: Optional[dict]=None,
28 | init_acc: Optional[dict]=None,
29 | gravity=980,
30 | ground_elasticity=0,
31 | ground_friction=1,
32 | num_steps=200,
33 | size=(512, 512),
34 | display=True,
35 | save_snapshot: Optional[bool]=False,
36 | snapshot_frames: Optional[int]=16,
37 | colormap: Optional[str]="tab20") -> None:
38 |
39 |
40 | super(AnimateUni, self).__init__(save_dir, size, gravity, display)
41 | self.num_steps = num_steps
42 | self.objs = {} # key is seg_id, value is pymunk shape
43 | mask_img = cv2.imread(mask_path, 0)
44 | self.mask_img = mask_img
45 |
46 | self.save_snapshot = save_snapshot
47 | if self.save_snapshot:
48 | self.snapshot_dir = os.path.join(save_dir, "snapshot")
49 | os.makedirs(self.snapshot_dir, exist_ok=True)
50 | self.snapshot_frames = snapshot_frames
51 | self.colormap = colormap
52 |
53 | self.obj_info = obj_info
54 |
55 | for seg_id in self.obj_info:
56 | class_name = self.obj_info[seg_id]["primitive"]
57 | density = self.obj_info[seg_id]["density"]
58 | mass = self.obj_info[seg_id]["mass"]
59 | elasticity = self.obj_info[seg_id]["elasticity"]
60 | friction = self.obj_info[seg_id]["friction"]
61 | if class_name == "polygon":
62 | mask = self.mask_img == seg_id
63 | polygons = Mask(mask).polygons()
64 | if len(polygons.points) > 1: # find the largest polygon
65 | areas = []
66 | for points in polygons.points:
67 | points = points.reshape((-1, 1, 2)).astype(np.int32)
68 | img = np.zeros((self.mask_img.shape[0], self.mask_img.shape[1], 3), dtype=np.uint8)
69 | img = cv2.fillPoly(img, [points], color=[0, 255, 0])
70 | mask = img[:, :, 1] > 0
71 | area = np.count_nonzero(mask)
72 | areas.append(area)
73 | areas = np.array(areas)
74 | largest_idx = np.argmax(areas)
75 | points = polygons.points[largest_idx]
76 | else:
77 | points = polygons.points[0]
78 |
79 | points = tuple(map(tuple, points))
80 | poly = self._create_poly(density, points, elasticity, friction)
81 | self.objs[seg_id] = poly
82 |
83 | elif class_name == "circle":
84 | mask = self.mask_img == int(seg_id)
85 | mask = (mask * 255).astype(np.uint8)
86 | center, radius = fit_circle_from_mask(mask)
87 | ball = self._create_ball(mass, radius, center, elasticity, friction)
88 | self.objs[seg_id] = ball
89 |
90 | if edge_list is not None:
91 | self._create_wall_segments(edge_list, ground_elasticity, ground_friction)
92 |
93 | self.init_velocity = init_velocity
94 | self.init_acc = init_acc
95 |
96 | def _create_wall_segments(self, edge_list, elasticity=0.05, friction=0.9):
97 | """Create a number of wall segments connecting the points"""
98 | for edge in edge_list:
99 | point_1, point_2 = edge
100 | v1 = pymunk.Vec2d(point_1[0], point_1[1])
101 | v2 = pymunk.Vec2d(point_2[0], point_2[1])
102 | wall_body = pymunk.Body(body_type=pymunk.Body.STATIC)
103 | wall_shape = pymunk.Segment(wall_body, v1, v2, 0.0)
104 | wall_shape.collision_type = 0
105 | wall_shape.elasticity = elasticity
106 | wall_shape.friction = friction
107 | self._space.add(wall_body, wall_shape)
108 |
109 | def _create_ball(self, mass, radius, position, elasticity=0.5, friction=0.4):
110 | """
111 | Create a ball defined by mass, radius and position
112 | """
113 | inertia = pymunk.moment_for_circle(mass, 0, radius, (0, 0))
114 | body = pymunk.Body(mass, inertia)
115 | body.position = position
116 | shape = pymunk.Circle(body, radius, (0, 0))
117 | shape.elasticity = elasticity
118 | shape.friction = friction
119 | self._space.add(body, shape)
120 | return shape
121 |
122 | def _create_poly(self, density, points, elasticity=0.5, friction=0.4, collision_type=0):
123 | """
124 | Create a poly defined by density, points
125 | """
126 | body = pymunk.Body()
127 | shape = pymunk.Poly(body, points)
128 | shape.density = density
129 | shape.elasticity = elasticity
130 | shape.friction = friction
131 | shape.collision_type = collision_type
132 | self._space.add(body, shape)
133 | return shape
134 |
135 |
136 | def _draw_objects(self) -> None:
137 | jet = plt.get_cmap(self.colormap)
138 | cNorm = colors.Normalize(vmin=0, vmax=len(self.objs))
139 | scalarMap = cm.ScalarMappable(norm=cNorm, cmap=jet)
140 | for seg_id in self.objs:
141 | shape = self.objs[seg_id]
142 | color = scalarMap.to_rgba(int(seg_id))
143 | color = (color[0]*255, color[1]*255, color[2]*255, 255)
144 | if isinstance(shape, pymunk.Circle):
145 | self.draw_ball(shape, color)
146 | elif isinstance(shape, pymunk.Poly):
147 | self.draw_poly(shape, color)
148 | for shape in self._space.shapes:
149 | if isinstance(shape, pymunk.Segment):
150 | self.draw_wall(shape)
151 |
152 | def draw_ball(self, ball, color=(0, 0, 255, 255)):
153 | body = ball.body
154 | v = body.position + ball.offset.cpvrotate(body.rotation_vector)
155 | r = ball.radius
156 | pygame.draw.circle(self._screen, pygame.Color(color), v, int(r), 0)
157 |
158 | def draw_wall(self, wall, color=(252, 3, 169, 255), width=3):
159 | body = wall.body
160 | pv1 = body.position + wall.a.cpvrotate(body.rotation_vector)
161 | pv2 = body.position + wall.b.cpvrotate(body.rotation_vector)
162 | pygame.draw.lines(self._screen, pygame.Color(color), False, [pv1, pv2], width=width)
163 |
164 | def draw_poly(self, poly, color=(0, 255, 0, 255)):
165 | body = poly.body
166 | ps = [p.rotated(body.angle) + body.position for p in poly.get_vertices()]
167 | ps.append(ps[0])
168 | pygame.draw.polygon(self._screen, pygame.Color(color), ps)
169 |
170 | def get_transform(self):
171 | # get current timestep objs state transformation
172 | """
173 | rotation_matrix: np.array([[math.cos(rad), -math.sin(rad)], [math.sin(rad), math.cos(rad)]])
174 | in Kornia, the rotation matrix created by K.geometry.transform.get_rotation_matrix2d should be transpose,
175 | equivalent to use -angle in record
176 | """
177 | state = {}
178 | keys = list(self.objs.keys())
179 | keys = sorted(keys, key=lambda x: int(x))
180 | for seg_id in keys:
181 | shape = self.objs[seg_id]
182 | if isinstance(shape, pymunk.Poly):
183 | ps = [p.rotated(shape.body.angle) + shape.body.position for p in shape.get_vertices()]
184 | ps = np.array(ps)
185 | center = np.mean(ps, axis=0)
186 | angle = shape.body.angle
187 |
188 | elif isinstance(shape, pymunk.Circle):
189 | center = shape.body.position
190 | angle = shape.body.angle
191 |
192 | state[seg_id] = (center, angle)
193 | return state
194 |
195 | def init_condition(self, init_velocity, init_acc):
196 | for seg_id in self.objs:
197 | shape = self.objs[seg_id]
198 | query_seg_id = seg_id
199 | if init_velocity is not None and query_seg_id in init_velocity:
200 | ins_init_vel = list(init_velocity[query_seg_id]) if not isinstance(init_velocity[query_seg_id], list) else init_velocity[query_seg_id]
201 | shape.body.velocity = ins_init_vel
202 | if init_acc is not None and query_seg_id in init_acc:
203 | ins_init_acc = init_acc[query_seg_id]
204 | shape.body.apply_impulse_at_local_point((ins_init_acc[0] * shape.body.mass, ins_init_acc[1] * shape.body.mass), (0, 0))
205 |
206 | def run(self) -> None:
207 | """
208 | The main loop of the game.
209 | :return: None
210 | """
211 | # Main loop
212 | num_steps = self.num_steps // self._physics_steps_per_frame
213 | count = 0
214 | self.init_condition(self.init_velocity, self.init_acc)
215 | if self.save_snapshot:
216 | # save snapshot of the simulation of #animate frames
217 | snapshot_indices = np.arange(num_steps)[::num_steps//(self.snapshot_frames)][:self.snapshot_frames]
218 | snapshot_indices[0] = 0
219 | save_path = os.path.join(self.snapshot_dir, "snapshot_{:03d}.png")
220 | self._process_events()
221 | self._clear_screen()
222 | self._draw_objects()
223 | pygame.display.flip()
224 | pygame.image.save(self._screen, save_path.format(0))
225 |
226 | while self._running and count < num_steps:
227 | # Progress time forward
228 | count += 1
229 | for x in range(self._physics_steps_per_frame):
230 | self._space.step(self._dt)
231 | state = self.get_transform()
232 | self.history.append(state)
233 | if self.display:
234 | self._process_events()
235 | self._clear_screen()
236 | self._draw_objects()
237 | pygame.display.flip()
238 | if self.save_snapshot and count in snapshot_indices:
239 | index = np.where(snapshot_indices==count)[0].item()
240 | pygame.image.save(self._screen, save_path.format(index))
241 |
242 | self._clock.tick(50)
243 | pygame.display.set_caption("fps: " + str(self._clock.get_fps()))
244 | self.save_state()
245 |
246 |
247 | def main(args, data_root, save_root):
248 | save_dir = os.path.join(save_root, args.cat.lower())
249 | os.makedirs(save_dir, exist_ok=True)
250 |
251 | data_dir = os.path.join(data_root, args.cat)
252 | mask_path=os.path.join(data_dir, "mask.png")
253 |
254 | anim = AnimateUni(
255 | mask_path=mask_path, save_dir=save_dir,
256 | obj_info=args.obj_info, edge_list=args.edge_list,
257 | init_velocity=args.init_velocity, init_acc=args.init_acc,
258 | num_steps=args.num_steps,
259 | gravity=args.gravity,
260 | ground_elasticity=getattr(args, "ground_elasticity", 0),
261 | ground_friction=getattr(args, "ground_friction", 1),
262 | size=args.size, display=args.display, save_snapshot=args.save_snapshot, snapshot_frames=args.animation_frames)
263 | anim.run()
264 |
265 | history_path = os.path.join(save_dir, "history.pkl")
266 | animation_frames = getattr(args, "animation_frames", 16)
267 | replay = AnimateImgSeg(data_dir=data_dir, save_dir=save_dir, history_path=history_path, animate_frames=animation_frames)
268 | replay.record()
269 |
270 |
271 | if __name__ == "__main__":
272 | parser = argparse.ArgumentParser()
273 | parser.add_argument("--data_root", type=str, default="../data")
274 | parser.add_argument("--save_root", type=str, default="../outputs")
275 | parser.add_argument("--config", type=str, default="../data/pool/sim.yaml")
276 | args = parser.parse_args()
277 | config = OmegaConf.load(args.config)
278 | config = list_to_numpy(config)
279 | main(config, data_root=args.data_root, save_root=args.save_root)
--------------------------------------------------------------------------------
/simulation/animate_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from typing import List
4 | import numpy as np
5 | import pickle
6 | import pygame
7 | import pymunk
8 | import pymunk.pygame_util
9 |
10 | import kornia as K
11 | from sim_utils import writing_video, prep_data, composite_trans
12 |
13 |
14 | class AnimateObj(object):
15 |
16 | def __init__(self, save_dir, size=(512,512), gravity=980, display=True) -> None:
17 | self.size = size
18 | self.save_dir = save_dir
19 | os.makedirs(self.save_dir, exist_ok=True)
20 | self.history = [] # record state
21 |
22 | # Space
23 | self._space = pymunk.Space()
24 | # Physics
25 | # Time step
26 | self._space.gravity = (0.0, gravity) # positive_y_is_down
27 | self._dt = 1.0 / 60.0
28 | # Number of physics steps per screen frame
29 | self._physics_steps_per_frame = 1 # larger, faster
30 | self.history = []
31 | # pygame
32 | self.display = display
33 | if self.display:
34 | pygame.init()
35 | self._screen = pygame.display.set_mode(size)
36 | self._clock = pygame.time.Clock()
37 | self._draw_options = pymunk.pygame_util.DrawOptions(self._screen)
38 |
39 | # Execution control and time until the next ball spawns
40 | self._running = True
41 |
42 | def run(self) -> None:
43 | """Custom implement func
44 | The main loop of the game.
45 | :return: None
46 | """
47 | # Main loop
48 | while self._running:
49 | # Progress time forward
50 | for x in range(self._physics_steps_per_frame):
51 | self._space.step(self._dt)
52 |
53 | state = self.get_transform()
54 | self.history.append(state)
55 | if self.display:
56 | self._process_events()
57 | self.draw()
58 | # self._clear_screen()
59 | # self._draw_objects()
60 | pygame.display.flip()
61 | # Delay fixed time between frames
62 | self._clock.tick(50)
63 | pygame.display.set_caption("fps: " + str(self._clock.get_fps()))
64 | self.save_state()
65 |
66 |
67 | def get_transform(self, verbose=False) -> List[int]:
68 | # custom func
69 | center = self.ball.body.position
70 | angle = self.poly.body.angle
71 | if verbose:
72 | print(center, angle)
73 | return center, angle
74 |
75 | def save_state(self):
76 | # save self.state
77 | with open(os.path.join(self.save_dir, 'history.pkl'), 'wb') as f:
78 | pickle.dump(self.history, f)
79 |
80 | def _process_events(self) -> None:
81 | """
82 | Handle game and events like keyboard input. Call once per frame only.
83 | :return: None
84 | """
85 | for event in pygame.event.get():
86 | if event.type == pygame.QUIT:
87 | self._running = False
88 | elif event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE:
89 | self._running = False
90 |
91 | def _clear_screen(self) -> None:
92 | """
93 | Clears the screen.
94 | :return: None
95 | """
96 | self._screen.fill(pygame.Color("white"))
97 |
98 | def _draw_objects(self) -> None:
99 | """
100 | Draw the objects.
101 | :return: None
102 | """
103 | self._space.debug_draw(self._draw_options)
104 |
105 |
106 | class AnimateImgSeg(object):
107 | # Animate Img from Segmentation
108 | def __init__(self, data_dir, save_dir, history_path, animate_frames=16):
109 | self.save_dir = save_dir
110 | os.makedirs(self.save_dir, exist_ok=True)
111 |
112 | self.img, self.mask_img, inpaint_img = prep_data(data_dir)
113 |
114 | with open(history_path, 'rb') as f:
115 | history = pickle.load(f)
116 | self.animate_frames = animate_frames
117 | self.history = history[::len(history)//(self.animate_frames)][:animate_frames]
118 | assert len(self.history) == self.animate_frames
119 |
120 | active_keys = list(self.history[0].keys())
121 | self.active_keys = sorted(active_keys, key=lambda x: int(x))
122 | for seg_id in np.unique(self.mask_img):
123 | if seg_id == 0: # background
124 | continue
125 | if seg_id not in active_keys:
126 | mask = self.mask_img == int(seg_id)
127 | inpaint_img[mask] = self.img[mask]
128 | self.mask_img[mask] = 0 # set as background
129 | self.inpaint_img = inpaint_img # inpaint_img is the background
130 |
131 | init_centers = []
132 | for seg_id in active_keys:
133 | if seg_id == "0":
134 | raise ValueError("seg_id should not be 0")
135 | mask = self.mask_img == int(seg_id)
136 | init_center = np.mean(np.argwhere(mask), axis=0)
137 | init_center = np.array([init_center[1], init_center[0]]) # center (y,x) to (x,y)
138 | init_centers.append(init_center)
139 | init_centers = np.stack(init_centers, axis=0) # (num_objs, 2)
140 | self.init_centers = init_centers.astype(np.float32)
141 |
142 | def record(self):
143 | H, W = self.img.shape[:2]
144 | masked_src_imgs = []
145 | for seg_id in self.active_keys:
146 | mask = self.mask_img == int(seg_id)
147 | src_img = torch.from_numpy(self.img * mask[:, :, None]).permute(2, 0, 1).float() # (3, H, W)
148 | masked_src_imgs.append(src_img)
149 | masked_src_imgs = torch.stack(masked_src_imgs, dim=0) # tensor (num_objs, 3, H, W)
150 | init_centers = torch.from_numpy(self.init_centers).float()# tensor (num_objs, 2)
151 |
152 | imgs = []
153 | msk_list = []
154 | trans_list = [] # (animate_frames [dict{seg_id: Trans(2, 3)}]
155 | for i in range(len(self.history)):
156 | if i == 0: # use original image
157 | trans = []
158 | imgs.append((self.img*255).astype(np.uint8))
159 | # active segmentation mask
160 | seg_mask = np.zeros_like(self.mask_img)
161 | for seg_id in self.active_keys:
162 | seg_mask[self.mask_img == int(seg_id)] = seg_id
163 | trans.append(torch.eye(3)[:2, :]) # (2, 3)
164 | msk_list.append(seg_mask)
165 | trans = torch.stack(trans, dim=0) # (num_objs, 2, 3)
166 | trans_list.append(trans)
167 | else:
168 | history = self.history[i] # dict of seg_id: (center, angle)
169 | centers, scales, angles = [], [], []
170 | for seg_id in self.active_keys:
171 | center, angle = history[seg_id]
172 |
173 | # the rotation matrix used in K is pymunk rotation transpose, thus use -angle
174 | angle = -angle
175 |
176 | center = torch.tensor([center]).float() # (1, 2)
177 | angle = torch.tensor([angle/np.pi * 180]).float() # [1]
178 | scale = torch.ones(1, 2).float()
179 | centers.append(center)
180 | scales.append(scale)
181 | angles.append(angle)
182 | centers = torch.cat(centers, dim=0) # (num_objs, 2)
183 | scales = torch.cat(scales, dim=0) # (num_objs, 2)
184 | angles = torch.cat(angles, dim=0) # (num_objs)
185 |
186 | trans = K.geometry.transform.get_rotation_matrix2d(init_centers, angles, scales)
187 | trans[:, :, 2] += centers - init_centers # (num_objs, 2, 3)
188 |
189 | active_list = list(map(int, self.active_keys))
190 | final_frame, seg_mask = composite_trans(masked_src_imgs, trans, self.inpaint_img, active_list)
191 |
192 | imgs.append(final_frame)
193 | msk_list.append(seg_mask)
194 | trans_list.append(trans) # (num_objs, 2, 3)
195 |
196 | imgs = np.stack(imgs, axis=0) # (animate_frames, H, W, 3)
197 | writing_video(imgs[..., ::-1], os.path.join(self.save_dir, 'composite.mp4'), frame_rate=7)
198 |
199 | msk_list = np.stack(msk_list, axis=0) # (animate_frames, H, W)
200 | msk_list = torch.from_numpy(msk_list)
201 | imgs = torch.from_numpy(imgs).permute(0, 3, 1, 2)
202 | trans_list = np.stack(trans_list, axis=0) # (animate_frames, num_objs, 2, 3)
203 | trans_list = torch.from_numpy(trans_list)
204 | assert len(imgs) == len(msk_list) == len(trans_list)
205 |
206 | torch.save(msk_list, os.path.join(self.save_dir,'mask_video.pt')) # foreground objs segmentation mask
207 | torch.save(imgs, os.path.join(self.save_dir,'composite.pt'))
208 | torch.save(trans_list, os.path.join(self.save_dir, "trans_list.pt"))
--------------------------------------------------------------------------------
/simulation/sim_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import os
4 | import torch
5 | import torch.nn.functional as F
6 | from PIL import Image
7 | import kornia as K
8 | from omegaconf import OmegaConf
9 |
10 |
11 | def get_bbox(mask):
12 | rows = np.any(mask, axis=1)
13 | cols = np.any(mask, axis=0)
14 | rmin, rmax = np.where(rows)[0][[0, -1]]
15 | cmin, cmax = np.where(cols)[0][[0, -1]]
16 |
17 | return rmin, rmax, cmin, cmax
18 |
19 |
20 | def writing_video(rgb_list, save_path: str, frame_rate: int = 30):
21 | fourcc = cv2.VideoWriter_fourcc(*'mp4v')
22 | h, w, _ = rgb_list[0].shape
23 | out = cv2.VideoWriter(save_path, fourcc, frame_rate, (w, h))
24 |
25 | for img in rgb_list:
26 | out.write(img)
27 |
28 | out.release()
29 | return
30 |
31 |
32 | def prep_data(data_dir):
33 |
34 | img_path = os.path.join(data_dir, "original.png")
35 | mask_path = os.path.join(data_dir, "mask.png")
36 | inpaint_path = os.path.join(data_dir, "inpaint.png")
37 | img = np.array(Image.open(img_path)) / 255
38 | mask_img = np.array(Image.open(mask_path))
39 | if mask_img.ndim == 3:
40 | mask_img = mask_img[:, :, 0]
41 | inpaint_img = np.array(Image.open(inpaint_path)).astype(np.float32) / 255
42 | return img, mask_img, inpaint_img
43 |
44 |
45 | def fit_circle_from_mask(mask_image):
46 | contours, _ = cv2.findContours(mask_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
47 |
48 | if len(contours) == 0:
49 | print("No contours found in the mask image.")
50 | return None
51 | max_contour = max(contours, key=cv2.contourArea)
52 | (x, y), radius = cv2.minEnclosingCircle(max_contour)
53 |
54 | center = (x, y)
55 | return center, radius
56 |
57 |
58 | def composite_trans(masked_src_imgs, trans, inpaint_img, active_seg_ids):
59 | """params:
60 | src_imgs: (N, C, H, W)
61 | src_seg: (H, W) original segmentaion mask
62 | trans: (N, 2, 3)
63 | inpaint_img: (C, H, W)
64 | active_seg_ids: (N, ) record segmentaion id
65 | thre: threshold for foreground segmentaion mask
66 | return:
67 |
68 | """
69 | H, W = inpaint_img.shape[:2]
70 | out = K.geometry.warp_affine(masked_src_imgs, trans, (H, W)) # (N, C, H, W)
71 | src_binary_seg = (masked_src_imgs.sum(dim=1, keepdim=True) > 0).float() # (N, 1, H, W)
72 | out_binary_seg = K.geometry.warp_affine(src_binary_seg, trans, (H, W)) # (N, C, H, W)
73 |
74 | foreground_msk = out_binary_seg.sum(dim=0).sum(dim=0) > 0 # (H, W)
75 | seg_map = torch.zeros((H, W)).long()
76 | seg_map[~foreground_msk] = 0
77 | seg_mask = out_binary_seg.sum(dim=1).argmax(dim=0) + 1 # (H, W)
78 | seg_map[foreground_msk] = seg_mask[foreground_msk] # 0 is background, 1~N is the fake segmentaion id
79 |
80 | num_classes = len(active_seg_ids) + 1
81 | binary_seg_map = F.one_hot(seg_map, num_classes=num_classes) # (H, W, N+1)
82 | binary_seg_map = binary_seg_map.permute(2, 0, 1).float() # (N+1, H, W)
83 | binary_seg_map = binary_seg_map.unsqueeze(dim=1) # (N, 1, H, W)
84 |
85 | inpaint_img = torch.from_numpy(inpaint_img).permute(2, 0, 1).unsqueeze(0) # (1, C, H, W)
86 | input = torch.cat([inpaint_img, out], dim=0) # (N+1, C, H, W)
87 |
88 | composite = (binary_seg_map * input).sum(0) # (C, H, W)
89 |
90 | composite = composite.permute(1, 2, 0).numpy() # (H, W, C)
91 | final_frame = (composite*255).astype(np.uint8)
92 |
93 | final_seg_map = torch.zeros((H, W)).long()
94 | for idx, seg_id in enumerate(active_seg_ids):
95 | final_seg_map[seg_map==idx+1] = seg_id
96 |
97 | final_seg_map = final_seg_map.numpy()
98 |
99 | return final_frame, final_seg_map
100 |
101 |
102 | def list_to_numpy(data):
103 | if isinstance(data, list):
104 | try:
105 | return np.array(data)
106 | except ValueError:
107 | return data
108 | elif isinstance(data, dict) or isinstance(data, OmegaConf):
109 | for key in data:
110 | data[key] = list_to_numpy(data[key])
111 | return data
--------------------------------------------------------------------------------