├── .gitignore
├── prompts
├── single_instruction.txt
├── similar_noun
│ ├── synonym_bowl.txt
│ ├── synonym_pot.txt
│ ├── synonym_laptop.txt
│ ├── synonym_knife.txt
│ ├── synonym_earphone.txt
│ ├── synonym_cap.txt
│ ├── synonym_jar.txt
│ ├── synonym_faucet.txt
│ ├── synonym_telephone.txt
│ ├── synonym_mug.txt
│ └── synonym_bag.txt
├── phone_on_base.txt
├── mug_40.txt
├── general_tasks_40.txt
├── rand_verb_40.txt
├── hamlet.txt
├── similar_verb_40.txt
├── rand_noun
│ ├── synonym_bag.txt
│ ├── synonym_cap.txt
│ ├── synonym_jar.txt
│ ├── synonym_mug.txt
│ ├── synonym_pot.txt
│ ├── synonym_bowl.txt
│ ├── synonym_earphone.txt
│ ├── synonym_faucet.txt
│ ├── synonym_knife.txt
│ ├── synonym_laptop.txt
│ └── synonym_telephone.txt
├── human_mug_statement_40.txt
├── jar.txt
├── robot_mug_statement_40.txt
└── mug.txt
├── method.png
├── common
├── synonym_remote.txt
├── assets
│ └── textures
│ │ ├── floor
│ │ ├── carpet.png
│ │ ├── forest.png
│ │ └── rubber_mats.png
│ │ ├── table
│ │ ├── puzzle.png
│ │ ├── stove.png
│ │ ├── stove2.png
│ │ └── stove3.png
│ │ └── wall
│ │ ├── cabinets.png
│ │ ├── fridge.png
│ │ ├── posters.png
│ │ ├── wardrobe.png
│ │ ├── white_wall.png
│ │ ├── kitchen_hanger.png
│ │ ├── kitchen_hanger2.png
│ │ └── kitchen_hanger3.png
├── __init__.py
├── counter.py
├── when.py
├── rnd.py
├── driver.py
├── mae_utils.py
├── flags.py
├── base_envs.py
├── rlbench_utils.py
├── expl.py
├── dists.py
├── logger.py
├── tfutils.py
├── config.py
├── kitchen.py
├── other.py
├── replay.py
├── nets.py
└── envs.py
├── LICENSE
├── r3mreward.py
├── README.md
├── env.yml
├── configs.yaml
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 |
--------------------------------------------------------------------------------
/prompts/single_instruction.txt:
--------------------------------------------------------------------------------
1 | Pick up the [NOUN]
--------------------------------------------------------------------------------
/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/method.png
--------------------------------------------------------------------------------
/prompts/similar_noun/synonym_bowl.txt:
--------------------------------------------------------------------------------
1 | bowl
2 | dish
3 | soup bowl
4 | wooden bowl
5 | plastic bowl
--------------------------------------------------------------------------------
/common/synonym_remote.txt:
--------------------------------------------------------------------------------
1 | remote control
2 | remote
3 | clicker
4 | controller
5 | TV remote
6 | video remote
--------------------------------------------------------------------------------
/prompts/similar_noun/synonym_pot.txt:
--------------------------------------------------------------------------------
1 | pot
2 | flowerpot
3 | vase
4 | container
5 | plant pot
6 | flower holder
--------------------------------------------------------------------------------
/prompts/similar_noun/synonym_laptop.txt:
--------------------------------------------------------------------------------
1 | laptop
2 | notebook
3 | ultrabook
4 | netbook
5 | tablet
6 | chromebook
7 | computer
--------------------------------------------------------------------------------
/prompts/similar_noun/synonym_knife.txt:
--------------------------------------------------------------------------------
1 | knife
2 | blade
3 | cutter
4 | chopper
5 | cleaver
6 | pocket knife
7 | Swiss army knife
--------------------------------------------------------------------------------
/common/assets/textures/floor/carpet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/floor/carpet.png
--------------------------------------------------------------------------------
/common/assets/textures/floor/forest.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/floor/forest.png
--------------------------------------------------------------------------------
/common/assets/textures/table/puzzle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/table/puzzle.png
--------------------------------------------------------------------------------
/common/assets/textures/table/stove.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/table/stove.png
--------------------------------------------------------------------------------
/common/assets/textures/table/stove2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/table/stove2.png
--------------------------------------------------------------------------------
/common/assets/textures/table/stove3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/table/stove3.png
--------------------------------------------------------------------------------
/common/assets/textures/wall/cabinets.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/wall/cabinets.png
--------------------------------------------------------------------------------
/common/assets/textures/wall/fridge.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/wall/fridge.png
--------------------------------------------------------------------------------
/common/assets/textures/wall/posters.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/wall/posters.png
--------------------------------------------------------------------------------
/common/assets/textures/wall/wardrobe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/wall/wardrobe.png
--------------------------------------------------------------------------------
/common/assets/textures/wall/white_wall.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/wall/white_wall.png
--------------------------------------------------------------------------------
/common/assets/textures/floor/rubber_mats.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/floor/rubber_mats.png
--------------------------------------------------------------------------------
/common/assets/textures/wall/kitchen_hanger.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/wall/kitchen_hanger.png
--------------------------------------------------------------------------------
/common/assets/textures/wall/kitchen_hanger2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/wall/kitchen_hanger2.png
--------------------------------------------------------------------------------
/common/assets/textures/wall/kitchen_hanger3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/wall/kitchen_hanger3.png
--------------------------------------------------------------------------------
/prompts/similar_noun/synonym_earphone.txt:
--------------------------------------------------------------------------------
1 | earphone
2 | headphone
3 | earbud
4 | in-ear monitor
5 | earpiece
6 | earphone set
7 | earphone accessory
--------------------------------------------------------------------------------
/prompts/similar_noun/synonym_cap.txt:
--------------------------------------------------------------------------------
1 | cap
2 | hat
3 | snapback
4 | baseball cap
5 | visor
6 | bucket hat
7 | fedora
8 | cowboy hat
9 | knit cap
10 | nightcap
--------------------------------------------------------------------------------
/prompts/similar_noun/synonym_jar.txt:
--------------------------------------------------------------------------------
1 | jar
2 | vase
3 | container
4 | bottle
5 | canister
6 | pot
7 | urn
8 | jug
9 | flask
10 | pitcher
11 | bucket
12 | pail
13 | basket
--------------------------------------------------------------------------------
/prompts/similar_noun/synonym_faucet.txt:
--------------------------------------------------------------------------------
1 | faucet
2 | tap
3 | valve
4 | spout
5 | handle
6 | lever
7 | knob
8 | spigot
9 | water dispenser
10 | sink faucet
11 | bathtub faucet
12 | drinking fountain
13 | water faucet
14 | mixing faucet
--------------------------------------------------------------------------------
/prompts/similar_noun/synonym_telephone.txt:
--------------------------------------------------------------------------------
1 | telephone
2 | phone
3 | mobile phone
4 | cell phone
5 | landline
6 | cordless phone
7 | smart phone
8 | flip phone
9 | touchscreen phone
10 | wireless phone
11 | home phone
12 | office phonee
--------------------------------------------------------------------------------
/prompts/similar_noun/synonym_mug.txt:
--------------------------------------------------------------------------------
1 | mug
2 | cup
3 | tumbler
4 | teacup
5 | coffee cup
6 | espresso cup
7 | beer mug
8 | wine glass
9 | goblet
10 | flute
11 | martini glass
12 | punch cup
13 | thermal mug
14 | plastic cup
15 | styrofoam cup
16 | paper cup
--------------------------------------------------------------------------------
/prompts/similar_noun/synonym_bag.txt:
--------------------------------------------------------------------------------
1 | bag
2 | handbag
3 | purse
4 | clutch
5 | tote
6 | backpack
7 | knapsack
8 | satchel
9 | shoulder bag
10 | duffel bag
11 | messenger bag
12 | grip
13 | briefcase
14 | pouch
15 | fanny pack
16 | drawstring bag
17 | beach bag
18 | grocery bag
19 | shopping bag
20 | gift bag
21 | lunch bag
22 | laptop bag
23 | travel bag
--------------------------------------------------------------------------------
/prompts/phone_on_base.txt:
--------------------------------------------------------------------------------
1 | phone on base
2 | pick up the phone and place it on the base
3 | the robot gripper picks up the phone and places it on the base
4 | reach toward the phone and move it to the base
5 | the robot arm grasps the phone on sets it down
6 | put the phone on the base
7 | the robot is placing the phone on the box
8 | seize the phone and drop it on the box
9 | robot arm grasped the phone and dropped it on the square base
10 | the arm is picking up the phone and placing it on the base
11 |
--------------------------------------------------------------------------------
/common/__init__.py:
--------------------------------------------------------------------------------
1 | # General tools.
2 | from .config import *
3 | from .counter import *
4 | from .flags import *
5 | from .logger import *
6 | from .when import *
7 |
8 | # RL tools.
9 | from .other import *
10 | from .driver import *
11 | from .rlbench_utils import *
12 | from .envs import *
13 | # from .kitchen import *
14 | from .replay import *
15 |
16 | # TensorFlow tools.
17 | from .tfutils import *
18 | from .dists import *
19 |
20 | from .nets import *
21 | from .mae_utils import *
22 | from .mae import *
23 | from .expl import *
24 | from .rnd import *
--------------------------------------------------------------------------------
/common/counter.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 |
4 | @functools.total_ordering
5 | class Counter:
6 | def __init__(self, initial=0):
7 | self.value = initial
8 |
9 | def __int__(self):
10 | return int(self.value)
11 |
12 | def __eq__(self, other):
13 | return int(self) == other
14 |
15 | def __ne__(self, other):
16 | return int(self) != other
17 |
18 | def __lt__(self, other):
19 | return int(self) < other
20 |
21 | def __add__(self, other):
22 | return int(self) + other
23 |
24 | def increment(self, amount=1):
25 | self.value += amount
26 |
--------------------------------------------------------------------------------
/prompts/mug_40.txt:
--------------------------------------------------------------------------------
1 | Grasp the mug
2 | Extend your hand for the mug
3 | Retrieve the mug
4 | Take hold of the mug
5 | Stretch out for the mug
6 | Clutch the mug
7 | Seize the mug
8 | Lay your hands on the mug
9 | Pick up the mug
10 | Get a grip on the mug
11 | Reach out for the cup
12 | Grip the cup
13 | Hold the mug
14 | Stretch for the mug
15 | Snatch the mug
16 | Embrace the mug
17 | Catch the mug
18 | Grapple for the mug
19 | Take the mug
20 | Secure the mug
21 | Cling to the mug
22 | Obtain the mug
23 | Acquire the mug
24 | Clasp the mug
25 | Nudge the mug
26 | Bring the mug closer
27 | Draw the mug nearer
28 | Gather the mug
29 | Gain possession of the mug
30 | Lay hold of the mug
31 | Retrieve the cup
32 | Clutch the cup
33 | Hold onto the cup
34 | Reach out for the beverage holder
35 | Extend your arm for the coffee mug
36 | Grab the mug
37 | Obtain the coffee vessel
38 | Snag the mug
39 | Get a hold of the mug
40 | Hook the mug
--------------------------------------------------------------------------------
/common/when.py:
--------------------------------------------------------------------------------
1 | class Every:
2 | def __init__(self, every):
3 | self._every = every
4 | self._last = None
5 |
6 | def __call__(self, step):
7 | step = int(step)
8 | if not self._every:
9 | return False
10 | if self._last is None:
11 | self._last = step
12 | return True
13 | if step >= self._last + self._every:
14 | self._last += self._every
15 | return True
16 | return False
17 |
18 |
19 | class Once:
20 | def __init__(self):
21 | self._once = True
22 |
23 | def __call__(self):
24 | if self._once:
25 | self._once = False
26 | return True
27 | return False
28 |
29 |
30 | class Until:
31 | def __init__(self, until):
32 | self._until = until
33 |
34 | def __call__(self, step):
35 | step = int(step)
36 | if not self._until:
37 | return True
38 | return step < self._until
39 |
--------------------------------------------------------------------------------
/prompts/general_tasks_40.txt:
--------------------------------------------------------------------------------
1 | Close the window blinds
2 | Set the kitchen timer
3 | Sweep the kitchen floor
4 | Turn off the stove
5 | Dust the bookshelf
6 | Empty the kitchen trash can
7 | Arrange the spices alphabetically
8 | Turn on the kitchen fan
9 | Polish the kitchen sink faucet
10 | Put away clean dishes
11 | Take out the recycling
12 | Fluff the couch pillows
13 | Straighten the curtains
14 | Check the expiration date on the milk
15 | Replace the hand soap
16 | Sweep the front porch
17 | Take a sip of water
18 | Put away your shoes
19 | Change the tablecloth
20 | Feed the cat
21 | Take a deep breath
22 | Wipe down the kitchen counters
23 | Replace the batteries in the smoke detector
24 | Water the indoor plants
25 | Wash your hands
26 | Straighten the picture frames
27 | Refill the salt shaker
28 | Unload the dishwasher
29 | Turn on the kitchen lights
30 | Organize the fridge contents
31 | Take a quick snack break
32 | Tidy up the entryway
33 | Replace the dish towel
34 | Dust the lampshade
35 | Sweep the outdoor patio
36 | Put away the cereal box
37 | Check the weather forecast
38 | Test the oven temperature
39 | Wipe the mirror clean
40 | Take a moment to stretch
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2023 Ademi Adeniji, Amber Xie, Carmelo Sferrazza, Younggyo Seo, Stephen James, Pieter Abbeel
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining
4 | a copy of this software and associated documentation files (the
5 | "Software"), to deal in the Software without restriction, including
6 | without limitation the rights to use, copy, modify, merge, publish,
7 | distribute, sublicense, and/or sell copies of the Software, and to
8 | permit persons to whom the Software is furnished to do so, subject to
9 | the following conditions:
10 |
11 | The above copyright notice and this permission notice shall be
12 | included in all copies or substantial portions of the Software.
13 |
14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/prompts/rand_verb_40.txt:
--------------------------------------------------------------------------------
1 | The [NOUN] is seized
2 | The [NOUN] is clutched
3 | The [NOUN] is gripped
4 | The [NOUN] is firmly grasped
5 | The [NOUN] is tightly held
6 | The [NOUN] is firmly caught
7 | The [NOUN] is securely clasped
8 | The [NOUN] is rotated
9 | The [NOUN] has been flipped
10 | The [NOUN] has been knotted
11 | The [NOUN] has been folded
12 | The [NOUN] has been rinsed
13 | The [NOUN] has been filled
14 | The [NOUN] is shaken
15 | The [NOUN] has been scooped
16 | The [NOUN] is poured
17 | The [NOUN] has been scrubbed
18 | The [NOUN] is tilted
19 | The [NOUN] has been heated
20 | Reach for the [NOUN]
21 | Grasp at the [NOUN]
22 | Stretch out to touch the [NOUN]
23 | Move your arm towards the [NOUN]
24 | Use the gripper to rinse the [NOUN]
25 | Position the end effector to fold the [NOUN]
26 | Reach out the robotic arm to wipe the [NOUN]
27 | Utilize the gripper to seize the [NOUN]
28 | Guide the robotic arm to obtain the [NOUN]
29 | Maneuver the end effector to lift up the [NOUN]
30 | Extend your hand towards the [NOUN]
31 | Reach out your hand to acquire the [NOUN]
32 | Guide your arm to rotate the [NOUN]
33 | Maneuver your hand to shake up the [NOUN]
34 | Flip the [NOUN]
35 | Tap the [NOUN]
36 | Fold the [NOUN]
37 | Rotate the [NOUN]
38 | Brush the [NOUN]
39 | Twist the [NOUN]
40 | Wipe the [NOUN]
--------------------------------------------------------------------------------
/prompts/hamlet.txt:
--------------------------------------------------------------------------------
1 | Holla, Barnardo.
2 | BARNARDO Say, what, is Horatio there?
3 | HORATIO A piece of him.
4 | BARNARDO
5 | Welcome, Horatio.—Welcome, good Marcellus.
6 | HORATIO
7 | What, has this thing appeared again tonight?
8 | BARNARDO I have seen nothing.
9 | MARCELLUS
10 | Horatio says ’tis but our fantasy
11 | And will not let belief take hold of him
12 | Touching this dreaded sight twice seen of us.
13 | Therefore I have entreated him along
14 | With us to watch the minutes of this night,
15 | That, if again this apparition come,
16 | He may approve our eyes and speak to it.
17 | Tush, tush, ’twill not appear.
18 | BARNARDO Sit down awhile,
19 | How now, Horatio, you tremble and look pale.
20 | Is not this something more than fantasy?
21 | What think you on ’t?
22 | At least the whisper goes so: our last king,
23 | Whose image even but now appeared to us,
24 | Was, as you know, by Fortinbras of Norway,
25 | Thereto pricked on by a most emulate pride,
26 | Dared to the combat; in which our valiant Hamlet
27 | (For so this side of our known world esteemed him)
28 | Did slay this Fortinbras, who by a sealed compact,
29 | Well ratified by law and heraldry,
30 | Did forfeit, with his life, all those his lands
31 | Which he stood seized of, to the conqueror.
32 | Against the which a moiety competent
33 | Was gagèd by our king, which had returned
34 | To the inheritance of Fortinbras
35 | Had he been vanquisher, as, by the same comart
36 | And carriage of the article designed,
37 | His fell to Hamlet. Now, sir, young Fortinbras,
38 | Of unimprovèd mettle hot and full,
39 | Hath in the skirts of Norway here and there
40 | Sharked up a list of lawless resolutes
--------------------------------------------------------------------------------
/prompts/similar_verb_40.txt:
--------------------------------------------------------------------------------
1 | Pick up the [NOUN]
2 | Lift the [NOUN] with your hands
3 | Hold the [NOUN] in your grasp
4 | Take hold of the [NOUN] and raise it
5 | Grasp the [NOUN] firmly and lift it up
6 | Raise the [NOUN] by picking it up
7 | Retrieve the [NOUN] and hold it up
8 | Lift the [NOUN] by gripping it
9 | Seize the [NOUN] and raise it off the surface
10 | Hold onto the [NOUN] and lift it up
11 | The [NOUN] is lifted up
12 | The [NOUN] is picked up off the ground
13 | The [NOUN] is raised up by hand
14 | The [NOUN] is grasped and lifted up
15 | The [NOUN] is taken up by hand
16 | The [NOUN] is retrieved and lifted up
17 | The [NOUN] is lifted off its surface
18 | The [NOUN] is elevated by being picked up
19 | The [NOUN] is hoisted up by hand
20 | The [NOUN] is scooped up and lifted
21 | The [NOUN] is lifted by the hand
22 | The [NOUN] is grasped and picked up
23 | The [NOUN] is raised by the palm
24 | The [NOUN] is taken up by the fingers
25 | The [NOUN] is held and lifted up
26 | The [NOUN] is lifted off the surface by the arm
27 | The [NOUN] is picked up and held by the wrist
28 | The [NOUN] is scooped up by the palm and lifted
29 | The [NOUN] is elevated by the hand
30 | The [NOUN] is taken up by the fingers of the hand
31 | The [NOUN] is grasped and raised
32 | The [NOUN] is lifted by the gripper
33 | The end effector picks up the [NOUN]
34 | The arm lifts the [NOUN]
35 | The [NOUN] is held aloft by the robotic hand
36 | The robotic gripper secures the [NOUN]
37 | The [NOUN] is lifted off the surface by the robotic arm
38 | The robotic manipulator seizes and elevates the [NOUN]
39 | The robotic end effector clasps and hoists the [NOUN]
40 | The [NOUN] is taken up by the robotic gripper
--------------------------------------------------------------------------------
/prompts/rand_noun/synonym_bag.txt:
--------------------------------------------------------------------------------
1 | bag
2 | handbag
3 | purse
4 | clutch
5 | tote
6 | backpack
7 | knapsack
8 | satchel
9 | shoulder bag
10 | duffel bag
11 | messenger bag
12 | grip
13 | briefcase
14 | pouch
15 | fanny pack
16 | drawstring bag
17 | beach bag
18 | grocery bag
19 | shopping bag
20 | gift bag
21 | lunch bag
22 | laptop bag
23 | travel bag
24 | bowl
25 | dish
26 | soup bowl
27 | wooden bowl
28 | plastic bowl
29 | cap
30 | hat
31 | snapback
32 | baseball cap
33 | visor
34 | bucket hat
35 | fedora
36 | cowboy hat
37 | knit cap
38 | nightcap
39 | earphone
40 | headphone
41 | earbud
42 | in-ear monitor
43 | earpiece
44 | earphone set
45 | earphone accessory
46 | faucet
47 | tap
48 | valve
49 | spout
50 | handle
51 | lever
52 | knob
53 | spigot
54 | water dispenser
55 | sink faucet
56 | bathtub faucet
57 | drinking fountain
58 | water faucet
59 | mixing faucet
60 | jar
61 | vase
62 | container
63 | bottle
64 | canister
65 | pot
66 | urn
67 | jug
68 | flask
69 | pitcher
70 | bucket
71 | pail
72 | basket
73 | knife
74 | blade
75 | cutter
76 | chopper
77 | cleaver
78 | pocket knife
79 | Swiss army knife
80 | laptop
81 | notebook
82 | ultrabook
83 | netbook
84 | tablet
85 | chromebook
86 | computer
87 | mug
88 | cup
89 | tumbler
90 | teacup
91 | coffee cup
92 | espresso cup
93 | beer mug
94 | wine glass
95 | goblet
96 | flute
97 | martini glass
98 | punch cup
99 | thermal mug
100 | plastic cup
101 | styrofoam cup
102 | paper cup
103 | pot
104 | flowerpot
105 | vase
106 | container
107 | plant pot
108 | flower holder
109 | telephone
110 | phone
111 | mobile phone
112 | cell phone
113 | landline
114 | cordless phone
115 | smart phone
116 | flip phone
117 | touchscreen phone
118 | wireless phone
119 | home phone
120 | office phonee
--------------------------------------------------------------------------------
/prompts/rand_noun/synonym_cap.txt:
--------------------------------------------------------------------------------
1 | bag
2 | handbag
3 | purse
4 | clutch
5 | tote
6 | backpack
7 | knapsack
8 | satchel
9 | shoulder bag
10 | duffel bag
11 | messenger bag
12 | grip
13 | briefcase
14 | pouch
15 | fanny pack
16 | drawstring bag
17 | beach bag
18 | grocery bag
19 | shopping bag
20 | gift bag
21 | lunch bag
22 | laptop bag
23 | travel bag
24 | bowl
25 | dish
26 | soup bowl
27 | wooden bowl
28 | plastic bowl
29 | cap
30 | hat
31 | snapback
32 | baseball cap
33 | visor
34 | bucket hat
35 | fedora
36 | cowboy hat
37 | knit cap
38 | nightcap
39 | earphone
40 | headphone
41 | earbud
42 | in-ear monitor
43 | earpiece
44 | earphone set
45 | earphone accessory
46 | faucet
47 | tap
48 | valve
49 | spout
50 | handle
51 | lever
52 | knob
53 | spigot
54 | water dispenser
55 | sink faucet
56 | bathtub faucet
57 | drinking fountain
58 | water faucet
59 | mixing faucet
60 | jar
61 | vase
62 | container
63 | bottle
64 | canister
65 | pot
66 | urn
67 | jug
68 | flask
69 | pitcher
70 | bucket
71 | pail
72 | basket
73 | knife
74 | blade
75 | cutter
76 | chopper
77 | cleaver
78 | pocket knife
79 | Swiss army knife
80 | laptop
81 | notebook
82 | ultrabook
83 | netbook
84 | tablet
85 | chromebook
86 | computer
87 | mug
88 | cup
89 | tumbler
90 | teacup
91 | coffee cup
92 | espresso cup
93 | beer mug
94 | wine glass
95 | goblet
96 | flute
97 | martini glass
98 | punch cup
99 | thermal mug
100 | plastic cup
101 | styrofoam cup
102 | paper cup
103 | pot
104 | flowerpot
105 | vase
106 | container
107 | plant pot
108 | flower holder
109 | telephone
110 | phone
111 | mobile phone
112 | cell phone
113 | landline
114 | cordless phone
115 | smart phone
116 | flip phone
117 | touchscreen phone
118 | wireless phone
119 | home phone
120 | office phonee
--------------------------------------------------------------------------------
/prompts/rand_noun/synonym_jar.txt:
--------------------------------------------------------------------------------
1 | bag
2 | handbag
3 | purse
4 | clutch
5 | tote
6 | backpack
7 | knapsack
8 | satchel
9 | shoulder bag
10 | duffel bag
11 | messenger bag
12 | grip
13 | briefcase
14 | pouch
15 | fanny pack
16 | drawstring bag
17 | beach bag
18 | grocery bag
19 | shopping bag
20 | gift bag
21 | lunch bag
22 | laptop bag
23 | travel bag
24 | bowl
25 | dish
26 | soup bowl
27 | wooden bowl
28 | plastic bowl
29 | cap
30 | hat
31 | snapback
32 | baseball cap
33 | visor
34 | bucket hat
35 | fedora
36 | cowboy hat
37 | knit cap
38 | nightcap
39 | earphone
40 | headphone
41 | earbud
42 | in-ear monitor
43 | earpiece
44 | earphone set
45 | earphone accessory
46 | faucet
47 | tap
48 | valve
49 | spout
50 | handle
51 | lever
52 | knob
53 | spigot
54 | water dispenser
55 | sink faucet
56 | bathtub faucet
57 | drinking fountain
58 | water faucet
59 | mixing faucet
60 | jar
61 | vase
62 | container
63 | bottle
64 | canister
65 | pot
66 | urn
67 | jug
68 | flask
69 | pitcher
70 | bucket
71 | pail
72 | basket
73 | knife
74 | blade
75 | cutter
76 | chopper
77 | cleaver
78 | pocket knife
79 | Swiss army knife
80 | laptop
81 | notebook
82 | ultrabook
83 | netbook
84 | tablet
85 | chromebook
86 | computer
87 | mug
88 | cup
89 | tumbler
90 | teacup
91 | coffee cup
92 | espresso cup
93 | beer mug
94 | wine glass
95 | goblet
96 | flute
97 | martini glass
98 | punch cup
99 | thermal mug
100 | plastic cup
101 | styrofoam cup
102 | paper cup
103 | pot
104 | flowerpot
105 | vase
106 | container
107 | plant pot
108 | flower holder
109 | telephone
110 | phone
111 | mobile phone
112 | cell phone
113 | landline
114 | cordless phone
115 | smart phone
116 | flip phone
117 | touchscreen phone
118 | wireless phone
119 | home phone
120 | office phonee
--------------------------------------------------------------------------------
/prompts/rand_noun/synonym_mug.txt:
--------------------------------------------------------------------------------
1 | bag
2 | handbag
3 | purse
4 | clutch
5 | tote
6 | backpack
7 | knapsack
8 | satchel
9 | shoulder bag
10 | duffel bag
11 | messenger bag
12 | grip
13 | briefcase
14 | pouch
15 | fanny pack
16 | drawstring bag
17 | beach bag
18 | grocery bag
19 | shopping bag
20 | gift bag
21 | lunch bag
22 | laptop bag
23 | travel bag
24 | bowl
25 | dish
26 | soup bowl
27 | wooden bowl
28 | plastic bowl
29 | cap
30 | hat
31 | snapback
32 | baseball cap
33 | visor
34 | bucket hat
35 | fedora
36 | cowboy hat
37 | knit cap
38 | nightcap
39 | earphone
40 | headphone
41 | earbud
42 | in-ear monitor
43 | earpiece
44 | earphone set
45 | earphone accessory
46 | faucet
47 | tap
48 | valve
49 | spout
50 | handle
51 | lever
52 | knob
53 | spigot
54 | water dispenser
55 | sink faucet
56 | bathtub faucet
57 | drinking fountain
58 | water faucet
59 | mixing faucet
60 | jar
61 | vase
62 | container
63 | bottle
64 | canister
65 | pot
66 | urn
67 | jug
68 | flask
69 | pitcher
70 | bucket
71 | pail
72 | basket
73 | knife
74 | blade
75 | cutter
76 | chopper
77 | cleaver
78 | pocket knife
79 | Swiss army knife
80 | laptop
81 | notebook
82 | ultrabook
83 | netbook
84 | tablet
85 | chromebook
86 | computer
87 | mug
88 | cup
89 | tumbler
90 | teacup
91 | coffee cup
92 | espresso cup
93 | beer mug
94 | wine glass
95 | goblet
96 | flute
97 | martini glass
98 | punch cup
99 | thermal mug
100 | plastic cup
101 | styrofoam cup
102 | paper cup
103 | pot
104 | flowerpot
105 | vase
106 | container
107 | plant pot
108 | flower holder
109 | telephone
110 | phone
111 | mobile phone
112 | cell phone
113 | landline
114 | cordless phone
115 | smart phone
116 | flip phone
117 | touchscreen phone
118 | wireless phone
119 | home phone
120 | office phonee
--------------------------------------------------------------------------------
/prompts/rand_noun/synonym_pot.txt:
--------------------------------------------------------------------------------
1 | bag
2 | handbag
3 | purse
4 | clutch
5 | tote
6 | backpack
7 | knapsack
8 | satchel
9 | shoulder bag
10 | duffel bag
11 | messenger bag
12 | grip
13 | briefcase
14 | pouch
15 | fanny pack
16 | drawstring bag
17 | beach bag
18 | grocery bag
19 | shopping bag
20 | gift bag
21 | lunch bag
22 | laptop bag
23 | travel bag
24 | bowl
25 | dish
26 | soup bowl
27 | wooden bowl
28 | plastic bowl
29 | cap
30 | hat
31 | snapback
32 | baseball cap
33 | visor
34 | bucket hat
35 | fedora
36 | cowboy hat
37 | knit cap
38 | nightcap
39 | earphone
40 | headphone
41 | earbud
42 | in-ear monitor
43 | earpiece
44 | earphone set
45 | earphone accessory
46 | faucet
47 | tap
48 | valve
49 | spout
50 | handle
51 | lever
52 | knob
53 | spigot
54 | water dispenser
55 | sink faucet
56 | bathtub faucet
57 | drinking fountain
58 | water faucet
59 | mixing faucet
60 | jar
61 | vase
62 | container
63 | bottle
64 | canister
65 | pot
66 | urn
67 | jug
68 | flask
69 | pitcher
70 | bucket
71 | pail
72 | basket
73 | knife
74 | blade
75 | cutter
76 | chopper
77 | cleaver
78 | pocket knife
79 | Swiss army knife
80 | laptop
81 | notebook
82 | ultrabook
83 | netbook
84 | tablet
85 | chromebook
86 | computer
87 | mug
88 | cup
89 | tumbler
90 | teacup
91 | coffee cup
92 | espresso cup
93 | beer mug
94 | wine glass
95 | goblet
96 | flute
97 | martini glass
98 | punch cup
99 | thermal mug
100 | plastic cup
101 | styrofoam cup
102 | paper cup
103 | pot
104 | flowerpot
105 | vase
106 | container
107 | plant pot
108 | flower holder
109 | telephone
110 | phone
111 | mobile phone
112 | cell phone
113 | landline
114 | cordless phone
115 | smart phone
116 | flip phone
117 | touchscreen phone
118 | wireless phone
119 | home phone
120 | office phonee
--------------------------------------------------------------------------------
/prompts/rand_noun/synonym_bowl.txt:
--------------------------------------------------------------------------------
1 | bag
2 | handbag
3 | purse
4 | clutch
5 | tote
6 | backpack
7 | knapsack
8 | satchel
9 | shoulder bag
10 | duffel bag
11 | messenger bag
12 | grip
13 | briefcase
14 | pouch
15 | fanny pack
16 | drawstring bag
17 | beach bag
18 | grocery bag
19 | shopping bag
20 | gift bag
21 | lunch bag
22 | laptop bag
23 | travel bag
24 | bowl
25 | dish
26 | soup bowl
27 | wooden bowl
28 | plastic bowl
29 | cap
30 | hat
31 | snapback
32 | baseball cap
33 | visor
34 | bucket hat
35 | fedora
36 | cowboy hat
37 | knit cap
38 | nightcap
39 | earphone
40 | headphone
41 | earbud
42 | in-ear monitor
43 | earpiece
44 | earphone set
45 | earphone accessory
46 | faucet
47 | tap
48 | valve
49 | spout
50 | handle
51 | lever
52 | knob
53 | spigot
54 | water dispenser
55 | sink faucet
56 | bathtub faucet
57 | drinking fountain
58 | water faucet
59 | mixing faucet
60 | jar
61 | vase
62 | container
63 | bottle
64 | canister
65 | pot
66 | urn
67 | jug
68 | flask
69 | pitcher
70 | bucket
71 | pail
72 | basket
73 | knife
74 | blade
75 | cutter
76 | chopper
77 | cleaver
78 | pocket knife
79 | Swiss army knife
80 | laptop
81 | notebook
82 | ultrabook
83 | netbook
84 | tablet
85 | chromebook
86 | computer
87 | mug
88 | cup
89 | tumbler
90 | teacup
91 | coffee cup
92 | espresso cup
93 | beer mug
94 | wine glass
95 | goblet
96 | flute
97 | martini glass
98 | punch cup
99 | thermal mug
100 | plastic cup
101 | styrofoam cup
102 | paper cup
103 | pot
104 | flowerpot
105 | vase
106 | container
107 | plant pot
108 | flower holder
109 | telephone
110 | phone
111 | mobile phone
112 | cell phone
113 | landline
114 | cordless phone
115 | smart phone
116 | flip phone
117 | touchscreen phone
118 | wireless phone
119 | home phone
120 | office phonee
--------------------------------------------------------------------------------
/prompts/rand_noun/synonym_earphone.txt:
--------------------------------------------------------------------------------
1 | bag
2 | handbag
3 | purse
4 | clutch
5 | tote
6 | backpack
7 | knapsack
8 | satchel
9 | shoulder bag
10 | duffel bag
11 | messenger bag
12 | grip
13 | briefcase
14 | pouch
15 | fanny pack
16 | drawstring bag
17 | beach bag
18 | grocery bag
19 | shopping bag
20 | gift bag
21 | lunch bag
22 | laptop bag
23 | travel bag
24 | bowl
25 | dish
26 | soup bowl
27 | wooden bowl
28 | plastic bowl
29 | cap
30 | hat
31 | snapback
32 | baseball cap
33 | visor
34 | bucket hat
35 | fedora
36 | cowboy hat
37 | knit cap
38 | nightcap
39 | earphone
40 | headphone
41 | earbud
42 | in-ear monitor
43 | earpiece
44 | earphone set
45 | earphone accessory
46 | faucet
47 | tap
48 | valve
49 | spout
50 | handle
51 | lever
52 | knob
53 | spigot
54 | water dispenser
55 | sink faucet
56 | bathtub faucet
57 | drinking fountain
58 | water faucet
59 | mixing faucet
60 | jar
61 | vase
62 | container
63 | bottle
64 | canister
65 | pot
66 | urn
67 | jug
68 | flask
69 | pitcher
70 | bucket
71 | pail
72 | basket
73 | knife
74 | blade
75 | cutter
76 | chopper
77 | cleaver
78 | pocket knife
79 | Swiss army knife
80 | laptop
81 | notebook
82 | ultrabook
83 | netbook
84 | tablet
85 | chromebook
86 | computer
87 | mug
88 | cup
89 | tumbler
90 | teacup
91 | coffee cup
92 | espresso cup
93 | beer mug
94 | wine glass
95 | goblet
96 | flute
97 | martini glass
98 | punch cup
99 | thermal mug
100 | plastic cup
101 | styrofoam cup
102 | paper cup
103 | pot
104 | flowerpot
105 | vase
106 | container
107 | plant pot
108 | flower holder
109 | telephone
110 | phone
111 | mobile phone
112 | cell phone
113 | landline
114 | cordless phone
115 | smart phone
116 | flip phone
117 | touchscreen phone
118 | wireless phone
119 | home phone
120 | office phonee
--------------------------------------------------------------------------------
/prompts/rand_noun/synonym_faucet.txt:
--------------------------------------------------------------------------------
1 | bag
2 | handbag
3 | purse
4 | clutch
5 | tote
6 | backpack
7 | knapsack
8 | satchel
9 | shoulder bag
10 | duffel bag
11 | messenger bag
12 | grip
13 | briefcase
14 | pouch
15 | fanny pack
16 | drawstring bag
17 | beach bag
18 | grocery bag
19 | shopping bag
20 | gift bag
21 | lunch bag
22 | laptop bag
23 | travel bag
24 | bowl
25 | dish
26 | soup bowl
27 | wooden bowl
28 | plastic bowl
29 | cap
30 | hat
31 | snapback
32 | baseball cap
33 | visor
34 | bucket hat
35 | fedora
36 | cowboy hat
37 | knit cap
38 | nightcap
39 | earphone
40 | headphone
41 | earbud
42 | in-ear monitor
43 | earpiece
44 | earphone set
45 | earphone accessory
46 | faucet
47 | tap
48 | valve
49 | spout
50 | handle
51 | lever
52 | knob
53 | spigot
54 | water dispenser
55 | sink faucet
56 | bathtub faucet
57 | drinking fountain
58 | water faucet
59 | mixing faucet
60 | jar
61 | vase
62 | container
63 | bottle
64 | canister
65 | pot
66 | urn
67 | jug
68 | flask
69 | pitcher
70 | bucket
71 | pail
72 | basket
73 | knife
74 | blade
75 | cutter
76 | chopper
77 | cleaver
78 | pocket knife
79 | Swiss army knife
80 | laptop
81 | notebook
82 | ultrabook
83 | netbook
84 | tablet
85 | chromebook
86 | computer
87 | mug
88 | cup
89 | tumbler
90 | teacup
91 | coffee cup
92 | espresso cup
93 | beer mug
94 | wine glass
95 | goblet
96 | flute
97 | martini glass
98 | punch cup
99 | thermal mug
100 | plastic cup
101 | styrofoam cup
102 | paper cup
103 | pot
104 | flowerpot
105 | vase
106 | container
107 | plant pot
108 | flower holder
109 | telephone
110 | phone
111 | mobile phone
112 | cell phone
113 | landline
114 | cordless phone
115 | smart phone
116 | flip phone
117 | touchscreen phone
118 | wireless phone
119 | home phone
120 | office phonee
--------------------------------------------------------------------------------
/prompts/rand_noun/synonym_knife.txt:
--------------------------------------------------------------------------------
1 | bag
2 | handbag
3 | purse
4 | clutch
5 | tote
6 | backpack
7 | knapsack
8 | satchel
9 | shoulder bag
10 | duffel bag
11 | messenger bag
12 | grip
13 | briefcase
14 | pouch
15 | fanny pack
16 | drawstring bag
17 | beach bag
18 | grocery bag
19 | shopping bag
20 | gift bag
21 | lunch bag
22 | laptop bag
23 | travel bag
24 | bowl
25 | dish
26 | soup bowl
27 | wooden bowl
28 | plastic bowl
29 | cap
30 | hat
31 | snapback
32 | baseball cap
33 | visor
34 | bucket hat
35 | fedora
36 | cowboy hat
37 | knit cap
38 | nightcap
39 | earphone
40 | headphone
41 | earbud
42 | in-ear monitor
43 | earpiece
44 | earphone set
45 | earphone accessory
46 | faucet
47 | tap
48 | valve
49 | spout
50 | handle
51 | lever
52 | knob
53 | spigot
54 | water dispenser
55 | sink faucet
56 | bathtub faucet
57 | drinking fountain
58 | water faucet
59 | mixing faucet
60 | jar
61 | vase
62 | container
63 | bottle
64 | canister
65 | pot
66 | urn
67 | jug
68 | flask
69 | pitcher
70 | bucket
71 | pail
72 | basket
73 | knife
74 | blade
75 | cutter
76 | chopper
77 | cleaver
78 | pocket knife
79 | Swiss army knife
80 | laptop
81 | notebook
82 | ultrabook
83 | netbook
84 | tablet
85 | chromebook
86 | computer
87 | mug
88 | cup
89 | tumbler
90 | teacup
91 | coffee cup
92 | espresso cup
93 | beer mug
94 | wine glass
95 | goblet
96 | flute
97 | martini glass
98 | punch cup
99 | thermal mug
100 | plastic cup
101 | styrofoam cup
102 | paper cup
103 | pot
104 | flowerpot
105 | vase
106 | container
107 | plant pot
108 | flower holder
109 | telephone
110 | phone
111 | mobile phone
112 | cell phone
113 | landline
114 | cordless phone
115 | smart phone
116 | flip phone
117 | touchscreen phone
118 | wireless phone
119 | home phone
120 | office phonee
--------------------------------------------------------------------------------
/prompts/rand_noun/synonym_laptop.txt:
--------------------------------------------------------------------------------
1 | bag
2 | handbag
3 | purse
4 | clutch
5 | tote
6 | backpack
7 | knapsack
8 | satchel
9 | shoulder bag
10 | duffel bag
11 | messenger bag
12 | grip
13 | briefcase
14 | pouch
15 | fanny pack
16 | drawstring bag
17 | beach bag
18 | grocery bag
19 | shopping bag
20 | gift bag
21 | lunch bag
22 | laptop bag
23 | travel bag
24 | bowl
25 | dish
26 | soup bowl
27 | wooden bowl
28 | plastic bowl
29 | cap
30 | hat
31 | snapback
32 | baseball cap
33 | visor
34 | bucket hat
35 | fedora
36 | cowboy hat
37 | knit cap
38 | nightcap
39 | earphone
40 | headphone
41 | earbud
42 | in-ear monitor
43 | earpiece
44 | earphone set
45 | earphone accessory
46 | faucet
47 | tap
48 | valve
49 | spout
50 | handle
51 | lever
52 | knob
53 | spigot
54 | water dispenser
55 | sink faucet
56 | bathtub faucet
57 | drinking fountain
58 | water faucet
59 | mixing faucet
60 | jar
61 | vase
62 | container
63 | bottle
64 | canister
65 | pot
66 | urn
67 | jug
68 | flask
69 | pitcher
70 | bucket
71 | pail
72 | basket
73 | knife
74 | blade
75 | cutter
76 | chopper
77 | cleaver
78 | pocket knife
79 | Swiss army knife
80 | laptop
81 | notebook
82 | ultrabook
83 | netbook
84 | tablet
85 | chromebook
86 | computer
87 | mug
88 | cup
89 | tumbler
90 | teacup
91 | coffee cup
92 | espresso cup
93 | beer mug
94 | wine glass
95 | goblet
96 | flute
97 | martini glass
98 | punch cup
99 | thermal mug
100 | plastic cup
101 | styrofoam cup
102 | paper cup
103 | pot
104 | flowerpot
105 | vase
106 | container
107 | plant pot
108 | flower holder
109 | telephone
110 | phone
111 | mobile phone
112 | cell phone
113 | landline
114 | cordless phone
115 | smart phone
116 | flip phone
117 | touchscreen phone
118 | wireless phone
119 | home phone
120 | office phonee
--------------------------------------------------------------------------------
/prompts/rand_noun/synonym_telephone.txt:
--------------------------------------------------------------------------------
1 | bag
2 | handbag
3 | purse
4 | clutch
5 | tote
6 | backpack
7 | knapsack
8 | satchel
9 | shoulder bag
10 | duffel bag
11 | messenger bag
12 | grip
13 | briefcase
14 | pouch
15 | fanny pack
16 | drawstring bag
17 | beach bag
18 | grocery bag
19 | shopping bag
20 | gift bag
21 | lunch bag
22 | laptop bag
23 | travel bag
24 | bowl
25 | dish
26 | soup bowl
27 | wooden bowl
28 | plastic bowl
29 | cap
30 | hat
31 | snapback
32 | baseball cap
33 | visor
34 | bucket hat
35 | fedora
36 | cowboy hat
37 | knit cap
38 | nightcap
39 | earphone
40 | headphone
41 | earbud
42 | in-ear monitor
43 | earpiece
44 | earphone set
45 | earphone accessory
46 | faucet
47 | tap
48 | valve
49 | spout
50 | handle
51 | lever
52 | knob
53 | spigot
54 | water dispenser
55 | sink faucet
56 | bathtub faucet
57 | drinking fountain
58 | water faucet
59 | mixing faucet
60 | jar
61 | vase
62 | container
63 | bottle
64 | canister
65 | pot
66 | urn
67 | jug
68 | flask
69 | pitcher
70 | bucket
71 | pail
72 | basket
73 | knife
74 | blade
75 | cutter
76 | chopper
77 | cleaver
78 | pocket knife
79 | Swiss army knife
80 | laptop
81 | notebook
82 | ultrabook
83 | netbook
84 | tablet
85 | chromebook
86 | computer
87 | mug
88 | cup
89 | tumbler
90 | teacup
91 | coffee cup
92 | espresso cup
93 | beer mug
94 | wine glass
95 | goblet
96 | flute
97 | martini glass
98 | punch cup
99 | thermal mug
100 | plastic cup
101 | styrofoam cup
102 | paper cup
103 | pot
104 | flowerpot
105 | vase
106 | container
107 | plant pot
108 | flower holder
109 | telephone
110 | phone
111 | mobile phone
112 | cell phone
113 | landline
114 | cordless phone
115 | smart phone
116 | flip phone
117 | touchscreen phone
118 | wireless phone
119 | home phone
120 | office phonee
--------------------------------------------------------------------------------
/prompts/human_mug_statement_40.txt:
--------------------------------------------------------------------------------
1 | The fingers are stretching towards the mug
2 | The palm is extending towards the mug
3 | The hand is grasping for the mug
4 | The arm is trying to retrieve the mug
5 | The person is moving towards the mug with their hand
6 | The human is inching towards the mug
7 | The fingers are approaching the mug with their grip
8 | The palm is leaning towards the mug
9 | The hand is aiming for the mug with its fingers
10 | The arm is going after the mug
11 | The fingers are making a move for the mug
12 | The palm is making a reach for the mug
13 | The hand is making an effort to reach the mug
14 | The arm is making a play for the mug
15 | The fingers are stretching out to grab the mug
16 | The palm is angling for the mug
17 | The hand is snatching for the mug
18 | The arm is closing in on the mug with its hand
19 | The fingers are clenching towards the mug
20 | The hand is coveting the mug
21 | The palm is yearning for the mug
22 | The arm is desiring to hold the mug
23 | The fingers are eagerly trying to take the mug
24 | The hand is avidly reaching for the mug
25 | The palm is enthusiastically extending towards the mug
26 | The arm is zealously attempting to reach the mug
27 | The fingers are ambitiously grasping for the mug
28 | The hand is fiercely reaching for the mug
29 | The palm is vigorously attempting to take the mug
30 | The arm is robustly reaching for the mug
31 | The fingers are ardently going after the mug
32 | The hand is intently reaching for the mug
33 | The palm is doggedly grasping for the mug
34 | The arm is persistently attempting to take the mug
35 | The fingers are tenaciously reaching for the mug
36 | The hand is determined to grasp the mug
37 | The palm is unwaveringly reaching for the mug
38 | The arm is steadfastly trying to take the mug
39 | The fingers are hungrily grasping for the mug
40 | The hand is greedily reaching for the mug
--------------------------------------------------------------------------------
/prompts/jar.txt:
--------------------------------------------------------------------------------
1 | Grasp the jar
2 | Extend your arm for the jar
3 | Stretch to grab the jar
4 | Take hold of the jar
5 | Get the jar within your reach
6 | Snatch the jar
7 | Reach out for the jar
8 | Reach forward for the jar
9 | Reach up for the jar
10 | Go for the jar
11 | Aim for the jar
12 | Clutch the jar
13 | Lay your hand on the jar
14 | Nab the jar
15 | Catch the jar
16 | Hook the jar
17 | Secure the jar
18 | Seize the jar
19 | Grip the jar
20 | Snag the jar
21 | Acquire the jar
22 | Take possession of the jar
23 | Hold the jar
24 | Clasp the jar
25 | Embrace the jar
26 | Enclose the jar
27 | Envelop the jar
28 | Gather the jar
29 | Collect the jar
30 | Retrieve the jar
31 | Access the jar
32 | Attain the jar
33 | Capture the jar
34 | Draw in the jar
35 | Bring the jar closer
36 | Take the jar
37 | Take a hold of the jar
38 | Take the jar in your hand
39 | Bring the jar to yourself
40 | Take possession of the jar
41 | Lay hands on the jar
42 | Clench the jar
43 | Grab the jar
44 | Obtain the jar
45 | Take up the jar
46 | Pluck the jar
47 | Snatch up the jar
48 | Scoop up the jar
49 | Gather up the jar
50 | Close in on the jar
51 | Ensnare the jar
52 | Embrace the jar tightly
53 | Hold onto the jar
54 | Embody the jar
55 | Hold onto the jar tightly
56 | Take a firm grip of the jar
57 | Get your hands on the jar
58 | Embrace the jar with both hands
59 | Grasp onto the jar tightly
60 | Hold the jar with a firm grip
61 | Securely hold the jar
62 | Keep the jar within reach
63 | Hold the jar close
64 | Draw the jar closer
65 | Haul in the jar
66 | Take hold of the jar with purpose
67 | Take the jar with confidence
68 | Reach out and secure the jar
69 | Get a firm grasp on the jar
70 | Pick up the jar
71 | Take the jar into your hand
72 | Grip the jar firmly
73 | Tightly clasp the jar
74 | Hold the jar securely
75 | Bring the jar towards you
76 | Take control of the jar
77 | Hold onto the jar with determination
78 | Retrieve the jar with care
79 | Grasp the jar with assurance
80 | Keep the jar within arm's reach
--------------------------------------------------------------------------------
/prompts/robot_mug_statement_40.txt:
--------------------------------------------------------------------------------
1 | The automaton is extending its arm towards the cup
2 | The machine is grasping for the mug
3 | The android is stretching towards the mug
4 | The robotic arm is attempting to take the mug
5 | The gripper is trying to retrieve the mug
6 | The mechanical arm is moving towards the mug with its hand
7 | The cyborg is inching towards the mug
8 | The robotic fingers are approaching the mug with its grip
9 | The machine is leaning towards the mug
10 | The robot is aiming for the cup with its gripper
11 | The automaton is going after the mug
12 | The cyborg is making a move for the mug
13 | The robotic arm is making a reach for the mug
14 | The machine is making an effort to reach the mug
15 | The android is making a play for the mug
16 | The robotic fingers are stretching out to grab the mug
17 | The gripper is angling for the mug
18 | The robot is snatching for the mug
19 | The machine is closing in on the mug with its hand
20 | The robotic fingers are clenching towards the mug
21 | The automaton is coveting the mug
22 | The android is yearning for the mug
23 | The machine is desiring to hold the mug
24 | The robot is yearning to possess the mug
25 | The cyborg is eagerly trying to take the mug
26 | The robotic arm is avidly reaching for the mug
27 | The machine is enthusiastically extending its arm for the mug
28 | The android is zealously attempting to reach the mug
29 | The robot is ambitiously grasping for the mug
30 | The automaton is fiercely reaching for the mug
31 | The cyborg is vigorously attempting to take the mug
32 | The robotic arm is robustly reaching for the mug
33 | The machine is ardently going after the mug
34 | The android is intently reaching for the mug
35 | The gripper is doggedly grasping for the mug
36 | The robot is persistently attempting to take the mug
37 | The mechanical arm is tenaciously reaching for the mug
38 | The machine is determined to grasp the mug
39 | The robotic fingers are unwaveringly reaching for the mug
40 | The android is steadfastly trying to take the mug
--------------------------------------------------------------------------------
/common/rnd.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | class Predictor(tf.keras.Model):
4 | def __init__(self, input_shape, hidden_dim, output_dim):
5 | super(Predictor, self).__init__()
6 | self.conv = tf.keras.Sequential([
7 | tf.keras.layers.Conv2D(32, kernel_size=8, strides=4, activation='relu'),
8 | tf.keras.layers.Conv2D(64, kernel_size=4, strides=2, activation='relu'),
9 | tf.keras.layers.Conv2D(64, kernel_size=3, strides=1, activation='relu'),
10 | tf.keras.layers.Flatten()
11 | ])
12 | conv_output_size = self._get_conv_output(input_shape)
13 | self.fc = tf.keras.Sequential([
14 | tf.keras.layers.Dense(hidden_dim, activation='relu'),
15 | tf.keras.layers.Dense(output_dim)
16 | ])
17 |
18 | def call(self, x):
19 | x = self.conv(x)
20 | x = self.fc(x)
21 | return x
22 |
23 | def _get_conv_output(self, shape):
24 | x = tf.zeros((1, *shape))
25 | x = self.conv(x)
26 | return int(tf.reduce_prod(x.shape[1:]))
27 |
28 | class RandomNet(tf.keras.Model):
29 | def __init__(self, input_shape, output_dim):
30 | super(RandomNet, self).__init__()
31 | self.conv = tf.keras.Sequential([
32 | tf.keras.layers.Conv2D(32, kernel_size=8, strides=4, activation='relu'),
33 | tf.keras.layers.Conv2D(64, kernel_size=4, strides=2, activation='relu'),
34 | tf.keras.layers.Conv2D(64, kernel_size=3, strides=1, activation='relu'),
35 | tf.keras.layers.Flatten()
36 | ])
37 | conv_output_size = self._get_conv_output(input_shape)
38 | self.fc = tf.keras.Sequential([
39 | tf.keras.layers.Dense(output_dim)
40 | ])
41 |
42 | def call(self, x):
43 | x = self.conv(x)
44 | x = self.fc(x)
45 | return x
46 |
47 | def _get_conv_output(self, shape):
48 | x = tf.zeros((1, *shape))
49 | x = self.conv(x)
50 | return int(tf.reduce_prod(x.shape[1:]))
51 |
--------------------------------------------------------------------------------
/r3mreward.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | import torch
3 | import numpy as np
4 |
5 | class R3MReward(object):
6 | def __init__(self, embed_model, sentences, standardize_rewards, queue_size, update_stats_steps, num_top_images, use_lang_embeddings):
7 | self.embed_model = embed_model
8 | self.reward_model = self.embed_model.get_reward
9 | self.sentences = sentences
10 | self.standardize_rewards = standardize_rewards
11 | self.use_sharding_r3m = False
12 | self.log_top_images = True
13 | self.r3m_reward_bonus = 10.0
14 | self.update_stats_steps = update_stats_steps
15 | self.use_lang_embeddings = use_lang_embeddings
16 | if self.standardize_rewards:
17 | self.stats = {}
18 | for t in self.sentences:
19 | self.stats[t] = deque(maxlen=queue_size)
20 | if self.log_top_images:
21 | self.top_images = {}
22 | for t in self.sentences:
23 | self.top_images[t] = {"images": [], "rewards": []}
24 | self.num_top_images = num_top_images
25 |
26 | def get_lang_encoding(self, lang):
27 | return self.embed_model.lang_enc(lang)
28 |
29 | def get_reward(self, init, curr, lang, lang_emb, step=None):
30 | if isinstance(lang, int):
31 | lang_strings = [self.sentences[lang]]
32 | init_image = torch.unsqueeze(init, 0)
33 | curr_image = torch.unsqueeze(curr, 0)
34 | else:
35 | lang_strings = [self.sentences[i] for i in lang[:, 0]]
36 | init_image = init
37 | curr_image = curr
38 |
39 |
40 | init = self.embed_model(init_image[:, -3:, :, :])
41 | curr = self.embed_model(curr_image[:, -3:, :, :])
42 | if self.use_lang_embeddings:
43 | reward = self.embed_model.get_reward_le(init, curr, lang_emb)[0].unsqueeze(-1)
44 | else:
45 | reward = self.reward_model(init, curr, lang_strings)[0].unsqueeze(-1)
46 |
47 | return reward, None, None
--------------------------------------------------------------------------------
/prompts/mug.txt:
--------------------------------------------------------------------------------
1 | Grasp the mug
2 | Extend your hand for the mug
3 | Retrieve the mug
4 | Take hold of the mug
5 | Stretch out for the mug
6 | Clutch the mug
7 | Seize the mug
8 | Lay your hands on the mug
9 | Pick up the mug
10 | Get a grip on the mug
11 | Reach out for the cup
12 | Grip the cup
13 | Hold the mug
14 | Stretch for the mug
15 | Snatch the mug
16 | Embrace the mug
17 | Catch the mug
18 | Grapple for the mug
19 | Take the mug
20 | Secure the mug
21 | Cling to the mug
22 | Obtain the mug
23 | Acquire the mug
24 | Clasp the mug
25 | Nudge the mug
26 | Bring the mug closer
27 | Draw the mug nearer
28 | Gather the mug
29 | Gain possession of the mug
30 | Lay hold of the mug
31 | Retrieve the cup
32 | Clutch the cup
33 | Hold onto the cup
34 | Reach out for the beverage holder
35 | Extend your arm for the coffee mug
36 | Grab the mug
37 | Obtain the coffee vessel
38 | Snag the mug
39 | Get a hold of the mug
40 | Hook the mug
41 | Make a grab for the mug
42 | Stretch your fingers towards the mug
43 | Wrap your hand around the mug
44 | Clasp the warm mug
45 | Hold the handle of the mug
46 | Reach out and touch the mug
47 | Take the mug in your hands
48 | Grasp onto the mug's handle
49 | Retrieve the mug with ease
50 | Snatch the mug quickly
51 | Catch the mug before it falls
52 | Pull the mug towards you
53 | Reach out and grab the mug
54 | Secure the mug with your hand
55 | Get a grip on the coffee mug
56 | Hold onto the mug tightly
57 | Cling onto the mug's handle
58 | Acquire the mug without hesitation
59 | Lay your palm on the mug
60 | Nudge the mug slightly closer
61 | Draw the mug nearer to you
62 | Bring the mug within reach
63 | Obtain the mug in one swift motion
64 | Take hold of the mug firmly
65 | Grapple for the mug's handle
66 | Hook your fingers around the mug
67 | Embrace the mug lovingly
68 | Snuggle your hand around the mug
69 | Hold the mug with both hands
70 | Take the mug in a gentle grip
71 | Retrieve the cup of coffee
72 | Clutch the coffee mug tightly
73 | Hold onto the coffee cup
74 | Extend your arm and take the mug
75 | Grab onto the coffee mug's handle
76 | Obtain the warm mug with care
77 | Snag the mug from the table
78 | Hook the mug and lift it up
79 | Make a move for the mug's handle
80 | Reach for mug
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LAMP
2 |
3 | [**LA**nguage **M**odulated **P**retraining](https://arxiv.org/abs/2308.12270) (LAMP💡) is a method for pretraining a general RL agent for accelerated downstream learning by augmenting unsupervised RL rewards with extrinsic rewards parameterized by a Video-Langauge Model (VLM).
4 |
5 |
6 |
7 | ## Installation
8 | To create a conda environment called `lamp`:
9 | ```bash
10 | conda env create -f env.yml
11 | conda activate lamp
12 | ```
13 |
14 | Then, follow the RLBench installation instructions from [this fork](https://github.com/ademiadeniji/RLBench_lamp) that implements shaped rewards and the domain-randomized pretraining environment.
15 |
16 | Finally, install [our fork](https://github.com/ademiadeniji/r3m_lamp) of the R3M module that enables computing video-language alignment scores.
17 |
18 | ```bash
19 | git clone https://github.com/ademiadeniji/r3m_lamp
20 | pip install -e r3m_lamp
21 | ```
22 |
23 | ## Training
24 | To pretrain your LAMP agent run:
25 |
26 | ```bash
27 | TF_CPP_MIN_LOG_LEVEL=0 CUDA_VISIBLE_DEVICES=0 TF_XLA_FLAGS=--tf_xla_auto_jit=2 vglrun -d :0.0 python train.py --logdir /YOUR/LOGDIR/HERE --task pick_shapenet_objects --seed 1 --use_r3m_reward True --device cuda:0 --vidlang_model_device cuda:0 --use_lang_embeddings True --configs front_wrist vlsp --curriculum.objects 'bag,bowl,cap,earphone,faucet,jar,knife,laptop,mug,pot,telephone' --curriculum.num_unique_per_class '-1' --curriculum.num_objects '3' --curriculum.lang_prompt 'prompts/similar_verb_40.txt' --curriculum.synonym_folder prompts/similar_noun --curriculum.num_episodes '20000' --randomize True --expl_intr_scale 0.9 --expl_extr_scale 0.1 --plan2explore True
28 | ```
29 |
30 | To finetune your pretrained LAMP agent on the take lid off saucepan task run:
31 |
32 | ```bash
33 | TF_CPP_MIN_LOG_LEVEL=0 CUDA_VISIBLE_DEVICES=0 TF_XLA_FLAGS=--tf_xla_auto_jit=2 vglrun -d :0.0 python train.py --logdir /YOUR/LOGDIR/HERE --task take_lid_off_saucepan --seed 0 --device cuda:0 --vidlang_model_device cuda:0 --use_lang_embeddings True --configs front_wrist vlsp --curriculum.use False --critic_linear_probe True --loaddir [LOADDIR] --ts [NUM_STEPS_PRETRAINED] --plan2explore True --expl_intr_scale 0 --expl_extr_scale 1 --shaped_rewards True
34 | ```
35 | ## Citations
36 | If you use this code for your research, please cite our paper:
37 | ```sh
38 | @misc{adeniji2023language,
39 | title={Language Reward Modulation for Pretraining Reinforcement Learning},
40 | author={Ademi Adeniji and Amber Xie and Carmelo Sferrazza and Younggyo Seo and Stephen James and Pieter Abbeel},
41 | year={2023},
42 | eprint={2308.12270},
43 | archivePrefix={arXiv},
44 | primaryClass={cs.LG}
45 | }
46 | ```
--------------------------------------------------------------------------------
/common/driver.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Driver:
5 | def __init__(self, envs, **kwargs):
6 | self._envs = envs
7 | self._kwargs = kwargs
8 | self._on_steps = []
9 | self._on_resets = []
10 | self._on_episodes = []
11 | self._act_spaces = [env.act_space for env in envs]
12 | self.reset()
13 |
14 | def on_step(self, callback):
15 | self._on_steps.append(callback)
16 |
17 | def on_reset(self, callback):
18 | self._on_resets.append(callback)
19 |
20 | def on_episode(self, callback):
21 | self._on_episodes.append(callback)
22 |
23 | def reset(self):
24 | self._obs = [None] * len(self._envs)
25 | self._eps = [None] * len(self._envs)
26 | self._state = None
27 |
28 | def __call__(self, policy, steps=0, episodes=0):
29 | step, episode = 0, 0
30 | while step < steps or episode < episodes:
31 | obs = {
32 | i: self._envs[i].reset()
33 | for i, ob in enumerate(self._obs)
34 | if ob is None or ob["is_last"]
35 | }
36 | for i, ob in obs.items():
37 | self._obs[i] = ob() if callable(ob) else ob
38 | ob = self._obs[i]
39 | act = {k: np.zeros(v.shape) for k, v in self._act_spaces[i].items()}
40 | tran = {k: self._convert(v) for k, v in {**ob, **act}.items()}
41 | [fn(tran, worker=i, **self._kwargs) for fn in self._on_resets]
42 | self._eps[i] = [tran]
43 | obs = {k: np.stack([o[k] for o in self._obs]) for k in self._obs[0]}
44 | if len(self._kwargs.keys()) == 0:
45 | actions, self._state = policy(obs, self._state)
46 | else:
47 | actions, self._state = policy(obs, self._state, **self._kwargs)
48 | actions = [
49 | {k: np.array(actions[k][i]) for k in actions}
50 | for i in range(len(self._envs))
51 | ]
52 | assert len(actions) == len(self._envs)
53 | obs = [e.step(a) for e, a in zip(self._envs, actions)]
54 | obs = [ob() if callable(ob) else ob for ob in obs]
55 | for i, (act, ob) in enumerate(zip(actions, obs)):
56 | tran = {k: self._convert(v) for k, v in {**ob, **act}.items()}
57 | [fn(tran, worker=i, **self._kwargs) for fn in self._on_steps]
58 | self._eps[i].append(tran)
59 | step += 1
60 | if ob["is_last"]:
61 | ep = self._eps[i]
62 | ep = {k: self._convert([t[k] for t in ep]) for k in ep[0]}
63 | [fn(ep, **self._kwargs) for fn in self._on_episodes]
64 | episode += 1
65 | self._obs = obs
66 |
67 | def _convert(self, value):
68 | value = np.array(value)
69 | if np.issubdtype(value.dtype, np.floating):
70 | return value.astype(np.float32)
71 | elif np.issubdtype(value.dtype, np.signedinteger):
72 | return value.astype(np.int32)
73 | elif np.issubdtype(value.dtype, np.uint8):
74 | return value.astype(np.uint8)
75 | return value
76 |
--------------------------------------------------------------------------------
/common/mae_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.keras.layers as tfkl
3 | from tensorflow.keras import mixed_precision as prec
4 | import numpy as np
5 |
6 |
7 | class Token(tf.keras.layers.Layer):
8 | def __init__(
9 | self,
10 | name,
11 | embed_dim,
12 | **kwargs,
13 | ):
14 | super().__init__(**kwargs)
15 | self._name = name
16 | self.embed_dim = embed_dim
17 | self.mask_token = None
18 |
19 | def build(self, input_shape):
20 | self.mask_token = self.add_weight(
21 | name=f"{self._name}_token",
22 | shape=(1, 1, self.embed_dim),
23 | initializer=tf.random_normal_initializer(stddev=0.02),
24 | trainable=True,
25 | )
26 |
27 | def call(self, x):
28 | return self.mask_token
29 |
30 |
31 | def get_2d_sincos_pos_embed(embed_dim, grid_h_size, grid_w_size):
32 | """
33 | grid_size: int of the grid height and width
34 | return:
35 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
36 | """
37 | H, W = grid_h_size, grid_w_size
38 |
39 | grid_h = np.arange(grid_h_size, dtype=np.float32)
40 | grid_w = np.arange(grid_w_size, dtype=np.float32)
41 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
42 | grid = np.stack(grid, axis=0)
43 |
44 | grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
45 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
46 | return pos_embed
47 |
48 |
49 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
50 | assert embed_dim % 2 == 0
51 |
52 | # use half of dimensions to encode grid_h
53 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
54 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
55 |
56 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
57 | return emb
58 |
59 |
60 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
61 | """
62 | embed_dim: output dimension for each position
63 | pos: a list of positions to be encoded: size (M,)
64 | out: (M, D)
65 | """
66 | assert embed_dim % 2 == 0
67 | omega = np.arange(embed_dim // 2, dtype=np.float32)
68 | omega /= embed_dim / 2.0
69 | omega = 1.0 / 10000 ** omega # (D/2,)
70 |
71 | pos = pos.reshape(-1) # (M,)
72 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
73 |
74 | emb_sin = np.sin(out) # (M, D/2)
75 | emb_cos = np.cos(out) # (M, D/2)
76 |
77 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
78 | return emb
79 |
80 |
81 | def get_1d_sincos_pos_embed(embed_dim, grid_size):
82 | """
83 | grid_size: int of the grid height and width
84 | return:
85 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
86 | """
87 | assert embed_dim % 2 == 0
88 | grid_t = np.arange(grid_size, dtype=np.float32)
89 | grid = np.meshgrid(grid_t) # here w goes first
90 | grid = np.stack(grid, axis=0)
91 |
92 | grid = grid[0]
93 |
94 | omega = np.arange(embed_dim / 2, dtype=np.float32)
95 | omega /= embed_dim / 2
96 | omega = 1.0 / 10000 ** omega
97 | pos = grid
98 | out = np.einsum("m,d->md", pos, omega)
99 |
100 | emb_sin = np.sin(out)
101 | emb_cos = np.cos(out)
102 |
103 | pos_emb = np.concatenate([emb_sin, emb_cos], axis=1)
104 | return pos_emb
105 |
--------------------------------------------------------------------------------
/common/flags.py:
--------------------------------------------------------------------------------
1 | import re
2 | import sys
3 |
4 |
5 | class Flags:
6 | def __init__(self, *args, **kwargs):
7 | from .config import Config
8 |
9 | self._config = Config(*args, **kwargs)
10 |
11 | def parse(self, argv=None, known_only=False, help_exists=None):
12 | if help_exists is None:
13 | help_exists = not known_only
14 | if argv is None:
15 | argv = sys.argv[1:]
16 | if "--help" in argv:
17 | print("\nHelp:")
18 | lines = str(self._config).split("\n")[2:]
19 | print("\n".join("--" + re.sub(r"[:,\[\]]", "", x) for x in lines))
20 | help_exists and sys.exit()
21 | parsed = {}
22 | remaining = []
23 | key = None
24 | vals = None
25 | for arg in argv:
26 | if arg.startswith("--"):
27 | if key:
28 | self._submit_entry(key, vals, parsed, remaining)
29 | if "=" in arg:
30 | key, val = arg.split("=", 1)
31 | vals = [val]
32 | else:
33 | key, vals = arg, []
34 | else:
35 | if key:
36 | vals.append(arg)
37 | else:
38 | remaining.append(arg)
39 | self._submit_entry(key, vals, parsed, remaining)
40 | parsed = self._config.update(parsed)
41 | if known_only:
42 | return parsed, remaining
43 | else:
44 | for flag in remaining:
45 | if flag.startswith("--"):
46 | raise ValueError(f"Flag '{flag}' did not match any config keys.")
47 | assert not remaining, remaining
48 | return parsed
49 |
50 | def _submit_entry(self, key, vals, parsed, remaining):
51 | if not key and not vals:
52 | return
53 | if not key:
54 | vals = ", ".join(f"'{x}'" for x in vals)
55 | raise ValueError(f"Values {vals} were not preceeded by any flag.")
56 | name = key[len("--") :]
57 | if "=" in name:
58 | remaining.extend([key] + vals)
59 | return
60 | if self._config.IS_PATTERN.match(name):
61 | pattern = re.compile(name)
62 | keys = {k for k in self._config.flat if pattern.match(k)}
63 | elif name in self._config:
64 | keys = [name]
65 | else:
66 | keys = []
67 | if not keys:
68 | remaining.extend([key] + vals)
69 | return
70 | if not vals:
71 | raise ValueError(f"Flag '{key}' was not followed by any values.")
72 | for key in keys:
73 | parsed[key] = self._parse_flag_value(self._config[key], vals, key)
74 |
75 | def _parse_flag_value(self, default, value, key):
76 | value = value if isinstance(value, (tuple, list)) else (value,)
77 | if isinstance(default, (tuple, list)):
78 | if len(value) == 1 and "," in value[0]:
79 | value = value[0].split(",")
80 | return tuple(self._parse_flag_value(default[0], [x], key) for x in value)
81 | assert len(value) == 1, value
82 | value = str(value[0])
83 | if default is None:
84 | return value
85 | if isinstance(default, bool):
86 | try:
87 | return bool(["False", "True"].index(value))
88 | except ValueError:
89 | message = f"Expected bool but got '{value}' for key '{key}'."
90 | raise TypeError(message)
91 | if isinstance(default, int):
92 | value = float(value) # Allow scientific notation for integers.
93 | if float(int(value)) != value:
94 | message = f"Expected int but got float '{value}' for key '{key}'."
95 | raise TypeError(message)
96 | return int(value)
97 | return type(default)(value)
98 |
--------------------------------------------------------------------------------
/common/base_envs.py:
--------------------------------------------------------------------------------
1 | import threading
2 | import os
3 | import re
4 |
5 | import gym
6 | import numpy as np
7 | import pickle
8 | from d4rl.kitchen.adept_envs.simulation.renderer import DMRenderer, MjPyRenderer
9 | from d4rl.kitchen.adept_envs.simulation.sim_robot import RenderMode
10 |
11 |
12 | def get_device_id():
13 | return int(os.environ.get('GL_DEVICE_ID', 0))
14 |
15 |
16 | class DmBenchEnv():
17 |
18 | def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
19 | domain, task = name.split('_', 1)
20 | if domain == 'cup': # Only domain with multiple words.
21 | domain = 'ball_in_cup'
22 | if isinstance(domain, str):
23 | from dm_control import suite
24 | self._env = suite.load(domain, task)
25 | else:
26 | assert task is None
27 | self._env = domain()
28 | self._action_repeat = action_repeat
29 | self._size = size
30 | if camera is None:
31 | camera = dict(quadruped=2).get(domain, 0)
32 | self._camera = camera
33 |
34 | @property
35 | def observation_space(self):
36 | spaces = {}
37 | for key, value in self._env.observation_spec().items():
38 | spaces[key] = gym.spaces.Box(
39 | -np.inf, np.inf, value.shape, dtype=np.float32)
40 | spaces['image'] = gym.spaces.Box(
41 | 0, 255, self._size + (3,), dtype=np.uint8)
42 | return gym.spaces.Dict(spaces)
43 |
44 | def _update_obs(self, obs):
45 | return obs
46 |
47 | @property
48 | def action_space(self):
49 | spec = self._env.action_spec()
50 | return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32)
51 |
52 | def step(self, action):
53 | assert np.isfinite(action).all(), action
54 | reward = 0
55 | for _ in range(self._action_repeat):
56 | time_step = self._env.step(action)
57 | reward += time_step.reward or 0
58 | if time_step.last():
59 | break
60 | obs = dict(time_step.observation)
61 | obs['image'] = self.render()
62 | done = time_step.last()
63 | info = {'discount': np.array(time_step.discount, np.float32)}
64 |
65 | return self._update_obs(obs), reward, done, info
66 |
67 | def reset(self):
68 | time_step = self._env.reset()
69 | obs = dict(time_step.observation)
70 | obs['image'] = self.render()
71 | return self._update_obs(obs)
72 |
73 | def render(self, *args, **kwargs):
74 | if kwargs.get('mode', 'rgb_array') != 'rgb_array':
75 | raise ValueError("Only render mode 'rgb_array' is supported.")
76 | return self._env.physics.render(*self._size, camera_id=self._camera)
77 |
78 |
79 | class BenchEnv():
80 | LOCK = threading.Lock()
81 |
82 | def __init__(self, action_repeat, width=64):
83 | self._action_repeat = action_repeat
84 | self._width = width
85 | self._size = (self._width, self._width)
86 |
87 | @property
88 | def observation_space(self):
89 | shape = self._size + (3,)
90 | # space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
91 | spaces = {"reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32),
92 | }
93 | return spaces
94 | # return gym.spaces.Dict({'image': space})
95 |
96 | @property
97 | def action_space(self):
98 | return self._env.action_space
99 |
100 | def close(self):
101 | return self._env.close()
102 |
103 | def reset(self):
104 | with self.LOCK:
105 | state = self._env.reset()
106 | return self._get_obs(state)
107 |
108 | def step(self, action):
109 | total_reward = 0.0
110 | for step in range(self._action_repeat):
111 | state, reward, done, info = self._env.step(action)
112 | total_reward += reward
113 | if done:
114 | break
115 | obs = self._get_obs(state)
116 | return obs, total_reward, done, info
117 |
118 | def render(self, mode):
119 | return self._env.render(mode, self._width, self._width)
120 |
121 | def render_offscreen(self):
122 | img = self.renderer.render_offscreen(
123 | self._width, self._width, mode=RenderMode.RGB, camera_id=-1)
124 | return np.flipud(np.fliplr(img))
125 |
126 | def _get_obs(self, state):
127 | return {'image': self.render_offscreen(), 'state': state}
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name: lamp
2 | channels:
3 | - nvidia
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=5.1=1_gnu
9 | - ca-certificates=2022.12.7=ha878542_0
10 | - certifi=2022.12.7=pyhd8ed1ab_0
11 | - cudatoolkit=11.1.74=h6bb024c_0
12 | - cudnn=8.2.1.32=h86fa8c9_0
13 | - ld_impl_linux-64=2.38=h1181459_1
14 | - libblas=3.9.0=15_linux64_openblas
15 | - libcblas=3.9.0=15_linux64_openblas
16 | - libffi=3.4.2=h6a678d5_6
17 | - libgcc-ng=11.2.0=h1234567_1
18 | - libgfortran-ng=12.2.0=h69a702a_19
19 | - libgfortran5=12.2.0=h337968e_19
20 | - libgomp=11.2.0=h1234567_1
21 | - liblapack=3.9.0=15_linux64_openblas
22 | - libopenblas=0.3.20=pthreads_h78a6416_0
23 | - libstdcxx-ng=11.2.0=h1234567_1
24 | - ncurses=6.4=h6a678d5_0
25 | - numpy
26 | - openssl=1.1.1t=h7f8727e_0
27 | - pip==22.3.1
28 | - python=3.8.16=h7a1cb2a_2
29 | - python_abi=3.8=2_cp38
30 | - readline=8.2=h5eee18b_0
31 | - sqlite=3.40.1=h5082296_0
32 | - tk=8.6.12=h1ccaba5_0
33 | - wheel=0.37.1=pyhd3eb1b0_0
34 | - xz=5.2.10=h5eee18b_1
35 | - zlib=1.2.13=h5eee18b_0
36 | - pip:
37 | - pip==22.3.1
38 | - absl-py==0.15.0
39 | - albumentations==1.2.1
40 | - antlr4-python3-runtime==4.8
41 | - astunparse==1.6.3
42 | - atari-py==0.2.9
43 | - beautifulsoup4==4.11.2
44 | - cached-property==1.5.2
45 | - cachetools==5.2.0
46 | - cffi==1.14.2
47 | - charset-normalizer==3.0.1
48 | - clang==5.0
49 | - click==8.1.3
50 | - cloudpickle==2.1.0
51 | - cycler==0.11.0
52 | - cython==0.29.32
53 | - decorator==5.1.1
54 | - dm-control==1.0.10
55 | - dm-env==1.5
56 | - dm-tree==0.1.7
57 | - fasteners==0.17.3
58 | - filelock==3.8.0
59 | - flatbuffers==1.12
60 | - fonttools==4.38.0
61 | - ftfy==6.1.1
62 | - gast==0.4.0
63 | - gdown==4.4.0
64 | - glfw==2.5.4
65 | - google-auth==2.11.0
66 | - google-auth-oauthlib==0.4.6
67 | - google-pasta==0.2.0
68 | - grpcio==1.48.1
69 | - html-testrunner==1.2.1
70 | - hydra-core==1.1.1
71 | - idna==3.4
72 | - imageio==2.21.2
73 | - importlib-metadata==4.12.0
74 | - importlib-resources==5.10.2
75 | - jinja2==3.1.2
76 | - joblib==1.1.0
77 | - keras==2.6.0
78 | - keras-preprocessing==1.1.2
79 | - llvmlite==0.39.1
80 | - lxml==4.9.1
81 | - markdown==3.4.1
82 | - markupsafe==2.1.1
83 | - matplotlib
84 | - natsort==8.2.0
85 | - networkx==2.6.3
86 | - numpy-quaternion==2022.4.3
87 | - nvidia-cublas-cu11==11.10.3.66
88 | - nvidia-cuda-nvrtc-cu11==11.7.99
89 | - nvidia-cuda-runtime-cu11==11.7.99
90 | - nvidia-cudnn-cu11==8.5.0.96
91 | - oauthlib==3.2.0
92 | - omegaconf==2.1.1
93 | - opencv-python==4.1.2.30
94 | - opencv-python-headless==4.6.0.66
95 | - opensimplex==0.3
96 | - opt-einsum==3.3.0
97 | - packaging==21.3
98 | - pandas==1.4.0
99 | - patchelf==0.15.0.0
100 | - pillow==9.4.0
101 | - protobuf==3.19.4
102 | - pyasn1==0.4.8
103 | - pyasn1-modules==0.2.8
104 | - pybullet==3.2.5
105 | - pycparser==2.21
106 | - pyopengl==3.1.6
107 | - pyparsing==3.0.9
108 | - pyquaternion==0.9.9
109 | - python-dateutil==2.8.2
110 | - pytz==2022.7.1
111 | - pywavelets==1.3.0
112 | - pyyaml==6.0
113 | - regex==2022.9.13
114 | - requests==2.28.2
115 | - requests-oauthlib==1.3.1
116 | - rsa==4.9
117 | - ruamel-yaml==0.17.21
118 | - ruamel-yaml-clib==0.2.6
119 | - scikit-image
120 | - scikit-learn
121 | - scipy==1.7.3
122 | - six
123 | - soupsieve==2.3.2.post1
124 | - tensorboard==2.9.0
125 | - tensorboard-data-server==0.6.1
126 | - tensorboard-plugin-wit==1.8.1
127 | - tensorflow==2.6.1
128 | - tensorflow-addons==0.18.0
129 | - tensorflow-estimator==2.6.0
130 | - tensorflow-gpu==2.6.0
131 | - tensorflow-hub==0.12.0
132 | - tensorflow-probability==0.14.1
133 | - tensorflow-text==2.6.0
134 | - termcolor==1.1.0
135 | - tfimm==0.2.6
136 | - threadpoolctl==3.1.0
137 | - tifffile==2021.11.2
138 | - timm==0.6.13
139 | - tokenizers==0.13.1
140 | - torch==1.13.1
141 | - torchvision==0.14.1
142 | - tqdm==4.64.1
143 | - transformers==4.23.1
144 | - transforms3d==0.4.1
145 | - typeguard==2.13.3
146 | - urllib3==1.26.14
147 | - wcwidth==0.2.6
148 | - werkzeug==2.2.2
149 | - wrapt==1.12.1
150 | - zipp==3.8.1
151 | - psutil==5.7.0
152 | - setuptools==65.5.0
153 | - gym==0.21.0
154 | - git+https://github.com/openai/CLIP.git
155 |
--------------------------------------------------------------------------------
/common/rlbench_utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import re
3 | import copy
4 | import numpy as np
5 |
6 |
7 | def collect_demo(
8 | env,
9 | replay,
10 | num_demos,
11 | camera_keys,
12 | shaped_rewards=False,
13 | task_name_to_num=None,
14 | ):
15 | transitions = []
16 | for _ in range(num_demos):
17 | success = False
18 | while not success:
19 | try:
20 | demo = env._task.get_demos(1, live_demos=True)[0]
21 | success = True
22 | except:
23 | pass
24 | transitions.extend(extract_from_demo(demo, shaped_rewards, camera_keys, env.task_name, task_name_to_num))
25 |
26 | # Restrict translation space by min_max
27 | actions = []
28 |
29 | for obs in transitions:
30 | if obs["is_first"]:
31 | continue
32 | action = obs["action"]
33 | actions.append(action)
34 |
35 | low, high = np.min(actions, 0)[:3], np.max(actions, 0)[:3]
36 | low -= 0.2 * np.fabs(low)
37 | high += 0.2 * np.fabs(high)
38 |
39 | for obs in transitions:
40 | if obs["is_first"]:
41 | # for first action, let's just label with zero action
42 | obs["action"] = np.zeros(3 + 1)
43 | else:
44 | action = obs["action"]
45 | updated_action = []
46 |
47 | pose = action[:3]
48 | norm_pose = 2 * ((pose - low) / (high - low)) - 1
49 | updated_action.append(norm_pose)
50 |
51 | gripper = action[3:4]
52 | norm_gripper = gripper * 2 - 1.0
53 | updated_action.append(norm_gripper)
54 |
55 | obs["action"] = np.hstack(updated_action)
56 |
57 | replay.add_step(obs)
58 |
59 | print(f"Position min/max: {low}/{high}")
60 | actions_min_max = low, high
61 |
62 | return actions_min_max
63 |
64 |
65 | def get_action(prev_obs, obs):
66 | prev_pose = prev_obs.gripper_pose[:3]
67 | cur_pose = obs.gripper_pose[:3]
68 | pose = cur_pose - prev_pose
69 | gripper_action = float(obs.gripper_open)
70 | prev_action = np.hstack([pose, gripper_action])
71 | return prev_action
72 |
73 |
74 | def extract_from_demo(demo, shaped_rewards, camera_keys, task_name=None, task_name_to_num=None):
75 | transitions = []
76 | if task_name_to_num is not None:
77 | init_image, init_state = None, None
78 | for k, obs in enumerate(demo):
79 | if k == 0:
80 | prev_action = None
81 | else:
82 | prev_obs = demo[k - 1]
83 | prev_action = get_action(prev_obs, obs)
84 |
85 | terminal = k == len(demo) - 1
86 | first = k == 0
87 | success = terminal
88 |
89 | if shaped_rewards:
90 | reward = obs.task_low_dim_state[0]
91 | else:
92 | reward = float(success)
93 |
94 | # Not to override obs
95 | _obs = copy.deepcopy(obs)
96 | _obs.joint_velocities = None
97 | _obs.joint_positions = None
98 | _obs.task_low_dim_state = None
99 |
100 | transition = {
101 | "reward": reward,
102 | "is_first": first,
103 | "is_last": False,
104 | "is_terminal": False,
105 | "success": success,
106 | "action": prev_action,
107 | "state": _obs.get_low_dim_data(),
108 | }
109 |
110 | keys = get_camera_keys(camera_keys)
111 | images = []
112 | for key in keys:
113 | if key == "image_front":
114 | images.append(_obs.front_rgb)
115 | if key == "image_wrist":
116 | images.append(_obs.wrist_rgb)
117 | transition["image"] = np.concatenate(images, axis=-2)
118 | if task_name_to_num is not None:
119 | if k == 0:
120 | init_image = transition["image"]
121 | init_state = transition["state"]
122 | transition["init_image"] = init_image
123 | transition["init_state"] = init_state
124 | transition['task_num'] = task_name_to_num[task_name]
125 | transitions.append(transition)
126 |
127 | if len(transitions) % 50 == 0:
128 | time_limit = len(transitions)
129 | else:
130 | time_limit = 50 * (1 + (len(transitions) // 50))
131 | while len(transitions) < time_limit:
132 | transitions.append(copy.deepcopy(transition))
133 | transitions[-1]["is_last"] = True
134 | return transitions
135 |
136 |
137 | def get_camera_keys(keys):
138 | camera_keys = keys.split("|")
139 | return camera_keys
140 |
--------------------------------------------------------------------------------
/common/expl.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow_probability import distributions as tfd
3 |
4 | import agent
5 | import common
6 |
7 |
8 | class Random(common.Module):
9 |
10 | def __init__(self, config, act_space, wm, tfstep, reward):
11 | self.config = config
12 | self.act_space = self.act_space
13 |
14 | def actor(self, feat):
15 | shape = feat.shape[:-1] + self.act_space.shape
16 | if self.config.actor.dist == 'onehot':
17 | return common.OneHotDist(tf.zeros(shape))
18 | else:
19 | dist = tfd.Uniform(-tf.ones(shape), tf.ones(shape))
20 | return tfd.Independent(dist, 1)
21 |
22 | def train(self, start, context, data):
23 | return None, {}
24 |
25 |
26 | class Plan2Explore(common.Module):
27 |
28 | def __init__(self, config, act_space, wm, tfstep, reward):
29 | self.config = config
30 | self.reward = reward
31 | self.wm = wm
32 | self.ac = agent.ActorCritic(config, act_space, tfstep)
33 | self.actor = self.ac.actor
34 | stoch_size = config.rssm.stoch
35 | if config.rssm.discrete:
36 | stoch_size *= config.rssm.discrete
37 | size = {
38 | 'stoch': stoch_size,
39 | 'deter': config.rssm.deter,
40 | 'feat': config.rssm.stoch + config.rssm.deter,
41 | }[self.config.disag_target]
42 | self._networks = [
43 | common.MLP(size, **config.expl_head)
44 | for _ in range(config.disag_models)]
45 | self.opt = common.Optimizer('expl', **config.expl_opt)
46 | self.extr_rewnorm = common.StreamNorm(**self.config.expl_reward_norm)
47 | self.intr_rewnorm = common.StreamNorm(**self.config.expl_reward_norm)
48 |
49 | def train(self, start, context, data):
50 | metrics = {}
51 | stoch = start['stoch']
52 | if self.config.rssm.discrete:
53 | stoch = tf.reshape(
54 | stoch, stoch.shape[:-2] + (stoch.shape[-2] * stoch.shape[-1]))
55 | target = {
56 | 'embed': context['embed'],
57 | 'stoch': stoch,
58 | 'deter': start['deter'],
59 | 'feat': context['feat'],
60 | }[self.config.disag_target]
61 | inputs = context['feat']
62 | if self.config.disag_action_cond:
63 | action = tf.cast(data['action'], inputs.dtype)
64 | inputs = tf.concat([inputs, action], -1)
65 | metrics.update(self._train_ensemble(inputs, target))
66 | metrics.update(self.ac.train(
67 | self.wm, start, data['is_terminal'], self._intr_reward))
68 | return None, metrics
69 |
70 | def _intr_reward(self, seq):
71 | inputs = seq['feat']
72 | if self.config.disag_action_cond:
73 | action = tf.cast(seq['action'], inputs.dtype)
74 | inputs = tf.concat([inputs, action], -1)
75 | preds = [head(inputs).mode() for head in self._networks]
76 | disag = tf.tensor(preds).std(0).mean(-1)
77 | if self.config.disag_log:
78 | disag = tf.math.log(disag)
79 | reward = self.config.expl_intr_scale * self.intr_rewnorm(disag)[0]
80 | if self.config.expl_extr_scale:
81 | reward += self.config.expl_extr_scale * self.extr_rewnorm(
82 | self.reward(seq))[0]
83 | return reward
84 |
85 | def _train_ensemble(self, inputs, targets):
86 | if self.config.disag_offset:
87 | targets = targets[:, self.config.disag_offset:]
88 | inputs = inputs[:, :-self.config.disag_offset]
89 | targets = tf.stop_gradient(targets)
90 | inputs = tf.stop_gradient(inputs)
91 | with tf.GradientTape() as tape:
92 | preds = [head(inputs) for head in self._networks]
93 | loss = -sum([pred.log_prob(targets).mean() for pred in preds])
94 | metrics = self.opt(tape, loss, self._networks)
95 | return metrics
96 |
97 |
98 | class ModelLoss(common.Module):
99 |
100 | def __init__(self, config, act_space, wm, tfstep, reward):
101 | self.config = config
102 | self.reward = reward
103 | self.wm = wm
104 | self.ac = agent.ActorCritic(config, act_space, tfstep)
105 | self.actor = self.ac.actor
106 | self.head = common.MLP([], **self.config.expl_head)
107 | self.opt = common.Optimizer('expl', **self.config.expl_opt)
108 |
109 | def train(self, start, context, data):
110 | metrics = {}
111 | target = tf.cast(context[self.config.expl_model_loss], tf.float32)
112 | with tf.GradientTape() as tape:
113 | loss = -self.head(context['feat']).log_prob(target).mean()
114 | metrics.update(self.opt(tape, loss, self.head))
115 | metrics.update(self.ac.train(
116 | self.wm, start, data['is_terminal'], self._intr_reward))
117 | return None, metrics
118 |
119 | def _intr_reward(self, seq):
120 | reward = self.config.expl_intr_scale * self.head(seq['feat']).mode()
121 | if self.config.expl_extr_scale:
122 | reward += self.config.expl_extr_scale * self.reward(seq)
123 | return reward
--------------------------------------------------------------------------------
/common/dists.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import tensorflow_probability as tfp
4 | from tensorflow_probability import distributions as tfd
5 |
6 | import common
7 |
8 | # Patch to ignore seed to avoid synchronization across GPUs.
9 | _orig_random_categorical = tf.random.categorical
10 |
11 |
12 | def random_categorical(*args, **kwargs):
13 | kwargs["seed"] = None
14 | return _orig_random_categorical(*args, **kwargs)
15 |
16 |
17 | tf.random.categorical = random_categorical
18 |
19 | # Patch to ignore seed to avoid synchronization across GPUs.
20 | _orig_random_normal = tf.random.normal
21 |
22 |
23 | def random_normal(*args, **kwargs):
24 | kwargs["seed"] = None
25 | return _orig_random_normal(*args, **kwargs)
26 |
27 |
28 | tf.random.normal = random_normal
29 |
30 |
31 | class SampleDist:
32 | def __init__(self, dist, samples=100):
33 | self._dist = dist
34 | self._samples = samples
35 |
36 | @property
37 | def name(self):
38 | return "SampleDist"
39 |
40 | def __getattr__(self, name):
41 | return getattr(self._dist, name)
42 |
43 | def mean(self):
44 | samples = self._dist.sample(self._samples)
45 | return samples.mean(0)
46 |
47 | def mode(self):
48 | sample = self._dist.sample(self._samples)
49 | logprob = self._dist.log_prob(sample)
50 | return tf.gather(sample, tf.argmax(logprob))[0]
51 |
52 | def entropy(self):
53 | sample = self._dist.sample(self._samples)
54 | logprob = self.log_prob(sample)
55 | return -logprob.mean(0)
56 |
57 |
58 | class OneHotDist(tfd.OneHotCategorical):
59 | def __init__(self, logits=None, probs=None, dtype=None):
60 | self._sample_dtype = dtype or tf.float32
61 | super().__init__(logits=logits, probs=probs)
62 |
63 | def mode(self):
64 | return tf.cast(super().mode(), self._sample_dtype)
65 |
66 | def sample(self, sample_shape=(), seed=None):
67 | # Straight through biased gradient estimator.
68 | sample = tf.cast(super().sample(sample_shape, seed), self._sample_dtype)
69 | probs = self._pad(super().probs_parameter(), sample.shape)
70 | sample += tf.cast(probs - tf.stop_gradient(probs), self._sample_dtype)
71 | return sample
72 |
73 | def _pad(self, tensor, shape):
74 | tensor = super().probs_parameter()
75 | while len(tensor.shape) < len(shape):
76 | tensor = tensor[None]
77 | return tensor
78 |
79 |
80 | class TruncNormalDist(tfd.TruncatedNormal):
81 | def __init__(self, loc, scale, low, high, clip=1e-6, mult=1):
82 | super().__init__(loc, scale, low, high)
83 | self._clip = clip
84 | self._mult = mult
85 |
86 | def sample(self, *args, **kwargs):
87 | event = super().sample(*args, **kwargs)
88 | if self._clip:
89 | clipped = tf.clip_by_value(
90 | event, self.low + self._clip, self.high - self._clip
91 | )
92 | event = event - tf.stop_gradient(event) + tf.stop_gradient(clipped)
93 | if self._mult:
94 | event *= self._mult
95 | return event
96 |
97 |
98 | class TanhBijector(tfp.bijectors.Bijector):
99 | def __init__(self, validate_args=False, name="tanh"):
100 | super().__init__(
101 | forward_min_event_ndims=0, validate_args=validate_args, name=name
102 | )
103 |
104 | def _forward(self, x):
105 | return tf.nn.tanh(x)
106 |
107 | def _inverse(self, y):
108 | dtype = y.dtype
109 | y = tf.cast(y, tf.float32)
110 | y = tf.where(
111 | tf.less_equal(tf.abs(y), 1.0),
112 | tf.clip_by_value(y, -0.99999997, 0.99999997),
113 | y,
114 | )
115 | y = tf.atanh(y)
116 | y = tf.cast(y, dtype)
117 | return y
118 |
119 | def _forward_log_det_jacobian(self, x):
120 | log2 = tf.math.log(tf.constant(2.0, dtype=x.dtype))
121 | return 2.0 * (log2 - x - tf.nn.softplus(-2.0 * x))
122 |
123 |
124 | class MSEDist:
125 | def __init__(self, mode, dims, agg="sum"):
126 | self._mode = mode
127 | self._dims = tuple([-x for x in range(1, dims + 1)])
128 | self._agg = agg
129 | self.batch_shape = mode.shape[: len(mode.shape) - dims]
130 | self.event_shape = mode.shape[len(mode.shape) - dims :]
131 |
132 | def mode(self):
133 | return self._mode
134 |
135 | def mean(self):
136 | return self._mode
137 |
138 | def log_prob(self, value):
139 | assert self._mode.shape == value.shape, (self._mode.shape, value.shape)
140 | distance = (self._mode - value) ** 2
141 | if self._agg == "mean":
142 | loss = distance.mean(self._dims)
143 | elif self._agg == "sum":
144 | loss = distance.sum(self._dims)
145 | else:
146 | raise NotImplementedError(self._agg)
147 | return -loss
148 |
149 |
150 | class SymlogDist:
151 | def __init__(self, mode, dims, agg="sum"):
152 | self._mode = mode
153 | self._dims = tuple([-x for x in range(1, dims + 1)])
154 | self._agg = agg
155 | self.batch_shape = mode.shape[: len(mode.shape) - dims]
156 | self.event_shape = mode.shape[len(mode.shape) - dims :]
157 |
158 | def mode(self):
159 | return symexp(self._mode)
160 |
161 | def mean(self):
162 | return symexp(self._mode)
163 |
164 | def log_prob(self, value):
165 | assert self._mode.shape == value.shape, (self._mode.shape, value.shape)
166 | distance = (self._mode - symlog(value)) ** 2
167 | if self._agg == "mean":
168 | loss = distance.mean(self._dims)
169 | elif self._agg == "sum":
170 | loss = distance.sum(self._dims)
171 | else:
172 | raise NotImplementedError(self._agg)
173 | return -loss
174 |
175 | def symlog(x):
176 | return tf.sign(x) * tf.math.log(1 + tf.abs(x))
177 |
178 |
179 | def symexp(x):
180 | return tf.sign(x) * (tf.math.exp(tf.abs(x)) - 1)
181 |
--------------------------------------------------------------------------------
/common/logger.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import pathlib
4 | import time
5 |
6 | import numpy as np
7 |
8 |
9 | class Logger:
10 | def __init__(self, step, outputs, multiplier=1):
11 | self._step = step
12 | self._outputs = outputs
13 | self._multiplier = multiplier
14 | self._last_step = None
15 | self._last_time = None
16 | self._metrics = []
17 |
18 | def add(self, mapping, prefix=None):
19 | step = int(self._step) * self._multiplier
20 | for name, value in dict(mapping).items():
21 | name = f"{prefix}_{name}" if prefix else name
22 | value = np.array(value)
23 | if len(value.shape) not in (0, 2, 3, 4):
24 | raise ValueError(
25 | f"Shape {value.shape} for name '{name}' cannot be "
26 | "interpreted as scalar, image, or video."
27 | )
28 | self._metrics.append((step, name, value))
29 |
30 | def scalar(self, name, value):
31 | self.add({name: value})
32 |
33 | def image(self, name, value):
34 | self.add({name: value})
35 |
36 | def video(self, name, value):
37 | self.add({name: value})
38 |
39 | def write(self, fps=False):
40 | fps and self.scalar("fps", self._compute_fps())
41 | if not self._metrics:
42 | return
43 | for output in self._outputs:
44 | output(self._metrics)
45 | self._metrics.clear()
46 |
47 | def _compute_fps(self):
48 | step = int(self._step) * self._multiplier
49 | if self._last_step is None:
50 | self._last_time = time.time()
51 | self._last_step = step
52 | return 0
53 | steps = step - self._last_step
54 | duration = time.time() - self._last_time
55 | self._last_time += duration
56 | self._last_step = step
57 | return steps / duration
58 |
59 |
60 | class TerminalOutput:
61 | def __call__(self, summaries):
62 | step = max(s for s, _, _, in summaries)
63 | scalars = {k: float(v) for _, k, v in summaries if len(v.shape) == 0}
64 | formatted = {k: self._format_value(v) for k, v in scalars.items()}
65 | print(f"[{step}]", " / ".join(f"{k} {v}" for k, v in formatted.items()))
66 |
67 | def _format_value(self, value):
68 | if value == 0:
69 | return "0"
70 | elif 0.01 < abs(value) < 10000:
71 | value = f"{value:.2f}"
72 | value = value.rstrip("0")
73 | value = value.rstrip("0")
74 | value = value.rstrip(".")
75 | return value
76 | else:
77 | value = f"{value:.1e}"
78 | value = value.replace(".0e", "e")
79 | value = value.replace("+0", "")
80 | value = value.replace("+", "")
81 | value = value.replace("-0", "-")
82 | return value
83 |
84 |
85 | class JSONLOutput:
86 | def __init__(self, logdir):
87 | self._logdir = pathlib.Path(logdir).expanduser()
88 |
89 | def __call__(self, summaries):
90 | scalars = {k: float(v) for _, k, v in summaries if len(v.shape) == 0}
91 | step = max(s for s, _, _, in summaries)
92 | with (self._logdir / "metrics.jsonl").open("a") as f:
93 | f.write(json.dumps({"step": step, **scalars}) + "\n")
94 |
95 |
96 | class TensorBoardOutput:
97 | def __init__(self, logdir, fps=20):
98 | # The TensorFlow summary writer supports file protocols like gs://. We use
99 | # os.path over pathlib here to preserve those prefixes.
100 | self._logdir = os.path.expanduser(logdir)
101 | self._writer = None
102 | self._fps = fps
103 |
104 | def __call__(self, summaries):
105 | import tensorflow as tf
106 |
107 | self._ensure_writer()
108 | self._writer.set_as_default()
109 | for step, name, value in summaries:
110 | if len(value.shape) == 0:
111 | tf.summary.scalar("scalars/" + name, value, step)
112 | elif len(value.shape) == 2:
113 | tf.summary.image(name, value, step)
114 | elif len(value.shape) == 3:
115 | tf.summary.image(name, value, step)
116 | elif len(value.shape) == 4:
117 | self._video_summary(name, value, step)
118 | self._writer.flush()
119 |
120 | def _ensure_writer(self):
121 | if not self._writer:
122 | import tensorflow as tf
123 |
124 | self._writer = tf.summary.create_file_writer(self._logdir, max_queue=1000)
125 |
126 | def _video_summary(self, name, video, step):
127 | import tensorflow as tf
128 | import tensorflow.compat.v1 as tf1
129 |
130 | name = name if isinstance(name, str) else name.decode("utf-8")
131 | if np.issubdtype(video.dtype, np.floating):
132 | video = np.clip(255 * video, 0, 255).astype(np.uint8)
133 | try:
134 | T, H, W, C = video.shape
135 | summary = tf1.Summary()
136 | image = tf1.Summary.Image(height=H, width=W, colorspace=C)
137 | image.encoded_image_string = encode_gif(video, self._fps)
138 | summary.value.add(tag=name, image=image)
139 | tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step)
140 | except (IOError, OSError) as e:
141 | print("GIF summaries require ffmpeg in $PATH.", e)
142 | tf.summary.image(name, video, step)
143 |
144 |
145 | def encode_gif(frames, fps):
146 | from subprocess import Popen, PIPE
147 |
148 | h, w, c = frames[0].shape
149 | pxfmt = {1: "gray", 3: "rgb24"}[c]
150 | cmd = " ".join(
151 | [
152 | "ffmpeg -y -f rawvideo -vcodec rawvideo",
153 | f"-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex",
154 | "[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse",
155 | f"-r {fps:.02f} -f gif -",
156 | ]
157 | )
158 | proc = Popen(cmd.split(" "), stdin=PIPE, stdout=PIPE, stderr=PIPE)
159 | for image in frames:
160 | proc.stdin.write(image.tobytes())
161 | out, err = proc.communicate()
162 | if proc.returncode:
163 | raise IOError("\n".join([" ".join(cmd), err.decode("utf8")]))
164 | del proc
165 | return out
166 |
--------------------------------------------------------------------------------
/configs.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 |
3 | # Train Script
4 | logdir: /dev/null
5 | loaddir: ''
6 | ts: ''
7 | seed: 0
8 | task: 'take_lid_off_saucepan'
9 | envs: 1
10 | envs_parallel: process
11 | render_size: [128, 128]
12 | time_limit: 150
13 | steps: 1002000
14 | log_every: 1e3
15 | eval_every: 1e4
16 | eval_eps: 10
17 | prefill: 0
18 | pretrain: 1000
19 | mae_pretrain: 10000
20 | train_every: 2
21 | train_mae_every: 2
22 | train_steps: 1
23 | train_mae_steps: 1
24 | replay: {capacity: 2e6, minlen: 1, maxlen: 50, prioritize_ends: True}
25 | dataset: {batch: 50, length: 50}
26 | mae_dataset: {batch: 32, length: 32}
27 | log_keys_video: ['image']
28 | log_keys_sum: '^$'
29 | log_keys_mean: '^$'
30 | log_keys_max: '^$'
31 | precision: 16
32 | jit: True
33 | action_repeat: 1
34 | device: 'cuda:0'
35 | vidlang_model_device: 'cuda:1'
36 | actor_linear_probe: False
37 | critic_linear_probe: False
38 | scripted_corner: 'top_left'
39 | randomize: False
40 | tune_instruction: False
41 | num_tune_instructions: 10
42 | instructions_file: ''
43 |
44 | # Env
45 | eval_noise: 0.0
46 | expl_noise: 0.0
47 | franka_kitchen: False
48 |
49 | # Agent
50 | clip_rewards: identity
51 |
52 | # Foundation Model
53 | use_r3m_reward: False
54 | use_internvideo_reward: False
55 | use_clip_reward: False
56 | multi_vidlang: False
57 | multi_task_vidlang: False
58 | internvideo_load_dir: "InternVideo/Pretrain/Multi-Modalities-Pretraining/models/InternVideo-MM-B-16.ckpt"
59 | standardize_rewards: False
60 | queue_size: 100000
61 | update_stats_steps: 2000
62 | num_top_images: 2
63 | use_zero_rewards: False
64 | boundary_reward_penalty: False
65 | use_lang_embeddings: False
66 |
67 | # Demo
68 | num_demos: 100
69 | shaped_rewards: True
70 |
71 | # MAE
72 | camera_keys: 'image_front|image_wrist'
73 | mask_ratio: 0.95
74 | mae: {img_h_size: 128, img_w_size: 128, patch_size: 16, embed_dim: 256, depth: 8, num_heads: 4, decoder_embed_dim: 256, decoder_depth: 6, decoder_num_heads: 4, reward_pred: True, early_conv: True, state_pred: True, in_chans: 3, ncams: 0, state_dim: 10, view_masking: True, control_input: 'front_wrist'}
75 | wm_flat_vit: {img_h_size: 8, img_w_size: 8, patch_size: 1, embed_dim: 128, depth: 2, num_heads: 4, decoder_embed_dim: 128, decoder_depth: 2, decoder_num_heads: 4, in_chans: 256, state_pred: False}
76 | image_t_size: 4
77 | mae_chunk: 1
78 | mae_avg: False
79 |
80 | # World Model
81 | grad_heads: [reward, discount]
82 | pred_discount: True
83 | rssm: {action_free: False, hidden: 1024, deter: 1024, stoch: 32, discrete: 32, act: elu, norm: none, std_act: sigmoid2, min_std: 0.1}
84 | reward_head: {layers: [512, 512, 512, 512], act: elu, norm: none, dist: symlog}
85 | discount_head: {layers: [512, 512, 512, 512], act: elu, norm: none, dist: binary}
86 | loss_scales: {feature: 1.0, kl: 1.0, reward: 1.0, discount: 1.0, proprio: 1.0, mae_reward: 1.0}
87 | wmkl: {scale: 1.0}
88 | wmkl_minloss: 0.1
89 | wmkl_balance: 0.8
90 | model_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100.0, wd: 1e-6, wd_pattern: 'kernel', warmup: 0}
91 | mae_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100.0, wd: 1e-6, warmup: 2500}
92 |
93 | # Actor Critic
94 | actor: {layers: [512, 512, 512, 512], act: elu, norm: none, dist: trunc_normal, min_std: 0.1}
95 | critic: {layers: [512, 512, 512, 512], act: elu, norm: none, dist: mse}
96 | actor_opt: {opt: adam, lr: 1e-4, eps: 1e-5, clip: 100.0, wd: 1e-6, wd_pattern: 'kernel', warmup: 0}
97 | critic_opt: {opt: adam, lr: 1e-4, eps: 1e-5, clip: 100.0, wd: 1e-6, wd_pattern: 'kernel', warmup: 0}
98 | discount: 0.99
99 | discount_lambda: 0.95
100 | imag_horizon: 15
101 | actor_grad: dynamics
102 | actor_grad_mix: 0.1
103 | aent: {scale: 1e-4}
104 | slow_target: True
105 | slow_target_update: 100
106 | slow_target_fraction: 1
107 | slow_baseline: True
108 | reward_norm: {momentum: 0.99, scale: 1.0, eps: 1e-8}
109 | curriculum: {use: True, num_episodes: '100|100|100', neg_lang_prompt: 'reach away from bowl', lang_prompt: 'reach_for_bowl|reach_for_mug|reach_for_jar', objects: 'bowl|mug|jar', synonym_folder: null, num_objects: "3|3|3", num_unique_per_class: "-1|-1|-1"}
110 |
111 | # Plan2Explore
112 | plan2explore: False
113 | expl_intr_scale: 0.5
114 | expl_extr_scale: 0.5
115 | expl_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100, wd: 1e-6}
116 | expl_head: {layers: [512, 512, 512, 512], act: elu, norm: none, dist: mse}
117 | expl_reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8}
118 | disag_target: stoch
119 | disag_log: False
120 | disag_models: 10
121 | disag_offset: 1
122 | disag_action_cond: True
123 | expl_model_loss: kl
124 |
125 | # Rnd
126 | rnd: False
127 | rnd_embedding_dim: 512
128 | rnd_hidden_dim: 256
129 | rnd_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100.0, wd: 1e-6, warmup: 2500}
130 |
131 | front:
132 | camera_keys: image_front
133 | mae.control_input: front
134 |
135 | wrist:
136 | camera_keys: image_wrist
137 | mae.control_input: wrist
138 |
139 | overhead:
140 | camera_keys: image_overhead
141 | mae.control_input: overhead
142 |
143 | front_wrist:
144 | camera_keys: image_front|image_wrist
145 | mae.control_input: front_wrist
146 |
147 | overhead_wrist:
148 | camera_keys: image_overhead|image_wrist
149 | mae.control_input: overhead_wrist
150 |
151 | front_wrist_to_front:
152 | camera_keys: image_front|image_wrist
153 | mae.control_input: front
154 |
155 | front_wrist_to_wrist:
156 | camera_keys: image_front|image_wrist
157 | mae.control_input: wrist
158 |
159 | vlsp:
160 | image_t_size: 1
161 | use_imagenet_mae: False
162 | mae.view_masking: False
163 | mae.depth: 3
164 | mae.decoder_depth: 2
165 | prefill: 200
166 | num_demos: 0
167 | mae_pretrain: 0
168 | pretrain: 0
169 |
170 | ptmae:
171 | mae.img_w_size: 224
172 | mae.img_h_size: 224
173 | wm_flat_vit.img_h_size: 7
174 | wm_flat_vit.img_w_size: 7
175 | mae.state_pred: False
176 | wm_flat_vit.in_chans: 768
177 | wm_flat_vit.embed_dim: 128
178 | mae_avg: True
179 |
180 | debug:
181 | eval_eps: 1
182 | dataset.batch: 8
183 | dataset.length: 10
184 | mae_dataset.batch: 4
185 | mae_dataset.length: 8
186 | mae.depth: 1
187 | mae.decoder_depth: 1
188 | pretrain: 1
189 | mae_pretrain: 1
190 | num_demos: 1
191 | rssm.hidden: 64
192 | rssm.deter: 64
193 | rssm.stoch: 4
194 | rssm.discrete: 4
195 | imag_horizon: 3
196 | jit: False
197 | log_every: 100
198 |
--------------------------------------------------------------------------------
/common/tfutils.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import pickle
3 | import re
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 | from tensorflow.keras import mixed_precision as prec
8 |
9 | try:
10 | from tensorflow.python.distribute import values
11 | except Exception:
12 | from google3.third_party.tensorflow.python.distribute import values
13 |
14 | tf.tensor = tf.convert_to_tensor
15 | for base in (tf.Tensor, tf.Variable, values.PerReplica):
16 | base.mean = tf.math.reduce_mean
17 | base.std = tf.math.reduce_std
18 | base.var = tf.math.reduce_variance
19 | base.sum = tf.math.reduce_sum
20 | base.any = tf.math.reduce_any
21 | base.all = tf.math.reduce_all
22 | base.min = tf.math.reduce_min
23 | base.max = tf.math.reduce_max
24 | base.abs = tf.math.abs
25 | base.logsumexp = tf.math.reduce_logsumexp
26 | base.transpose = tf.transpose
27 | base.reshape = tf.reshape
28 | base.astype = tf.cast
29 |
30 |
31 | # values.PerReplica.dtype = property(lambda self: self.values[0].dtype)
32 |
33 | # tf.TensorHandle.__repr__ = lambda x: ''
34 | # tf.TensorHandle.__str__ = lambda x: ''
35 | # np.set_printoptions(threshold=5, edgeitems=0)
36 |
37 |
38 | class Module(tf.Module):
39 | def save(self, filename, verbose=True):
40 | values = tf.nest.map_structure(lambda x: x.numpy(), self.variables)
41 | amount = len(tf.nest.flatten(values))
42 | count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values)))
43 | if verbose:
44 | print(f"Save checkpoint with {amount} tensors and {count} parameters.")
45 | with pathlib.Path(filename).open("wb") as f:
46 | pickle.dump(values, f)
47 |
48 | def load(self, filename, verbose=True):
49 | with pathlib.Path(filename).open("rb") as f:
50 | values = pickle.load(f)
51 | amount = len(tf.nest.flatten(values))
52 | count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values)))
53 | if verbose:
54 | print(f"Load checkpoint with {amount} tensors and {count} parameters.")
55 | for i in range(len(values)):
56 | # print(f"{i} {self.variables[i].name} self.variables: {self.variables[i].shape}, {values[i].shape}. {values[i].shape == self.variables[i].shape}")
57 | if values[i].shape != self.variables[i].shape:
58 | print(f"{i} {self.variables[i].name} self.variables: {self.variables[i].shape}, {values[i].shape}")
59 |
60 | tf.nest.map_structure(lambda x, y: x.assign(y), self.variables, values)
61 |
62 | def get(self, name, ctor, *args, **kwargs):
63 | # Create or get layer by name to avoid mentioning it in the constructor.
64 | if not hasattr(self, "_modules"):
65 | self._modules = {}
66 | if name not in self._modules:
67 | self._modules[name] = ctor(*args, **kwargs)
68 | return self._modules[name]
69 |
70 |
71 | class Optimizer(tf.Module):
72 | def __init__(
73 | self,
74 | name,
75 | lr,
76 | eps=1e-4,
77 | clip=None,
78 | wd=None,
79 | opt="adam",
80 | warmup=0,
81 | wd_pattern=r".*",
82 | ):
83 | assert 0 <= wd < 1
84 | assert not clip or 1 <= clip
85 | self._name = name
86 | self._clip = clip
87 | self._wd = wd
88 | self._wd_pattern = wd_pattern
89 | self._updates = tf.Variable(0, trainable=False, dtype=tf.int64)
90 | self._lr = lr
91 | if warmup:
92 | self._lr = lambda: lr * tf.clip_by_value(
93 | self._updates.astype(tf.float32) / warmup, 0.0, 1.0
94 | )
95 | self._opt = {
96 | "adam": lambda: tf.optimizers.Adam(self._lr, epsilon=eps),
97 | "nadam": lambda: tf.optimizers.Nadam(self._lr, epsilon=eps),
98 | "adamax": lambda: tf.optimizers.Adamax(self._lr, epsilon=eps),
99 | "sgd": lambda: tf.optimizers.SGD(self._lr),
100 | "momentum": lambda: tf.optimizers.SGD(self._lr, 0.9),
101 | }[opt]()
102 | self._mixed = prec.global_policy().compute_dtype == tf.float16
103 | if self._mixed:
104 | self._opt = prec.LossScaleOptimizer(self._opt, dynamic=True)
105 | self._once = True
106 |
107 | @property
108 | def variables(self):
109 | return self._opt.variables()
110 |
111 | def __call__(self, tape, loss, modules):
112 | assert loss.dtype is tf.float32, (self._name, loss.dtype)
113 | assert len(loss.shape) == 0, (self._name, loss.shape)
114 | metrics = {}
115 |
116 | # Find variables.
117 | modules = modules if hasattr(modules, "__len__") else (modules,)
118 | varibs = tf.nest.flatten([module.variables for module in modules])
119 | count = sum(np.prod(x.shape) for x in varibs)
120 | if self._once:
121 | print(f"Found {count} {self._name} parameters.")
122 | self._once = False
123 |
124 | # Check loss.
125 | tf.debugging.check_numerics(loss, self._name + "_loss")
126 | metrics[f"{self._name}_loss"] = loss
127 |
128 | # Compute scaled gradient.
129 | if self._mixed:
130 | with tape:
131 | loss = self._opt.get_scaled_loss(loss)
132 | grads = tape.gradient(loss, varibs)
133 | if self._mixed:
134 | grads = self._opt.get_unscaled_gradients(grads)
135 | if self._mixed:
136 | metrics[f"{self._name}_loss_scale"] = self._opt.loss_scale
137 |
138 | # Distributed sync.
139 | context = tf.distribute.get_replica_context()
140 | if context:
141 | grads = context.all_reduce("mean", grads)
142 |
143 | # Gradient clipping.
144 | norm = tf.linalg.global_norm(grads)
145 | if not self._mixed:
146 | tf.debugging.check_numerics(norm, self._name + "_norm")
147 | if self._clip:
148 | grads, _ = tf.clip_by_global_norm(grads, self._clip, norm)
149 | metrics[f"{self._name}_grad_norm"] = norm
150 |
151 | # Weight decay.
152 | if self._wd:
153 | self._apply_weight_decay(varibs)
154 |
155 | # Apply gradients.
156 | self._opt.apply_gradients(
157 | zip(grads, varibs), experimental_aggregate_gradients=False
158 | )
159 | self._updates.assign_add(1)
160 |
161 | return metrics
162 |
163 | def _apply_weight_decay(self, varibs):
164 | nontrivial = self._wd_pattern != r".*"
165 | # if nontrivial:
166 | # print("Applied weight decay to variables:")
167 | for var in varibs:
168 | if re.search(self._wd_pattern, self._name + "/" + var.name):
169 | # if nontrivial:
170 | # print("- " + self._name + "/" + var.name)
171 | var.assign((1 - self._wd) * var)
172 |
--------------------------------------------------------------------------------
/common/config.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pathlib
3 | import re
4 |
5 |
6 | class Config(dict):
7 |
8 | SEP = "."
9 | IS_PATTERN = re.compile(r".*[^A-Za-z0-9_.-].*")
10 |
11 | def __init__(self, *args, **kwargs):
12 | mapping = dict(*args, **kwargs)
13 | mapping = self._flatten(mapping)
14 | mapping = self._ensure_keys(mapping)
15 | mapping = self._ensure_values(mapping)
16 | self._flat = mapping
17 | self._nested = self._nest(mapping)
18 | # Need to assign the values to the base class dictionary so that
19 | # conversion to dict does not lose the content.
20 | super().__init__(self._nested)
21 |
22 | @property
23 | def flat(self):
24 | return self._flat.copy()
25 |
26 | def save(self, filename):
27 | filename = pathlib.Path(filename)
28 | if filename.suffix == ".json":
29 | filename.write_text(json.dumps(dict(self)))
30 | elif filename.suffix in (".yml", ".yaml"):
31 | import ruamel.yaml as yaml
32 |
33 | with filename.open("w") as f:
34 | yaml.safe_dump(dict(self), f)
35 | else:
36 | raise NotImplementedError(filename.suffix)
37 |
38 | @classmethod
39 | def load(cls, filename):
40 | filename = pathlib.Path(filename)
41 | if filename.suffix == ".json":
42 | return cls(json.loads(filename.read_text()))
43 | elif filename.suffix in (".yml", ".yaml"):
44 | import ruamel.yaml as yaml
45 |
46 | return cls(yaml.safe_load(filename.read_text()))
47 | else:
48 | raise NotImplementedError(filename.suffix)
49 |
50 | def parse_flags(self, argv=None, known_only=False, help_exists=None):
51 | from . import flags
52 |
53 | return flags.Flags(self).parse(argv, known_only, help_exists)
54 |
55 | def __contains__(self, name):
56 | try:
57 | self[name]
58 | return True
59 | except KeyError:
60 | return False
61 |
62 | def __getattr__(self, name):
63 | if name.startswith("_"):
64 | return super().__getattr__(name)
65 | try:
66 | return self[name]
67 | except KeyError:
68 | raise AttributeError(name)
69 |
70 | def __getitem__(self, name):
71 | result = self._nested
72 | for part in name.split(self.SEP):
73 | result = result[part]
74 | if isinstance(result, dict):
75 | result = type(self)(result)
76 | return result
77 |
78 | def __setattr__(self, key, value):
79 | if key.startswith("_"):
80 | return super().__setattr__(key, value)
81 | message = f"Tried to set key '{key}' on immutable config. Use update()."
82 | raise AttributeError(message)
83 |
84 | def __setitem__(self, key, value):
85 | if key.startswith("_"):
86 | return super().__setitem__(key, value)
87 | message = f"Tried to set key '{key}' on immutable config. Use update()."
88 | raise AttributeError(message)
89 |
90 | def __reduce__(self):
91 | return (type(self), (dict(self),))
92 |
93 | def __str__(self):
94 | lines = ["\nConfig:"]
95 | keys, vals, typs = [], [], []
96 | for key, val in self.flat.items():
97 | keys.append(key + ":")
98 | vals.append(self._format_value(val))
99 | typs.append(self._format_type(val))
100 | max_key = max(len(k) for k in keys) if keys else 0
101 | max_val = max(len(v) for v in vals) if vals else 0
102 | for key, val, typ in zip(keys, vals, typs):
103 | key = key.ljust(max_key)
104 | val = val.ljust(max_val)
105 | lines.append(f"{key} {val} ({typ})")
106 | return "\n".join(lines)
107 |
108 | def update(self, *args, **kwargs):
109 | result = self._flat.copy()
110 | inputs = self._flatten(dict(*args, **kwargs))
111 | for key, new in inputs.items():
112 | if self.IS_PATTERN.match(key):
113 | pattern = re.compile(key)
114 | keys = {k for k in result if pattern.match(k)}
115 | else:
116 | keys = [key]
117 | if not keys:
118 | raise KeyError(f"Unknown key or pattern {key}.")
119 | for key in keys:
120 | if key not in result:
121 | result[key] = new
122 | continue
123 | old = result[key]
124 | try:
125 | if isinstance(old, int) and isinstance(new, float):
126 | if float(int(new)) != new:
127 | message = f"Cannot convert fractional float {new} to int."
128 | raise ValueError(message)
129 | if old is None:
130 | result[key] = str(new)
131 | else:
132 | result[key] = type(old)(new)
133 | except (ValueError, TypeError):
134 | raise TypeError(
135 | f"Cannot convert '{new}' to type '{type(old).__name__}' "
136 | + f"of value '{old}' for key '{key}'."
137 | )
138 | return type(self)(result)
139 |
140 | def _flatten(self, mapping):
141 | result = {}
142 | for key, value in mapping.items():
143 | if isinstance(value, dict):
144 | for k, v in self._flatten(value).items():
145 | if self.IS_PATTERN.match(key) or self.IS_PATTERN.match(k):
146 | combined = f"{key}\\{self.SEP}{k}"
147 | else:
148 | combined = f"{key}{self.SEP}{k}"
149 | result[combined] = v
150 | else:
151 | result[key] = value
152 | return result
153 |
154 | def _nest(self, mapping):
155 | result = {}
156 | for key, value in mapping.items():
157 | parts = key.split(self.SEP)
158 | node = result
159 | for part in parts[:-1]:
160 | if part not in node:
161 | node[part] = {}
162 | node = node[part]
163 | node[parts[-1]] = value
164 | return result
165 |
166 | def _ensure_keys(self, mapping):
167 | for key in mapping:
168 | assert not self.IS_PATTERN.match(key), key
169 | return mapping
170 |
171 | def _ensure_values(self, mapping):
172 | result = json.loads(json.dumps(mapping))
173 | for key, value in result.items():
174 | if isinstance(value, list):
175 | value = tuple(value)
176 | if isinstance(value, tuple):
177 | if len(value) == 0:
178 | message = (
179 | "Empty lists are disallowed because their type is unclear."
180 | )
181 | raise TypeError(message)
182 | if not isinstance(value[0], (str, float, int, bool)):
183 | message = "Lists can only contain strings, floats, ints, bools"
184 | message += f" but not {type(value[0])}"
185 | raise TypeError(message)
186 | if not all(isinstance(x, type(value[0])) for x in value[1:]):
187 | message = "Elements of a list must all be of the same type."
188 | raise TypeError(message)
189 | result[key] = value
190 | return result
191 |
192 | def _format_value(self, value):
193 | if isinstance(value, (list, tuple)):
194 | return "[" + ", ".join(self._format_value(x) for x in value) + "]"
195 | return str(value)
196 |
197 | def _format_type(self, value):
198 | if isinstance(value, (list, tuple)):
199 | assert len(value) > 0, value
200 | return self._format_type(value[0]) + "s"
201 | return str(type(value).__name__)
202 |
--------------------------------------------------------------------------------
/common/kitchen.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import gym
3 | import random
4 | import itertools
5 | from itertools import combinations
6 | from common.base_envs import BenchEnv
7 | from d4rl.kitchen.kitchen_envs import KitchenMicrowaveKettleBottomBurnerLightV0
8 | # import d4rl.kitchen.kitchen_envs
9 |
10 |
11 | class KitchenEnv(BenchEnv):
12 | def __init__(self, task, action_repeat=1, use_goal_idx=False, log_per_goal=False, control_mode='end_effector', width=64):
13 | # currently, task is not used
14 | super().__init__(action_repeat, width)
15 | self.use_goal_idx = use_goal_idx
16 | self.log_per_goal = log_per_goal
17 | with self.LOCK:
18 | self._env = KitchenMicrowaveKettleBottomBurnerLightV0()
19 |
20 | self._env.sim_robot.renderer._camera_settings = dict(
21 | distance=1.86, lookat=[-0.3, .5, 2.], azimuth=90, elevation=-60)
22 |
23 | self.rendered_goal = False
24 | self._env.reset()
25 | self.init_qpos = self._env.sim.data.qpos.copy()
26 | self.goal_idx = 0
27 | self.obs_element_goals, self.obs_element_indices, self.goal_configs = get_kitchen_benchmark_goals()
28 | self.goals = list(range(len(self.obs_element_goals)))
29 |
30 | @property
31 | def act_space(self):
32 | return {"action": self.action_space}
33 |
34 | @property
35 | def obs_space(self):
36 | return self.observation_space
37 |
38 | def set_goal_idx(self, idx):
39 | self.goal_idx = idx
40 |
41 | def get_goal_idx(self):
42 | return self.goal_idx
43 |
44 | def get_goals(self):
45 | return self.goals
46 |
47 | def _get_obs(self, state):
48 | image = self._env.render('rgb_array')
49 | obs = {'image': image, 'state': state, 'image_goal': self.render_goal(), 'goal': self.goal}
50 | if self.log_per_goal:
51 | for i, goal_idx in enumerate(self.goals):
52 | # add rewards for all goals
53 | task_rel_success, all_obj_success = self.compute_success(goal_idx)
54 | obs['metric_success_task_relevant/goal_'+str(goal_idx)] = task_rel_success
55 | obs['metric_success_all_objects/goal_'+str(goal_idx)] = all_obj_success
56 | if self.use_goal_idx:
57 | task_rel_success, all_obj_success = self.compute_success(self.goal_idx)
58 | obs['metric_success_task_relevant/goal_'+str(self.goal_idx)] = task_rel_success
59 | obs['metric_success_all_objects/goal_'+str(self.goal_idx)] = all_obj_success
60 |
61 | return obs
62 |
63 | def reset(self):
64 |
65 | with self.LOCK:
66 | state = self._env.reset()
67 | if not self.use_goal_idx:
68 | self.goal_idx = np.random.randint(len(self.goals))
69 | self.goal = self.goals[self.goal_idx]
70 | self.rendered_goal = False
71 | obs = self._get_obs(state)
72 | obs['state'] = obs['state'].astype(np.float32)
73 | obs["is_last"] = False
74 | obs["is_first"] = True
75 | obs["reward"] = 0 # not sure if this is good
76 | obs["success"] = 1 # hard-code for now?
77 | obs["is_terminal"] = False
78 | obs["lang_num"] = 0 # change in future for multitask
79 | self._init_vidlang_time_step = obs
80 | obs['init_image'] = obs['image']
81 | obs['init_state'] = obs['state']
82 | # print(type(obs['state']), obs['state'].dtype)
83 | # breakpoint()
84 | return obs
85 |
86 | def step(self, action):
87 | total_reward = 0.0
88 | for step in range(self._action_repeat):
89 | state, reward, done, info = self._env.step(action['action'])
90 | reward = self.compute_reward()
91 | total_reward += reward
92 | if done:
93 | break
94 | obs = self._get_obs(state)
95 | for k, v in obs.items():
96 | if 'metric_' in k:
97 | info[k] = v
98 | obs["is_last"] = done
99 | obs["is_terminal"] = done
100 | obs["is_first"] = False
101 | obs["reward"] = total_reward
102 | obs["success"] = 1 # hard-code for now?
103 | obs["lang_num"] = 0 # change in future for multitask
104 | obs['state'] = obs['state'].astype(np.float32)
105 | obs['init_image'] = self._init_vidlang_time_step['init_image']
106 | obs['init_state'] = self._init_vidlang_time_step['init_state']
107 | return obs #, total_reward, done, info
108 |
109 | def compute_reward(self, goal=None):
110 | if goal is None:
111 | goal = self.goal
112 | qpos = self._env.sim.data.qpos.copy()
113 |
114 | if len(self.obs_element_indices[goal]) > 9 :
115 | return -np.linalg.norm(qpos[self.obs_element_indices[goal]][9:] - self.obs_element_goals[goal][9:])
116 | else:
117 | return -np.linalg.norm(qpos[self.obs_element_indices[goal]] - self.obs_element_goals[goal])
118 |
119 | def compute_success(self, goal = None):
120 |
121 | if goal is None:
122 | goal = self.goal
123 | qpos = self._env.sim.data.qpos.copy()
124 |
125 | goal_qpos = self.init_qpos.copy()
126 | goal_qpos[self.obs_element_indices[goal]] = self.obs_element_goals[goal]
127 |
128 | per_obj_success = {
129 | 'bottom_burner' : ((qpos[9]<-0.38) and (goal_qpos[9]<-0.38)) or ((qpos[9]>-0.38) and (goal_qpos[9]>-0.38)),
130 | 'top_burner': ((qpos[13]<-0.38) and (goal_qpos[13]<-0.38)) or ((qpos[13]>-0.38) and (goal_qpos[13]>-0.38)),
131 | 'light_switch': ((qpos[17]<-0.25) and (goal_qpos[17]<-0.25)) or ((qpos[17]>-0.25) and (goal_qpos[17]>-0.25)),
132 | 'slide_cabinet' : abs(qpos[19] - goal_qpos[19])<0.1,
133 | 'hinge_cabinet' : abs(qpos[21] - goal_qpos[21])<0.2,
134 | 'microwave' : abs(qpos[22] - goal_qpos[22])<0.2,
135 | 'kettle' : np.linalg.norm(qpos[23:25] - goal_qpos[23:25]) < 0.2
136 | }
137 | task_objects = self.goal_configs[goal]
138 |
139 | task_rel_success = 1
140 | for _obj in task_objects:
141 | task_rel_success *= per_obj_success[_obj]
142 |
143 | all_obj_success = 1
144 | for _obj in per_obj_success:
145 | all_obj_success *= per_obj_success[_obj]
146 |
147 | return int(task_rel_success), int(all_obj_success)
148 |
149 | def render_goal(self):
150 | if self.rendered_goal:
151 | return self.rendered_goal_obj
152 |
153 | # random.sample(list(obs_element_goals), 1)[0]
154 | backup_qpos = self._env.sim.data.qpos.copy()
155 | backup_qvel = self._env.sim.data.qvel.copy()
156 |
157 | qpos = self.init_qpos.copy()
158 | qpos[self.obs_element_indices[self.goal]] = self.obs_element_goals[self.goal]
159 | self._env.set_state(qpos, np.zeros(len(self._env.init_qvel)))
160 |
161 | goal_obs = self._env.render('rgb_array')
162 |
163 | self._env.set_state(backup_qpos, backup_qvel)
164 |
165 | self.rendered_goal = True
166 | self.rendered_goal_obj = goal_obs
167 | return goal_obs
168 |
169 | def get_kitchen_benchmark_goals():
170 |
171 | object_goal_vals = {'bottom_burner' : [-0.88, -0.01],
172 | 'light_switch' : [ -0.69, -0.05],
173 | 'slide_cabinet': [0.37],
174 | 'hinge_cabinet': [0., 0.5],
175 | 'microwave' : [-0.5],
176 | 'kettle' : [-0.23, 0.75, 1.62]}
177 |
178 | object_goal_idxs = {'bottom_burner' : [9, 10],
179 | 'light_switch' : [17, 18],
180 | 'slide_cabinet': [19],
181 | 'hinge_cabinet': [20, 21],
182 | 'microwave' : [22],
183 | 'kettle' : [23, 24, 25]}
184 |
185 | base_task_names = [ 'bottom_burner', 'light_switch', 'slide_cabinet',
186 | 'hinge_cabinet', 'microwave', 'kettle' ]
187 |
188 |
189 | goal_configs = []
190 | #single task
191 | for i in range(6):
192 | goal_configs.append( [base_task_names[i]])
193 |
194 | #two tasks
195 | for i,j in combinations([1,2,3,5], 2) :
196 | goal_configs.append( [base_task_names[i], base_task_names[j]] )
197 |
198 | obs_element_goals = [] ; obs_element_indices = []
199 | for objects in goal_configs:
200 | _goal = np.concatenate([object_goal_vals[obj] for obj in objects])
201 | _goal_idxs = np.concatenate([object_goal_idxs[obj] for obj in objects])
202 |
203 | obs_element_goals.append(_goal)
204 | obs_element_indices.append(_goal_idxs)
205 |
206 | return obs_element_goals, obs_element_indices, goal_configs
--------------------------------------------------------------------------------
/common/other.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import contextlib
3 | import re
4 | import time
5 |
6 | import numpy as np
7 | import tensorflow as tf
8 | from tensorflow_probability import distributions as tfd
9 |
10 | from . import dists
11 | from . import tfutils
12 |
13 |
14 | class ScriptedAgent:
15 | def __init__(self, act_space, eval_env, corner = 'top_right'):
16 | self.act_space = act_space["action"]
17 | offset_xy = 0.15
18 | offset_z = 0.05
19 | if corner == 'top_left':
20 | self.target = [offset_xy,offset_xy,offset_z]
21 | elif corner == 'top_right':
22 | self.target = [offset_xy,-offset_xy,offset_z]
23 | elif corner == 'bottom_left':
24 | self.target = [-offset_xy,offset_xy,offset_z]
25 | elif corner == 'bottom_right':
26 | self.target = [-offset_xy,-offset_xy,offset_z]
27 | self.eval_env = eval_env
28 | print("low: ", self.act_space.low)
29 | print("high: ", self.act_space.high)
30 |
31 | def __call__(self, obs, state=None, mode=None):
32 | print(obs['state'])
33 | action = np.zeros(self.act_space.shape)[None]
34 | action[:,3] = 1. # Keep gripper open
35 |
36 | container_pos = self.eval_env.access("container_pos")()
37 | print("container_pos: ", container_pos)
38 |
39 | distance = container_pos + self.target - obs['state'][0][1:4]
40 | print("distance: ", distance)
41 | action[:,:3] = (distance - self.act_space.low[:3])/(self.act_space.high[:3] - self.act_space.low[:3]) * 2 - 1
42 | action = np.clip(action, -np.ones_like(self.act_space.low), np.ones_like(self.act_space.high))
43 |
44 | output = {"action": action}
45 |
46 | print("action: ", action)
47 |
48 | return output, None
49 |
50 | class RandomAgent:
51 | def __init__(self, act_space, logprob=False):
52 | self.act_space = act_space["action"]
53 | self.logprob = logprob
54 | if hasattr(self.act_space, "n"):
55 | self._dist = dists.OneHotDist(tf.zeros(self.act_space.n))
56 | else:
57 | dist = tfd.Uniform(self.act_space.low, self.act_space.high)
58 | self._dist = tfd.Independent(dist, 1)
59 |
60 | def __call__(self, obs, state=None, mode=None):
61 | action = self._dist.sample(len(obs["is_first"]))
62 | output = {"action": action}
63 | if self.logprob:
64 | output["logprob"] = self._dist.log_prob(action)
65 | return output, None
66 |
67 |
68 | def static_scan(fn, inputs, start, reverse=False):
69 | last = start
70 | outputs = [[] for _ in tf.nest.flatten(start)]
71 | indices = range(tf.nest.flatten(inputs)[0].shape[0])
72 | if reverse:
73 | indices = reversed(indices)
74 | for index in indices:
75 | inp = tf.nest.map_structure(lambda x: x[index], inputs)
76 | last = fn(last, inp)
77 | [o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))]
78 | if reverse:
79 | outputs = [list(reversed(x)) for x in outputs]
80 | outputs = [tf.stack(x, 0) for x in outputs]
81 | return tf.nest.pack_sequence_as(start, outputs)
82 |
83 |
84 | def schedule(string, step):
85 | try:
86 | return float(string)
87 | except ValueError:
88 | step = tf.cast(step, tf.float32)
89 | match = re.match(r"linear\((.+),(.+),(.+)\)", string)
90 | if match:
91 | initial, final, duration = [float(group) for group in match.groups()]
92 | mix = tf.clip_by_value(step / duration, 0, 1)
93 | return (1 - mix) * initial + mix * final
94 | match = re.match(r"warmup\((.+),(.+)\)", string)
95 | if match:
96 | warmup, value = [float(group) for group in match.groups()]
97 | scale = tf.clip_by_value(step / warmup, 0, 1)
98 | return scale * value
99 | match = re.match(r"exp\((.+),(.+),(.+)\)", string)
100 | if match:
101 | initial, final, halflife = [float(group) for group in match.groups()]
102 | return (initial - final) * 0.5 ** (step / halflife) + final
103 | match = re.match(r"horizon\((.+),(.+),(.+)\)", string)
104 | if match:
105 | initial, final, duration = [float(group) for group in match.groups()]
106 | mix = tf.clip_by_value(step / duration, 0, 1)
107 | horizon = (1 - mix) * initial + mix * final
108 | return 1 - 1 / horizon
109 | raise NotImplementedError(string)
110 |
111 |
112 | def lambda_return(reward, value, pcont, bootstrap, lambda_, axis):
113 | # Setting lambda=1 gives a discounted Monte Carlo return.
114 | # Setting lambda=0 gives a fixed 1-step return.
115 | assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape)
116 | if isinstance(pcont, (int, float)):
117 | pcont = pcont * tf.ones_like(reward)
118 | dims = list(range(reward.shape.ndims))
119 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1 :]
120 | if axis != 0:
121 | reward = tf.transpose(reward, dims)
122 | value = tf.transpose(value, dims)
123 | pcont = tf.transpose(pcont, dims)
124 | if bootstrap is None:
125 | bootstrap = tf.zeros_like(value[-1])
126 | next_values = tf.concat([value[1:], bootstrap[None]], 0)
127 | inputs = reward + pcont * next_values * (1 - lambda_)
128 | returns = static_scan(
129 | lambda agg, cur: cur[0] + cur[1] * lambda_ * agg,
130 | (inputs, pcont),
131 | bootstrap,
132 | reverse=True,
133 | )
134 | if axis != 0:
135 | returns = tf.transpose(returns, dims)
136 | return returns
137 |
138 |
139 | def action_noise(action, amount, act_space):
140 | if amount == 0:
141 | return action
142 | amount = tf.cast(amount, action.dtype)
143 | if hasattr(act_space, "n"):
144 | probs = amount / action.shape[-1] + (1 - amount) * action
145 | return dists.OneHotDist(probs=probs).sample()
146 | else:
147 | return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1)
148 |
149 |
150 | class StreamNorm(tfutils.Module):
151 | def __init__(self, shape=(), momentum=0.99, scale=1.0, eps=1e-8):
152 | # Momentum of 0 normalizes only based on the current batch.
153 | # Momentum of 1 disables normalization.
154 | self._shape = tuple(shape)
155 | self._momentum = momentum
156 | self._scale = scale
157 | self._eps = eps
158 | self.mag = tf.Variable(tf.ones(shape, tf.float64), False)
159 |
160 | def __call__(self, inputs):
161 | metrics = {}
162 | self.update(inputs)
163 | metrics["mean"] = inputs.mean()
164 | metrics["std"] = inputs.std()
165 | outputs = self.transform(inputs)
166 | metrics["normed_mean"] = outputs.mean()
167 | metrics["normed_std"] = outputs.std()
168 | return outputs, metrics
169 |
170 | def reset(self):
171 | self.mag.assign(tf.ones_like(self.mag))
172 |
173 | def update(self, inputs):
174 | batch = inputs.reshape((-1,) + self._shape)
175 | mag = tf.abs(batch).mean(0).astype(tf.float64)
176 | self.mag.assign(self._momentum * self.mag + (1 - self._momentum) * mag)
177 |
178 | def transform(self, inputs):
179 | values = inputs.reshape((-1,) + self._shape)
180 | values /= self.mag.astype(inputs.dtype)[None] + self._eps
181 | values *= self._scale
182 | return values.reshape(inputs.shape)
183 |
184 |
185 | class Timer:
186 | def __init__(self):
187 | self._indurs = collections.defaultdict(list)
188 | self._outdurs = collections.defaultdict(list)
189 | self._start_times = {}
190 | self._end_times = {}
191 |
192 | @contextlib.contextmanager
193 | def section(self, name):
194 | self.start(name)
195 | yield
196 | self.end(name)
197 |
198 | def wrap(self, function, name):
199 | def wrapped(*args, **kwargs):
200 | with self.section(name):
201 | return function(*args, **kwargs)
202 |
203 | return wrapped
204 |
205 | def start(self, name):
206 | now = time.time()
207 | self._start_times[name] = now
208 | if name in self._end_times:
209 | last = self._end_times[name]
210 | self._outdurs[name].append(now - last)
211 |
212 | def end(self, name):
213 | now = time.time()
214 | self._end_times[name] = now
215 | self._indurs[name].append(now - self._start_times[name])
216 |
217 | def result(self):
218 | metrics = {}
219 | for key in self._indurs:
220 | indurs = self._indurs[key]
221 | outdurs = self._outdurs[key]
222 | metrics[f"timer_count_{key}"] = len(indurs)
223 | metrics[f"timer_inside_{key}"] = np.sum(indurs)
224 | metrics[f"timer_outside_{key}"] = np.sum(outdurs)
225 | indurs.clear()
226 | outdurs.clear()
227 | return metrics
228 |
229 |
230 | class CarryOverState:
231 | def __init__(self, fn):
232 | self._fn = fn
233 | self._state = None
234 |
235 | def __call__(self, *args):
236 | self._state, out = self._fn(*args, self._state)
237 | return out
238 |
--------------------------------------------------------------------------------
/common/replay.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import datetime
3 | import io
4 | import pathlib
5 | import uuid
6 |
7 | import numpy as np
8 | import tensorflow as tf
9 |
10 | import albumentations as A
11 |
12 |
13 | class Replay:
14 | def __init__(
15 | self,
16 | directory,
17 | load_directory=None,
18 | capacity=0,
19 | minlen=1,
20 | maxlen=0,
21 | prioritize_ends=False,
22 | ):
23 | self._directory = pathlib.Path(directory).expanduser()
24 | self._directory.mkdir(parents=True, exist_ok=True)
25 | self._capacity = capacity
26 | self._minlen = minlen
27 | self._maxlen = maxlen
28 | self._prioritize_ends = prioritize_ends
29 | self._random = np.random.RandomState()
30 |
31 | self.load_directory = load_directory
32 | if load_directory is None:
33 | load_directory = self._directory
34 | else:
35 | load_directory = pathlib.Path(load_directory).expanduser()
36 | # filename -> key -> value_sequence
37 | self._complete_eps = load_episodes(load_directory, capacity, minlen)
38 | if len(self._complete_eps) != 0:
39 | self._eps_keys = list(self._complete_eps.keys())
40 | idxs = np.sort(
41 | np.argsort(np.random.uniform(size=len(self._eps_keys)))[
42 | : int(np.ceil(len(self._eps_keys) / 2))
43 | ]
44 | )
45 | else:
46 | self._eps_keys = []
47 | self._eps_masks = dict()
48 | # worker -> key -> value_sequence
49 | self._ongoing_eps = collections.defaultdict(
50 | lambda: collections.defaultdict(list)
51 | )
52 | self._total_episodes, self._total_steps = count_episodes(directory)
53 | self._loaded_episodes = len(self._complete_eps)
54 | self._loaded_steps = sum(eplen(x) for x in self._complete_eps.values())
55 | self.reward_func = None
56 |
57 | @property
58 | def stats(self):
59 | return {
60 | "total_steps": self._total_steps,
61 | "total_episodes": self._total_episodes,
62 | "loaded_steps": self._loaded_steps,
63 | "loaded_episodes": self._loaded_episodes,
64 | }
65 |
66 | def reward_relabel(self, reward_func):
67 | self.reward_func = reward_func
68 |
69 | def add_step(self, transition, worker=0):
70 | episode = self._ongoing_eps[worker]
71 | for key, value in transition.items():
72 | episode[key].append(value)
73 | if transition["is_last"]:
74 | self.add_episode(episode)
75 | episode.clear()
76 |
77 | def add_episode(self, episode):
78 | if 'lang_num' not in episode.keys() or len(episode['lang_num']) != len(episode['reward']):
79 | # running into errors when running pure RL, no multi-task learning
80 | # there's likely a better way to fix, but couldn't find it
81 | episode['lang_num'] = [0] * len(episode['reward'])
82 | length = eplen(episode)
83 | if length < self._minlen:
84 | print(f"Skipping short episode of length {length}.")
85 | return
86 | self._total_steps += length
87 | self._loaded_steps += length
88 | self._total_episodes += 1
89 | self._loaded_episodes += 1
90 | episode = {key: convert(value) for key, value in episode.items()}
91 | if self.reward_func is not None:
92 | vidlang_reward = self.reward_func(episode)
93 | episode["reward"] = vidlang_reward
94 | filename = save_episode(self._directory, episode)
95 | self._complete_eps[str(filename)] = episode
96 |
97 | self._eps_keys.append(str(filename))
98 | self._enforce_limit()
99 |
100 | def add_demo_episode(self, episode):
101 | success = np.sum(episode["success"]) >= 1.0
102 | if success:
103 | self.add_episode(episode)
104 |
105 | def dataset(self, batch, length):
106 | example = next(iter(self._generate_chunks(length)))
107 | dataset = tf.data.Dataset.from_generator(
108 | lambda: self._generate_chunks(length),
109 | {k: v.dtype for k, v in example.items()},
110 | {k: v.shape for k, v in example.items()},
111 | )
112 | dataset = dataset.batch(batch, drop_remainder=True)
113 | dataset = dataset.prefetch(5)
114 | return dataset
115 |
116 | def _generate_chunks(self, length):
117 | sequence = self._sample_sequence()
118 | while True:
119 | chunk = collections.defaultdict(list)
120 | added = 0
121 | while added < length:
122 | needed = length - added
123 | adding = {k: v[:needed] for k, v in sequence.items()}
124 | sequence = {k: v[needed:] for k, v in sequence.items()}
125 | for key, value in adding.items():
126 | chunk[key].append(value)
127 | added += len(adding["reward"])
128 | if len(sequence["reward"]) < 1:
129 | sequence = self._sample_sequence()
130 | chunk = {k: np.concatenate(v) for k, v in chunk.items()}
131 | yield chunk
132 |
133 | def _sample_sequence(self):
134 | eps_keys = self._eps_keys
135 | L = len(eps_keys)
136 | i = np.random.randint(0, L)
137 | episode_key = eps_keys[i]
138 | episode = self._complete_eps[episode_key]
139 |
140 | total = len(episode["reward"])
141 | length = total
142 | if self._maxlen:
143 | length = min(length, self._maxlen)
144 | # Randomize length to avoid all chunks ending at the same time in case the
145 | # episodes are all of the same length.
146 | length -= np.random.randint(self._minlen)
147 | length = max(self._minlen, length)
148 | upper = total - length + 1
149 | if self._prioritize_ends:
150 | upper += self._minlen
151 | index = min(self._random.randint(upper), total - length)
152 | sequence = {
153 | k: convert(v[index : index + length])
154 | for k, v in episode.items()
155 | if not k.startswith("log_")
156 | }
157 | sequence["is_first"] = np.zeros(len(sequence["reward"]), bool)
158 | sequence["is_first"][0] = True
159 | if self._maxlen:
160 | assert self._minlen <= len(sequence["reward"]) <= self._maxlen
161 | return sequence
162 |
163 | def _enforce_limit(self):
164 | if not self._capacity:
165 | return
166 | while self._loaded_episodes > 1 and self._loaded_steps > self._capacity:
167 | # Relying on Python preserving the insertion order of dicts.
168 | oldest, episode = next(iter(self._complete_eps.items()))
169 | self._loaded_steps -= eplen(episode)
170 | self._loaded_episodes -= 1
171 | del self._complete_eps[oldest]
172 |
173 |
174 | def count_episodes(directory):
175 | filenames = list(directory.glob("*.npz"))
176 | num_episodes = len(filenames)
177 | num_steps = sum(int(str(n).split("-")[-1][:-4]) - 1 for n in filenames)
178 | return num_episodes, num_steps
179 |
180 |
181 | def save_episode(directory, episode):
182 | timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
183 | identifier = str(uuid.uuid4().hex)
184 | length = eplen(episode)
185 | filename = directory / f"{timestamp}-{identifier}-{length}.npz"
186 | with io.BytesIO() as f1:
187 | np.savez_compressed(f1, **episode)
188 | f1.seek(0)
189 | with filename.open("wb") as f2:
190 | f2.write(f1.read())
191 | return filename
192 |
193 |
194 | def load_episodes(directory, capacity=None, minlen=1):
195 | # The returned directory from filenames to episodes is guaranteed to be in
196 | # temporally sorted order.
197 | filenames = sorted(directory.glob("*.npz"))
198 | if capacity:
199 | num_steps = 0
200 | num_episodes = 0
201 | for filename in reversed(filenames):
202 | length = int(str(filename).split("-")[-1][:-4])
203 | num_steps += length
204 | num_episodes += 1
205 | if num_steps >= capacity:
206 | break
207 | filenames = filenames[-num_episodes:]
208 | episodes = {}
209 | for filename in filenames:
210 | try:
211 | with filename.open("rb") as f:
212 | episode = np.load(f)
213 | episode = {k: episode[k] for k in episode.keys()}
214 | except Exception as e:
215 | print(f"Could not load episode {str(filename)}: {e}")
216 | continue
217 | episodes[str(filename)] = episode
218 | return episodes
219 |
220 |
221 | def convert(value):
222 | value = np.array(value)
223 | if np.issubdtype(value.dtype, np.floating):
224 | return value.astype(np.float32)
225 | elif np.issubdtype(value.dtype, np.signedinteger):
226 | return value.astype(np.int32)
227 | elif np.issubdtype(value.dtype, np.uint8):
228 | return value.astype(np.uint8)
229 | return value
230 |
231 |
232 | def eplen(episode):
233 | return len(episode["reward"]) - 1
234 |
--------------------------------------------------------------------------------
/common/nets.py:
--------------------------------------------------------------------------------
1 | import re
2 | import functools
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 | from tensorflow.keras import layers as tfkl
7 | from tensorflow.keras import initializers as tfki
8 | from tensorflow_probability import distributions as tfd
9 | from tensorflow.keras.mixed_precision import experimental as prec
10 |
11 | import common
12 |
13 |
14 | class RSSM(common.Module):
15 | def __init__(
16 | self,
17 | action_free=False,
18 | stoch=30,
19 | deter=200,
20 | hidden=200,
21 | discrete=False,
22 | act="elu",
23 | norm="none",
24 | std_act="softplus",
25 | min_std=0.1,
26 | ):
27 | super().__init__()
28 | self._action_free = action_free
29 | self._stoch = stoch
30 | self._deter = deter
31 | self._hidden = hidden
32 | self._discrete = discrete
33 | self._act = get_act(act)
34 | self._norm = norm
35 | self._std_act = std_act
36 | self._min_std = min_std
37 | self._cell = GRUCell(self._deter, norm=True)
38 | self._cast = lambda x: tf.cast(x, prec.global_policy().compute_dtype)
39 |
40 | def initial(self, batch_size):
41 | dtype = prec.global_policy().compute_dtype
42 | if self._discrete:
43 | state = dict(
44 | logit=tf.zeros([batch_size, self._stoch, self._discrete], dtype),
45 | stoch=tf.zeros([batch_size, self._stoch, self._discrete], dtype),
46 | deter=self._cell.get_initial_state(None, batch_size, dtype),
47 | )
48 | else:
49 | state = dict(
50 | mean=tf.zeros([batch_size, self._stoch], dtype),
51 | std=tf.zeros([batch_size, self._stoch], dtype),
52 | stoch=tf.zeros([batch_size, self._stoch], dtype),
53 | deter=self._cell.get_initial_state(None, batch_size, dtype),
54 | )
55 | return state
56 |
57 | def fill_action_with_zero(self, action):
58 | # action: [B, action]
59 | B, D = action.shape[0], action.shape[1]
60 | if self._action_free:
61 | return self._cast(tf.zeros([B, 50]))
62 | else:
63 | zeros = self._cast(tf.zeros([B, 50 - D]))
64 | return tf.concat([action, zeros], axis=1)
65 |
66 | @tf.function
67 | def observe(self, embed, action, is_first, state=None):
68 | swap = lambda x: tf.transpose(x, [1, 0] + list(range(2, len(x.shape))))
69 | if state is None:
70 | state = self.initial(tf.shape(action)[0])
71 | post, prior = common.static_scan(
72 | lambda prev, inputs: self.obs_step(prev[0], *inputs),
73 | (swap(action), swap(embed), swap(is_first)),
74 | (state, state),
75 | )
76 | post = {k: swap(v) for k, v in post.items()}
77 | prior = {k: swap(v) for k, v in prior.items()}
78 | return post, prior
79 |
80 | @tf.function
81 | def imagine(self, action, state=None):
82 | swap = lambda x: tf.transpose(x, [1, 0] + list(range(2, len(x.shape))))
83 | if state is None:
84 | state = self.initial(tf.shape(action)[0])
85 | assert isinstance(state, dict), state
86 | action = swap(action)
87 | prior = common.static_scan(self.img_step, action, state)
88 | prior = {k: swap(v) for k, v in prior.items()}
89 | return prior
90 |
91 | def get_feat(self, state):
92 | stoch = self._cast(state["stoch"])
93 | if self._discrete:
94 | shape = stoch.shape[:-2] + [self._stoch * self._discrete]
95 | stoch = tf.reshape(stoch, shape)
96 | return tf.concat([stoch, state[f"deter"]], -1)
97 |
98 | def get_dist(self, state):
99 | if self._discrete:
100 | logit = state["logit"]
101 | logit = tf.cast(logit, tf.float32)
102 | dist = tfd.Independent(common.OneHotDist(logit), 1)
103 | else:
104 | mean, std = state["mean"], state["std"]
105 | mean = tf.cast(mean, tf.float32)
106 | std = tf.cast(std, tf.float32)
107 | dist = tfd.MultivariateNormalDiag(mean, std)
108 | return dist
109 |
110 | @tf.function
111 | def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
112 | # if is_first.any():
113 | prev_state, prev_action = tf.nest.map_structure(
114 | lambda x: tf.einsum("b,b...->b...", 1.0 - is_first.astype(x.dtype), x),
115 | (prev_state, prev_action),
116 | )
117 | prior = self.img_step(prev_state, prev_action, sample)
118 | x = tf.concat([prior[f"deter"], embed], -1)
119 | x = self.get("obs_out", tfkl.Dense, self._hidden)(x)
120 | x = self.get("obs_out_norm", NormLayer, self._norm)(x)
121 | x = self._act(x)
122 | stats = self._suff_stats_layer("obs_dist", x)
123 | dist = self.get_dist(stats)
124 | stoch = dist.sample() if sample else dist.mode()
125 | post = {"stoch": stoch, "deter": prior[f"deter"], **stats}
126 | return post, prior
127 |
128 | @tf.function
129 | def img_step(self, prev_state, prev_action, sample=True):
130 | prev_stoch = self._cast(prev_state["stoch"])
131 | prev_action = self._cast(prev_action)
132 | if self._discrete:
133 | shape = prev_stoch.shape[:-2] + [self._stoch * self._discrete]
134 | prev_stoch = tf.reshape(prev_stoch, shape)
135 | x = tf.concat([prev_stoch, self.fill_action_with_zero(prev_action)], -1)
136 | x = self.get("img_in", tfkl.Dense, self._hidden)(x)
137 | x = self.get("img_in_norm", NormLayer, self._norm)(x)
138 | x, deter = self._cell(x, [prev_state[f"deter"]])
139 | deter = deter[0]
140 | x = self.get("img_out", tfkl.Dense, self._hidden)(x)
141 | x = self.get("img_out_norm", NormLayer, self._norm)(x)
142 | x = self._act(x)
143 | stats = self._suff_stats_layer(f"img_dist", x)
144 | dist = self.get_dist(stats)
145 | stoch = dist.sample() if sample else dist.mode()
146 | prior = {"stoch": stoch, "deter": deter, **stats}
147 | return prior
148 |
149 | def _suff_stats_layer(self, name, x):
150 | if self._discrete:
151 | x = self.get(name, tfkl.Dense, self._stoch * self._discrete, None)(x)
152 | logit = tf.reshape(x, x.shape[:-1] + [self._stoch, self._discrete])
153 | return {"logit": logit}
154 | else:
155 | x = self.get(name, tfkl.Dense, 2 * self._stoch, None)(x)
156 | mean, std = tf.split(x, 2, -1)
157 | std = {
158 | "softplus": lambda: tf.nn.softplus(std),
159 | "sigmoid": lambda: tf.nn.sigmoid(std),
160 | "sigmoid2": lambda: 2 * tf.nn.sigmoid(std / 2),
161 | }[self._std_act]()
162 | std = std + self._min_std
163 | return {"mean": mean, "std": std}
164 |
165 | def kl_loss(self, post, prior, balance=0.8):
166 | post_const = tf.nest.map_structure(tf.stop_gradient, post)
167 | prior_const = tf.nest.map_structure(tf.stop_gradient, prior)
168 | lhs = tfd.kl_divergence(self.get_dist(post_const), self.get_dist(prior))
169 | rhs = tfd.kl_divergence(self.get_dist(post), self.get_dist(prior_const))
170 | return balance * lhs + (1 - balance) * rhs
171 |
172 |
173 | class MLP(common.Module):
174 | def __init__(
175 | self, shape, linear_probe=False, layers=[512, 512, 512, 512], act="elu", norm="none", **out
176 | ):
177 | self._shape = (shape,) if isinstance(shape, int) else shape
178 | self._layers = layers
179 | self._norm = norm
180 | self._act = get_act(act)
181 | self._out = out
182 | self._linear_probe = linear_probe
183 |
184 | def __call__(self, features):
185 | x = tf.cast(features, prec.global_policy().compute_dtype)
186 | x = x.reshape([-1, x.shape[-1]])
187 | for index, unit in enumerate(self._layers):
188 | x = self.get(f"dense{index}", tfkl.Dense, unit, trainable=(not self._linear_probe))(x)
189 | x = self.get(f"norm{index}", NormLayer, self._norm, trainable=(not self._linear_probe))(x)
190 | x = self._act(x)
191 | if self._linear_probe:
192 | x = x.reshape(features.shape[:-1] + [x.shape[-1]])
193 | return self.get("out_probe", DistLayer, self._shape, **self._out)(x)
194 | x = x.reshape(features.shape[:-1] + [x.shape[-1]])
195 | return self.get("out", DistLayer, self._shape, **self._out)(x)
196 |
197 |
198 | class GRUCell(tf.keras.layers.AbstractRNNCell):
199 | def __init__(self, size, norm=True, act="tanh", update_bias=-1, **kwargs):
200 | super().__init__()
201 | self._size = size
202 | self._act = get_act(act)
203 | self._update_bias = update_bias
204 | self._layer = tfkl.Dense(3 * size, **kwargs)
205 | if norm:
206 | self._norm = NormLayer("layer")
207 | else:
208 | self._norm = NormLayer("none")
209 |
210 | @property
211 | def state_size(self):
212 | return self._size
213 |
214 | @tf.function
215 | def call(self, inputs, state):
216 | state = state[0] # Keras wraps the state in a list.
217 | parts = self._layer(tf.concat([inputs, state], -1))
218 | parts = self._norm(parts)
219 | reset, cand, update = tf.split(parts, 3, -1)
220 | reset = tf.nn.sigmoid(reset)
221 | cand = self._act(reset * cand)
222 | update = tf.nn.sigmoid(update + self._update_bias)
223 | output = update * cand + (1 - update) * state
224 | return output, [output]
225 |
226 |
227 | class DistLayer(common.Module):
228 | def __init__(
229 | self,
230 | shape,
231 | dist="mse",
232 | outscale=0.1,
233 | min_std=0.1,
234 | max_std=1.0,
235 | ):
236 | self._shape = shape
237 | self._dist = dist
238 | self._min_std = min_std
239 | self._max_std = max_std
240 | self._outscale = outscale
241 |
242 | def __call__(self, inputs):
243 | kw = {}
244 | if self._outscale == 0.0:
245 | kw["kernel_initializer"] = tfki.Zeros()
246 | else:
247 | kw["kernel_initializer"] = tfki.VarianceScaling(
248 | self._outscale, "fan_avg", "uniform"
249 | )
250 | out = self.get("out", tfkl.Dense, np.prod(self._shape), **kw)(inputs)
251 | out = tf.reshape(out, tf.concat([tf.shape(inputs)[:-1], self._shape], 0))
252 | out = tf.cast(out, tf.float32)
253 | if self._dist in ("normal", "trunc_normal"):
254 | std = self.get("std", tfkl.Dense, np.prod(self._shape))(inputs)
255 | std = tf.reshape(std, tf.concat([tf.shape(inputs)[:-1], self._shape], 0))
256 | std = tf.cast(std, tf.float32)
257 | if self._dist == "mse":
258 | return common.MSEDist(out, len(self._shape), "sum")
259 | if self._dist == "symlog":
260 | return common.SymlogDist(out, len(self._shape), "sum")
261 | if self._dist == "nmse":
262 | return common.NormalizedMSEDist(out, len(self._shape), "sum")
263 | if self._dist == "normal":
264 | lo, hi = self._min_std, self._max_std
265 | std = (hi - lo) * tf.nn.sigmoid(std) + lo
266 | dist = tfd.Normal(tf.tanh(out), std)
267 | dist = tfd.Independent(dist, len(self._shape))
268 | dist.minent = np.prod(self._shape) * tfd.Normal(0.0, lo).entropy()
269 | dist.maxent = np.prod(self._shape) * tfd.Normal(0.0, hi).entropy()
270 | return dist
271 | if self._dist == "binary":
272 | dist = tfd.Bernoulli(out)
273 | return tfd.Independent(dist, len(self._shape))
274 | if self._dist == "trunc_normal":
275 | lo, hi = self._min_std, self._max_std
276 | std = (hi - lo) * tf.nn.sigmoid(std) + lo
277 | dist = tfd.TruncatedNormal(tf.tanh(out), std, -1, 1)
278 | dist = tfd.Independent(dist, 1)
279 | dist.minent = np.prod(self._shape) * tfd.Normal(0.99, lo).entropy()
280 | dist.maxent = np.prod(self._shape) * tfd.Normal(0.0, hi).entropy()
281 | return dist
282 | if self._dist == "onehot":
283 | dist = common.OneHotDist(out)
284 | if len(self._shape) > 1:
285 | dist = tfd.Independent(dist, len(self._shape) - 1)
286 | dist.minent = 0.0
287 | dist.maxent = np.prod(self._shape[:-1]) * np.log(self._shape[-1])
288 | return dist
289 | raise NotImplementedError(self._dist)
290 |
291 |
292 | class NormLayer(common.Module, tf.keras.layers.Layer):
293 | def __init__(self, impl, trainable=True):
294 | super().__init__()
295 | self._impl = impl
296 | self._trainable = trainable
297 |
298 | def build(self, input_shape):
299 | if self._impl == "keras":
300 | self.layer = tfkl.LayerNormalization(trainable=self._trainable)
301 | self.layer.build(input_shape)
302 | elif self._impl == "layer":
303 | self.scale = self.add_weight("scale", input_shape[-1], tf.float32, "Ones", trainable=self._trainable)
304 | self.offset = self.add_weight(
305 | "offset", input_shape[-1], tf.float32, "Zeros", trainable=self._trainable
306 | )
307 |
308 | def call(self, x):
309 | if self._impl == "none":
310 | return x
311 | elif self._impl == "keras":
312 | return self.layer(x)
313 | elif self._impl == "layer":
314 | mean, var = tf.nn.moments(x, -1, keepdims=True)
315 | return tf.nn.batch_normalization(
316 | x, mean, var, self.offset, self.scale, 1e-3
317 | )
318 | else:
319 | raise NotImplementedError(self._impl)
320 |
321 |
322 | class MLPEncoder(common.Module):
323 | def __init__(
324 | self, act="elu", norm="none", layers=[512, 512, 512, 512], batchnorm=False
325 | ):
326 | self._act = get_act(act)
327 | self._layers = layers
328 | self._norm = norm
329 | self._batchnorm = batchnorm
330 |
331 | @tf.function
332 | def __call__(self, x, training=False):
333 | x = x.astype(prec.global_policy().compute_dtype)
334 | if self._batchnorm:
335 | x = self.get(f"batchnorm", tfkl.BatchNormalization)(x, training=training)
336 | for i, unit in enumerate(self._layers):
337 | x = self.get(f"dense{i}", tfkl.Dense, unit)(x)
338 | x = self.get(f"densenorm{i}", NormLayer, self._norm)(x)
339 | x = self._act(x)
340 | return x
341 |
342 |
343 | class CNNEncoder(common.Module):
344 | def __init__(
345 | self,
346 | cnn_depth=64,
347 | cnn_kernels=(4, 4),
348 | act="elu",
349 | ):
350 | self._act = get_act(act)
351 | self._cnn_depth = cnn_depth
352 | self._cnn_kernels = cnn_kernels
353 |
354 | @tf.function
355 | def __call__(self, x):
356 | x = x.astype(prec.global_policy().compute_dtype)
357 | for i, kernel in enumerate(self._cnn_kernels):
358 | depth = 2 ** i * self._cnn_depth
359 | x = self.get(f"conv{i}", tfkl.Conv2D, depth, kernel, 1)(x)
360 | x = self._act(x)
361 | return x
362 |
363 |
364 | class CNNDecoder(common.Module):
365 | def __init__(
366 | self,
367 | out_dim,
368 | cnn_depth=64,
369 | cnn_kernels=(4, 5),
370 | act="elu",
371 | ):
372 | self._out_dim = out_dim
373 | self._act = get_act(act)
374 | self._cnn_depth = cnn_depth
375 | self._cnn_kernels = cnn_kernels
376 |
377 | @tf.function
378 | def __call__(self, x):
379 | x = x.astype(prec.global_policy().compute_dtype)
380 |
381 | x = self.get("convin", tfkl.Dense, 2 * 2 * 2 * self._cnn_depth)(x)
382 | x = tf.reshape(x, [-1, 1, 1, 8 * self._cnn_depth])
383 |
384 | for i, kernel in enumerate(self._cnn_kernels):
385 | depth = 2 ** (len(self._cnn_kernels) - i - 1) * self._cnn_depth
386 | x = self.get(f"conv{i}", tfkl.Conv2DTranspose, depth, kernel, 1)(x)
387 | x = self._act(x)
388 | x = self.get("convout", tfkl.Dense, self._out_dim)(x)
389 | return x
390 |
391 |
392 | def get_act(name):
393 | if name == "none":
394 | return tf.identity
395 | if name == "mish":
396 | return lambda x: x * tf.math.tanh(tf.nn.softplus(x))
397 | elif hasattr(tf.nn, name):
398 | return getattr(tf.nn, name)
399 | elif hasattr(tf, name):
400 | return getattr(tf, name)
401 | else:
402 | raise NotImplementedError(name)
403 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import functools
3 | import logging
4 | import os
5 | import pathlib
6 | import re
7 | import sys
8 | import warnings
9 | import torch
10 | import itertools
11 | from omegaconf import OmegaConf
12 |
13 | try:
14 | import rich.traceback
15 |
16 | rich.traceback.install()
17 | except ImportError:
18 | pass
19 |
20 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
21 | logging.getLogger().setLevel("ERROR")
22 | warnings.filterwarnings("ignore", ".*box bound precision lowered.*")
23 |
24 | sys.path.append(str(pathlib.Path(__file__).parent))
25 | sys.path.append(str(pathlib.Path(__file__).parent.parent))
26 |
27 | import numpy as np
28 | from keras import backend as K
29 | import ruamel.yaml as yaml
30 |
31 | import common
32 |
33 | def process_curriculum_str(lang_curriculum_str):
34 | cfg_lang_prompts = lang_curriculum_str.split("|")
35 | lang_curriculum = []
36 | for i in range(len(cfg_lang_prompts)):
37 | lang_curriculum.append(cfg_lang_prompts[i].split(","))
38 |
39 | for i in range(len(lang_curriculum)):
40 | for j in range(len(lang_curriculum[i])):
41 | if os.path.exists(lang_curriculum[i][j]):
42 | with open(lang_curriculum[i][j], 'r') as f:
43 | lang_curriculum[i].pop(j)
44 | lang_curriculum[i].extend(f.read().splitlines())
45 |
46 | lang_instructions = list(itertools.chain(*lang_curriculum))
47 | lang_instructions = list(set(lang_instructions))
48 | return lang_curriculum, lang_instructions
49 |
50 | def get_lang_info(multi_task_vidlang, task, lang_prompt, synonym_folder, objects):
51 | """
52 | Returns
53 | final_lang_instructions: list of all possible processed language instructions
54 | lang_to_num: dict mapping language instruction (from above) to number
55 | lang_curriculum: list of non-processed instructions for each stage of curriculum
56 | synonym_dict: dict mapping object to object synonyms
57 | """
58 | if multi_task_vidlang:
59 | lang_instructions = task.split(',')
60 | return lang_instructions, None, None
61 | else:
62 | lang_curriculum, lang_instructions = process_curriculum_str(lang_prompt)
63 |
64 | _, all_objs = process_curriculum_str(objects)
65 | noun_variations = []
66 | synonym_dict = {}
67 | if synonym_folder is None:
68 | for obj in all_objs:
69 | synonym_dict[obj] = obj
70 | noun_variations = all_objs
71 | else:
72 | for obj in all_objs:
73 | with open(os.path.join(synonym_folder, f"synonym_{obj}.txt"), 'r') as f:
74 | synonyms = f.read().splitlines()
75 | synonym_dict[obj] = synonyms
76 | noun_variations.extend(synonyms)
77 | noun_variations = list(set(noun_variations))
78 |
79 | final_lang_instructions = []
80 | for lang_instr in lang_instructions:
81 | if "[NOUN]" in lang_instr:
82 | for noun in noun_variations:
83 | final_lang_instructions.append(lang_instr.replace("[NOUN]", noun))
84 | else:
85 | final_lang_instructions.append(lang_instr)
86 |
87 | lang_nums = range(len(final_lang_instructions))
88 | lang_to_num = dict(zip(final_lang_instructions, lang_nums))
89 | return final_lang_instructions, lang_to_num, lang_curriculum, synonym_dict
90 |
91 | def main():
92 |
93 | configs = yaml.safe_load(
94 | (pathlib.Path(sys.argv[0]).parent / "configs.yaml").read_text()
95 | )
96 | parsed, remaining = common.Flags(configs=["defaults"]).parse(known_only=True)
97 | config = common.Config(configs["defaults"])
98 | for name in parsed.configs:
99 | config = config.update(configs[name])
100 | config = common.Flags(config).parse(remaining)
101 |
102 | logdir = pathlib.Path(config.logdir).expanduser()
103 | logdir.mkdir(parents=True, exist_ok=True)
104 | config.save(logdir / "config.yaml")
105 |
106 | print(config, "\n")
107 | print("Logdir", logdir)
108 |
109 | loaddir = pathlib.Path(config.loaddir).expanduser()
110 | print("Loaddir", loaddir)
111 |
112 | import tensorflow as tf
113 | tf.config.experimental_run_functions_eagerly(not config.jit)
114 | message = "No GPU found. To actually train on CPU remove this assert."
115 | assert tf.config.experimental.list_physical_devices("GPU"), message
116 | print(tf.config.experimental.list_physical_devices("GPU"))
117 | for gpu in tf.config.experimental.list_physical_devices("GPU"):
118 | tf.config.experimental.set_memory_growth(gpu, True)
119 | assert config.precision in (16, 32), config.precision
120 | if config.precision == 16:
121 | from tensorflow.keras.mixed_precision import experimental as prec
122 |
123 | prec.set_policy(prec.Policy("mixed_float16"))
124 |
125 | train_replay = common.Replay(logdir / "train_episodes", **config.replay)
126 | eval_replay = common.Replay(
127 | logdir / "eval_episodes",
128 | **dict(
129 | capacity=config.replay.capacity // 10,
130 | minlen=config.replay.minlen,
131 | maxlen=config.replay.maxlen,
132 | ),
133 | )
134 | step = common.Counter(train_replay.stats["total_steps"])
135 | outputs = [
136 | common.TerminalOutput(),
137 | common.JSONLOutput(logdir),
138 | common.TensorBoardOutput(logdir),
139 | ]
140 | logger = common.Logger(step, outputs, multiplier=config.action_repeat)
141 | metrics = collections.defaultdict(list)
142 |
143 | should_train = common.Every(config.train_every)
144 | should_train_mae = common.Every(config.train_mae_every)
145 | should_log = common.Every(config.log_every)
146 |
147 | lang_instructions, lang_to_num, lang_curriculum, synonym_dict = get_lang_info(config.multi_task_vidlang, config.task, config.curriculum.lang_prompt, config.curriculum.synonym_folder, config.curriculum.objects)
148 |
149 | if (loaddir / f"variables_{config.ts}.pkl").exists():
150 | with open(loaddir / 'config.yaml', 'r') as f:
151 | pt_cfg = OmegaConf.create(yaml.safe_load(f))
152 | pt_lang_instructions, _, _, _ = get_lang_info(pt_cfg.multi_task_vidlang, pt_cfg.task, pt_cfg.curriculum.lang_prompt, pt_cfg.curriculum.synonym_folder, pt_cfg.curriculum.objects)
153 | if (pt_cfg.plan2explore or pt_cfg.rnd):
154 | if (pt_cfg.use_r3m_reward or pt_cfg.use_internvideo_reward or pt_cfg.use_clip_reward):
155 | config = config.update({'mae.state_dim': config.mae.state_dim + 768})
156 | config = config.update({'pretrain_mode': 'lang_emb'})
157 | else:
158 | config = config.update({'pretrain_mode': 'no_lang'})
159 | elif pt_cfg.use_lang_embeddings:
160 | config = config.update({'mae.state_dim': config.mae.state_dim + 768})
161 | config = config.update({'pretrain_mode': 'lang_emb'})
162 | else:
163 | config = config.update({'mae.state_dim': config.mae.state_dim + len(pt_lang_instructions)})
164 | config = config.update({'pretrain_mode': 'one_hot'})
165 | config = config.update({'num_langs': len(pt_lang_instructions)})
166 | config = config.update({'train_mode': 'finetune'})
167 | elif (config.plan2explore or config.rnd) and not (config.use_r3m_reward or config.use_internvideo_reward or config.use_clip_reward):
168 | config = config.update({'train_mode': 'pretrain'})
169 | else:
170 | config = config.update({'pretrain_mode': None})
171 | if config.use_lang_embeddings:
172 | config = config.update({'mae.state_dim': config.mae.state_dim + 768})
173 | else:
174 | config = config.update({'mae.state_dim': config.mae.state_dim + len(lang_instructions)})
175 | config = config.update({'num_langs': len(lang_instructions)})
176 | config = config.update({'train_mode': 'pretrain'})
177 | config.save(logdir / "config_updated.yaml")
178 |
179 | finetune_lang_encoding = None
180 | if config.train_mode == "finetune" and config.pretrain_mode == "lang_emb":
181 | import r3mreward as r3mreward
182 | from r3m import load_r3m
183 | finetune_instruction = config.task.replace("_", " ")
184 | model = load_r3m('resnet50').module.eval().to(config.vidlang_model_device)
185 | model = r3mreward.R3MReward(model, [finetune_instruction], config.standardize_rewards,
186 | config.queue_size, config.update_stats_steps, config.num_top_images, config.use_lang_embeddings)
187 | finetune_lang_encoding = model.get_lang_encoding([finetune_instruction]).cpu().numpy()[0]
188 | if config.tune_instruction:
189 | if os.path.exists(config.instructions_file):
190 | print("Loading instructions to tune with...")
191 | with open(config.instructions_file, 'r') as f:
192 | candidate_instructions = f.read().splitlines()
193 | candidate_instructions = candidate_instructions[:config.num_tune_instructions-10]
194 | print(f"Loaded {config.num_tune_instructions} instructions from {config.instructions_file}")
195 | else:
196 | print("Generating instructions to tune with...")
197 | pt_lang_instructions, _, lang_curriculum, _ = get_lang_info(pt_cfg.multi_task_vidlang, pt_cfg.task, pt_cfg.curriculum.lang_prompt, pt_cfg.curriculum.synonym_folder, pt_cfg.curriculum.objects)
198 | candidate_instructions = np.random.choice(pt_lang_instructions, config.num_tune_instructions-10, replace=False)
199 | with open(f"prompts/tune_instructions_{config.num_tune_instructions}.txt", 'w') as f:
200 | for instr in candidate_instructions:
201 | f.write(instr + '\n')
202 | print(f"Generated {config.num_tune_instructions} instructions and saved to prompts/tune_instructions_{config.num_tune_instructions}.txt")
203 | with open(f"prompts/{config.task}.txt", 'r') as f:
204 | task_candidate_instructions = f.read().splitlines()
205 | candidate_instructions.extend(task_candidate_instructions)
206 | candidate_instructions_encodings = []
207 | for instr in candidate_instructions:
208 | candidate_instructions_encodings.append(model.get_lang_encoding([instr]).cpu().numpy()[0])
209 | candidate_instructions_scores = [0] * len(candidate_instructions)
210 | global score_idx
211 | score_idx = 0
212 | del model
213 | import gc; gc.collect()
214 | lang_instructions = [finetune_instruction]
215 |
216 | if not config.use_r3m_reward and config.train_mode == "pretrain" and config.use_lang_embeddings:
217 | import r3mreward as r3mreward
218 | from r3m import load_r3m
219 | sentences = [t.replace("_", " ") for t in lang_instructions]
220 | model = load_r3m('resnet50').module.eval().to(config.vidlang_model_device)
221 | model = r3mreward.R3MReward(model, sentences, config.standardize_rewards,
222 | config.queue_size, config.update_stats_steps, config.num_top_images, config.use_lang_embeddings)
223 | lang_encodings = model.get_lang_encoding(sentences).cpu().numpy()
224 | lang_to_encoding = dict(zip(lang_instructions, lang_encodings))
225 | del model
226 | import gc; gc.collect()
227 |
228 |
229 | if config.use_r3m_reward:
230 | import r3mreward as r3mreward
231 | from r3m import load_r3m
232 | sentences = [t.replace("_", " ") for t in lang_instructions]
233 |
234 | model = load_r3m('resnet50').module.eval().to(config.vidlang_model_device)
235 | model = r3mreward.R3MReward(model, sentences, config.standardize_rewards,
236 | config.queue_size, config.update_stats_steps, config.num_top_images, config.use_lang_embeddings)
237 | if config.use_lang_embeddings:
238 | lang_encodings = model.get_lang_encoding(sentences).cpu().numpy()
239 | lang_to_encoding = dict(zip(lang_instructions, lang_encodings))
240 |
241 | def get_r3m_reward(data, step=0):
242 | with torch.no_grad():
243 | if config.camera_keys == "image_front|image_wrist" or config.camera_keys == "image_overhead|image_wrist":
244 | init_image = np.split(data['init_image'], 2, axis=-2)[0]
245 | image = np.split(data['image'], 2, axis=-2)[0]
246 | else:
247 | init_image = data['init_image']
248 | image = data['image']
249 | init_image = torch.from_numpy(init_image).to(config.vidlang_model_device)
250 | image = torch.from_numpy(image).to(config.vidlang_model_device)
251 | lang_num = torch.from_numpy(data['lang_num']).to(config.vidlang_model_device).unsqueeze(-1)
252 | if config.use_lang_embeddings:
253 | lang_embedding = torch.from_numpy(data['lang_embedding']).to(config.vidlang_model_device)
254 | else:
255 | lang_embedding = None
256 | init_image = init_image.permute((0, 3, 1, 2))
257 | image = image.permute((0, 3, 1, 2))
258 | reward, _, _ = model.get_reward(init_image, image, lang_num, lang_embedding, step)
259 | return reward.squeeze(-1).cpu().numpy()
260 |
261 | train_replay.reward_relabel(get_r3m_reward)
262 | elif config.use_internvideo_reward:
263 | sys.path.append('InternVideo/Pretrain/Multi-Modalities-Pretraining')
264 | import InternVideo
265 | from InternVideo import video_transform
266 | from torchvision import transforms
267 | print('Loading InternVideo model from path: {}...'.format(config.internvideo_load_dir))
268 | model = InternVideo.load_model(config.internvideo_load_dir).cuda().to(config.vidlang_model_device)
269 | upsample = torch.nn.Upsample(size=(224,224), mode='bilinear', align_corners=False)
270 | input_mean = [0.48145466, 0.4578275, 0.40821073]
271 | input_std = [0.26862954, 0.26130258, 0.27577711]
272 | trans = transforms.Compose([
273 | video_transform.ClipToTensor(channel_nb=3),
274 | video_transform.Normalize(mean=input_mean, std=input_std)
275 | ])
276 | def get_internvideo_reward(data, step=0):
277 | with torch.no_grad():
278 | if config.camera_keys == "image_front|image_wrist" or config.camera_keys == "image_overhead|image_wrist":
279 | image = np.split(data['image'], 2, axis=-2)[0]
280 | else:
281 | image = data['image']
282 | videos = trans(image).to(config.vidlang_model_device)
283 | videos = upsample(videos)
284 | text_cand = [lang_instructions[lang_num] for lang_num in data["lang_num"]]
285 | text = InternVideo.tokenize(text_cand).cuda().to(config.vidlang_model_device)
286 | text_features = model.encode_text(text)
287 |
288 | reward = []
289 | for i in range(videos.shape[1]):
290 | if i < 8:
291 | video = videos[:, :i+1, :, :]
292 | video = torch.cat([video, video[:, -1:, :, :].repeat(1, 8-(i+1), 1, 1)], dim=1).unsqueeze(0)
293 | else:
294 | indices = np.ceil(np.linspace(0, i, 8)).astype(int)
295 | video = videos[:, indices, :, :].unsqueeze(0)
296 | video_features = model.encode_video(video)
297 | video_features = torch.nn.functional.normalize(video_features, dim=1)
298 | text_features = torch.nn.functional.normalize(text_features, dim=1)
299 | t = model.logit_scale.exp()
300 | reward.append((video_features @ text_features[i]).cpu().numpy())
301 |
302 | return np.array(reward).squeeze(-1)
303 |
304 | train_replay.reward_relabel(get_internvideo_reward)
305 | elif config.use_clip_reward:
306 | import clip
307 | sys.path.append('InternVideo/Pretrain/Multi-Modalities-Pretraining')
308 | import InternVideo
309 | from InternVideo import video_transform
310 | from torchvision import transforms
311 | model, _ = clip.load('ViT-B/32', config.vidlang_model_device)
312 | upsample = torch.nn.Upsample(size=(224,224), mode='bilinear', align_corners=False)
313 | input_mean = [0.48145466, 0.4578275, 0.40821073]
314 | input_std = [0.26862954, 0.26130258, 0.27577711]
315 | trans = transforms.Compose([
316 | video_transform.ClipToTensor(channel_nb=3),
317 | video_transform.Normalize(mean=input_mean, std=input_std)
318 | ])
319 | def get_clip_reward(data, step=0):
320 | with torch.no_grad():
321 | if config.camera_keys == "image_front|image_wrist" or config.camera_keys == "image_overhead|image_wrist":
322 | image = np.split(data['image'], 2, axis=-2)[0]
323 | init_image = np.split(data['init_image'], 2, axis=-2)[0]
324 | else:
325 | image = data['image']
326 | init_image = data['init_image']
327 | init_videos = trans(init_image).to(config.vidlang_model_device)
328 | init_videos = upsample(init_videos)
329 | videos = trans(image).to(config.vidlang_model_device)
330 | videos = upsample(videos)
331 | text_inputs = torch.cat([clip.tokenize(lang_instructions[lang_num]) for lang_num in data["lang_num"]]).to(config.vidlang_model_device)
332 | init_image_inputs = init_videos.permute((1, 0, 2, 3))
333 | image_inputs = videos.permute((1, 0, 2, 3))
334 | init_image_features = model.encode_image(init_image_inputs)
335 | image_features = model.encode_image(image_inputs)
336 | init_text_features = model.encode_text(text_inputs)
337 | text_features = model.encode_text(text_inputs)
338 | init_image_features /= init_image_features.norm(dim=-1, keepdim=True)
339 | image_features /= image_features.norm(dim=-1, keepdim=True)
340 | init_text_features /= init_text_features.norm(dim=-1, keepdim=True)
341 | text_features /= text_features.norm(dim=-1, keepdim=True)
342 | delta_image_features = image_features - init_image_features
343 | delta_text_features = text_features - init_text_features
344 | reward = (delta_image_features * delta_text_features).sum(dim=-1).cpu().numpy()
345 | return reward
346 | train_replay.reward_relabel(get_clip_reward)
347 | else:
348 | print("Training from task reward...")
349 |
350 | if not config.use_lang_embeddings:
351 | lang_to_encoding = None
352 | def make_env(mode, actions_min_max=None):
353 | camera_keys = common.get_camera_keys(config.camera_keys)
354 | task = config.task.split(",")
355 |
356 | if config.franka_kitchen:
357 | env = common.KitchenEnv(task)
358 | elif config.multi_task_vidlang:
359 | env = common.MultiTaskVidLangRLBench(task,
360 | camera_keys,
361 | config.render_size,
362 | shaped_rewards=config.shaped_rewards,
363 | lang_to_num=lang_to_num,
364 | lang_to_encoding=lang_to_encoding,
365 | use_lang_embeddings=config.use_lang_embeddings,
366 | boundary_reward_penalty=config.boundary_reward_penalty,
367 | randomize=config.randomize,
368 | )
369 | elif config.use_r3m_reward or config.use_internvideo_reward or config.use_clip_reward:
370 | env = common.VidLangRLBench(task[0],
371 | lang_instructions,
372 | camera_keys,
373 | config.render_size,
374 | shaped_rewards=config.shaped_rewards,
375 | lang_to_num=lang_to_num,
376 | lang_to_encoding=lang_to_encoding,
377 | use_lang_embeddings=config.use_lang_embeddings,
378 | boundary_reward_penalty=config.boundary_reward_penalty,
379 | curriculum=config.curriculum,
380 | lang_curriculum=lang_curriculum,
381 | synonym_dict=synonym_dict,
382 | randomize=config.randomize
383 | )
384 | else:
385 | env = common.RLBench(
386 | lang_instructions,
387 | task[0],
388 | camera_keys,
389 | config.render_size,
390 | shaped_rewards=config.shaped_rewards,
391 | use_lang_embeddings=config.use_lang_embeddings,
392 | randomize=config.randomize,
393 | finetune_lang_encoding=finetune_lang_encoding
394 | )
395 | if actions_min_max:
396 | env.register_min_max(actions_min_max)
397 |
398 | env = common.TimeLimit(env, config.time_limit)
399 | return env
400 |
401 | def per_episode(ep, mode, lang_instructions=None):
402 | length = len(ep["reward"]) - 1
403 | score = float(ep["reward"].astype(np.float64).sum())
404 | success = float(np.sum(ep["success"]) >= 1.0)
405 | print(
406 | f"{mode.title()} episode has {float(success)} success, {length} steps and return {score:.1f}."
407 | )
408 | logger.scalar(f"{mode}_success", float(success))
409 | logger.scalar(f"{mode}_return", score)
410 | logger.scalar(f"{mode}_length", length)
411 | for key, value in ep.items():
412 | if re.match(config.log_keys_sum, key):
413 | logger.scalar(f"sum_{mode}_{key}", ep[key].sum())
414 | if re.match(config.log_keys_mean, key):
415 | logger.scalar(f"mean_{mode}_{key}", ep[key].mean())
416 | if re.match(config.log_keys_max, key):
417 | logger.scalar(f"max_{mode}_{key}", ep[key].max(0).mean())
418 | replay = dict(train=train_replay, eval=eval_replay)[mode]
419 | logger.add(replay.stats, prefix=mode)
420 | logger.write()
421 |
422 | print("Create envs.")
423 | num_eval_envs = min(config.envs, config.eval_eps)
424 | train_envs = [make_env("train") for _ in range(config.envs)]
425 |
426 |
427 | actions_min_max = None
428 |
429 | act_space = train_envs[0].act_space
430 | obs_space = train_envs[0].obs_space
431 |
432 | import agent as agent
433 |
434 | agnt = agent.Agent(config, obs_space, act_space, step)
435 | eval_policy = lambda *args: agnt.policy(*args, mode="eval")
436 |
437 | if config.tune_instruction:
438 | print("Tuning instruction...")
439 | def tune(ep, candidate_instructions_encodings=candidate_instructions_encodings, candidate_instructions_scores=candidate_instructions_scores):
440 | global score_idx
441 | candidate_instructions_scores[score_idx] = ep["reward"].sum()
442 | score_idx += 1
443 |
444 | tune_driver = common.Driver(train_envs)
445 | tune_driver.on_episode(lambda ep: tune(ep))
446 | train_envs[0].change_lang_encoding(candidate_instructions_encodings[0])
447 | tune_driver(eval_policy, episodes=config.num_tune_instructions)
448 | max_score_idx = np.argmax(np.array(candidate_instructions_scores))
449 | print("Best instruction: {} with reward {}".format(candidate_instructions[max_score_idx], candidate_instructions_scores[max_score_idx]))
450 | train_envs[0].change_lang_encoding(candidate_instructions_encodings[max_score_idx])
451 | finetune_lang_encoding = candidate_instructions_encodings[max_score_idx]
452 |
453 | make_async_env = lambda mode: common.Async(
454 | functools.partial(make_env, mode, actions_min_max), config.envs_parallel
455 | )
456 | eval_envs = [make_async_env("eval") for _ in range(num_eval_envs)]
457 |
458 | print("Creating train and eval drivers.")
459 | train_driver = common.Driver(train_envs)
460 | train_driver.on_episode(lambda ep: per_episode(ep, mode="train"))
461 | train_driver.on_step(lambda tran, worker: step.increment())
462 | train_driver.on_episode(train_replay.add_episode)
463 | eval_driver = common.Driver(eval_envs)
464 | eval_driver.on_episode(lambda ep: per_episode(ep, mode="eval", lang_instructions=lang_instructions))
465 | eval_driver.on_episode(eval_replay.add_episode)
466 |
467 | prefill = max(0, config.prefill - train_replay.stats["total_steps"])
468 | if prefill:
469 | print(f"Prefill dataset ({prefill} steps).")
470 | random_agent = common.RandomAgent(act_space)
471 | train_driver(random_agent, steps=prefill, episodes=1)
472 | eval_driver(random_agent, episodes=1)
473 | train_driver.reset()
474 | eval_driver.reset()
475 |
476 | print("Create agent.")
477 | train_dataset = iter(train_replay.dataset(**config.dataset))
478 | mae_train_dataset = iter(train_replay.dataset(**config.mae_dataset))
479 | report_dataset = iter(train_replay.dataset(**config.dataset))
480 |
481 | if not config.use_imagenet_mae:
482 | train_mae = agnt.train_mae
483 | train_agent = common.CarryOverState(agnt.train)
484 |
485 | if not config.use_imagenet_mae:
486 | train_mae(next(mae_train_dataset))
487 | train_agent(next(train_dataset))
488 |
489 | if (loaddir / f"variables_{config.ts}.pkl").exists():
490 | print("Loading agent.")
491 | try:
492 | agnt.load(loaddir / f"variables_{config.ts}.pkl")
493 | except Exception as e:
494 | raise Exception(f"Error loading agent: {e}")
495 | else:
496 | assert config.loaddir == ''
497 |
498 | print("Pretrain agent.")
499 | for _ in range(config.mae_pretrain):
500 | data = next(mae_train_dataset)
501 | if config.use_zero_rewards:
502 | data['reward'] = data['reward'] * tf.cast(data['is_terminal'], tf.float32)
503 | train_mae(data)
504 | for _ in range(config.pretrain):
505 | data = next(train_dataset)
506 | if config.use_zero_rewards:
507 | data['reward'] = data['reward'] * tf.cast(data['is_terminal'], tf.float32)
508 | train_agent(data)
509 |
510 | train_policy = lambda *args: agnt.policy(*args, mode="train")
511 |
512 | def train_step(tran, worker):
513 | if not config.use_imagenet_mae:
514 | if should_train_mae(step):
515 | for _ in range(config.train_mae_steps):
516 | data = next(mae_train_dataset)
517 | if config.use_zero_rewards:
518 | data['reward'] = data['reward'] * tf.cast(data['is_terminal'], tf.float32)
519 | mets = train_mae(data)
520 | [metrics[key].append(value) for key, value in mets.items()]
521 | if should_train(step):
522 | for _ in range(config.train_steps):
523 | data = next(train_dataset)
524 | if config.use_zero_rewards:
525 | data['reward'] = data['reward'] * tf.cast(data['is_terminal'], tf.float32)
526 | mets = train_agent(data)
527 | [metrics[key].append(value) for key, value in mets.items()]
528 | if should_log(step):
529 | for name, values in metrics.items():
530 | logger.scalar(name, np.array(values, np.float64).mean())
531 | metrics[name].clear()
532 | logger.add(
533 | agnt.report(next(report_dataset)),
534 | prefix="train",
535 | )
536 | logger.write(fps=True)
537 |
538 | train_driver.on_step(train_step)
539 |
540 | config.save(logdir / "config_updated.yaml")
541 | while step < config.steps:
542 | logger.write()
543 | print("Start evaluation.")
544 | eval_driver(eval_policy, episodes=config.eval_eps)
545 | print("Start training.")
546 | train_driver(train_policy, steps=config.eval_every)
547 | agnt.save(logdir / f"variables_{step.value}.pkl")
548 | for env in train_envs + eval_envs:
549 | try:
550 | env.close()
551 | except Exception:
552 | pass
553 | agnt.save(logdir / f"variables_final.pkl")
554 |
555 |
556 | if __name__ == "__main__":
557 | main()
558 |
--------------------------------------------------------------------------------
/common/envs.py:
--------------------------------------------------------------------------------
1 | import atexit
2 | import os
3 | import copy
4 | import random
5 | import sys
6 | import threading
7 | import traceback
8 |
9 | from pyrep.const import TextureMappingMode
10 | from pyrep.const import RenderMode
11 |
12 | import cloudpickle
13 | from functools import partial
14 | import gym
15 | import numpy as np
16 | from rlbench.utils import name_to_task_class
17 |
18 | from rlbench import RandomizeEvery
19 | from rlbench import VisualRandomizationConfig
20 |
21 | import time
22 |
23 | try:
24 | from pyrep.errors import ConfigurationPathError, IKError
25 | from rlbench.backend.exceptions import InvalidActionError
26 | except:
27 | pass
28 |
29 |
30 | class RLBench:
31 | def __init__(
32 | self,
33 | langs,
34 | name,
35 | camera_keys,
36 | size=(64, 64),
37 | actions_min_max=None,
38 | shaped_rewards=False,
39 | use_lang_embeddings=False,
40 | boundary_reward_penalty=False,
41 | randomize=False,
42 | finetune_lang_encoding=None,
43 | ):
44 | from rlbench.action_modes.action_mode import MoveArmThenGripper
45 | from rlbench.action_modes.arm_action_modes import (
46 | EndEffectorPoseViaPlanning,
47 | )
48 | from rlbench.action_modes.gripper_action_modes import (
49 | Discrete,
50 | )
51 | from rlbench.environment import Environment
52 | from rlbench.observation_config import ObservationConfig
53 | from rlbench.tasks import (
54 | PhoneOnBase,
55 | PickAndLift,
56 | PickUpCup,
57 | PutRubbishInBin,
58 | TakeLidOffSaucepan,
59 | TakeUmbrellaOutOfUmbrellaStand,
60 | MultiTaskMicrofridgesauce,
61 | # MultiTaskBusplanesauce,
62 | PickShapenetObjects
63 | )
64 |
65 | # we only support reach_target in this codebase
66 | obs_config = ObservationConfig()
67 |
68 | ## Camera setups
69 | obs_config.front_camera.set_all(False)
70 | obs_config.wrist_camera.set_all(False)
71 | obs_config.left_shoulder_camera.set_all(False)
72 | obs_config.right_shoulder_camera.set_all(False)
73 | obs_config.overhead_camera.set_all(False)
74 |
75 | if "image_front" in camera_keys:
76 | obs_config.front_camera.rgb = True
77 | obs_config.front_camera.image_size = size
78 | obs_config.front_camera.render_mode = RenderMode.OPENGL
79 |
80 | if "image_wrist" in camera_keys:
81 | obs_config.wrist_camera.rgb = True
82 | obs_config.wrist_camera.image_size = size
83 | obs_config.wrist_camera.render_mode = RenderMode.OPENGL
84 |
85 | if "image_overhead" in camera_keys:
86 | obs_config.overhead_camera.rgb = True
87 | obs_config.overhead_camera.image_size = size
88 | obs_config.overhead_camera.render_mode = RenderMode.OPENGL
89 |
90 | obs_config.joint_forces = False
91 | obs_config.joint_positions = True
92 | obs_config.joint_velocities = True
93 | obs_config.task_low_dim_state = True
94 | obs_config.gripper_touch_forces = False
95 | obs_config.gripper_pose = True
96 | obs_config.gripper_open = True
97 | obs_config.gripper_matrix = False
98 | obs_config.gripper_joint_positions = True
99 |
100 | if randomize:
101 | rand_config = [
102 | VisualRandomizationConfig(image_directory='common/assets/textures/table', whitelist = ['diningTable_visible']),
103 | VisualRandomizationConfig(image_directory='common/assets/textures/wall', whitelist = ['Wall1', 'Wall2', 'Wall3', 'Wall4']),
104 | VisualRandomizationConfig(image_directory='common/assets/textures/floor', whitelist = ['Floor'])
105 | ]
106 | tex_kwargs = [
107 | {'mapping_mode': TextureMappingMode.PLANE, 'repeat_along_u': False, 'repeat_along_v': False, 'uv_scaling': [1.6, 1.1]},
108 | {'mapping_mode': TextureMappingMode.PLANE, 'repeat_along_u': False, 'repeat_along_v': False, 'uv_scaling': [5.0, 3.0]},
109 | {'mapping_mode': TextureMappingMode.PLANE, 'repeat_along_u': False, 'repeat_along_v': False, 'uv_scaling': [5.0, 5.0]}
110 | ]
111 | randomized_every = RandomizeEvery.EPISODE
112 | else:
113 | rand_config = None
114 | randomized_every = None
115 | tex_kwargs = None
116 |
117 | env = Environment(
118 | action_mode=MoveArmThenGripper(
119 | arm_action_mode=EndEffectorPoseViaPlanning(False),
120 | gripper_action_mode=Discrete(),
121 | ),
122 | obs_config=obs_config,
123 | headless=True,
124 | shaped_rewards=shaped_rewards,
125 | randomize_every=randomized_every,
126 | visual_randomization_config=rand_config,
127 | tex_kwargs=tex_kwargs
128 | )
129 | env.launch()
130 |
131 | if name == "phone_on_base":
132 | task = PhoneOnBase
133 | elif name == "pick_and_lift":
134 | task = PickAndLift
135 | elif name == "pick_up_cup":
136 | task = PickUpCup
137 | elif name == "put_rubbish_in_bin":
138 | task = PutRubbishInBin
139 | elif name == "take_lid_off_saucepan":
140 | task = TakeLidOffSaucepan
141 | elif name == "take_umbrella_out_of_umbrella_stand":
142 | task = TakeUmbrellaOutOfUmbrellaStand
143 | elif name == "multi_task_microfridgesauce":
144 | task = MultiTaskMicrofridgesauce
145 | elif name == "pick_shapenet_objects":
146 | task = PickShapenetObjects
147 | elif name in ["reach_for_bus", "reach_for_plane"]:
148 | task = MultiTaskBusplanesauce
149 | else:
150 | task = name_to_task_class(name)
151 | self._env = env
152 | self._task = env.get_task(task)
153 | self.task_name = name
154 |
155 | if "pick_shapenet_objects" in name:
156 | try:
157 | n_obj = int(name.split("_")[-1])
158 | self._task._task.set_num_objects(n_obj)
159 | except:
160 | self._task._task.set_num_objects(1)
161 |
162 | _, obs = self._task.reset()
163 |
164 | task_low_dim = obs.task_low_dim_state.shape[0]
165 | self._state_dim = obs.get_low_dim_data().shape[0] - 14 - task_low_dim
166 | self._prev_obs, self._prev_reward = None, None
167 | self._ep_success = None
168 |
169 | self._size = size
170 | self._shaped_rewards = shaped_rewards
171 | self._camera_keys = camera_keys
172 | self._use_lang_embeddings = use_lang_embeddings
173 | self.finetune_lang_encoding = finetune_lang_encoding
174 | self._boundary_reward_penalty = boundary_reward_penalty
175 | self.langs = langs
176 |
177 | if actions_min_max:
178 | self.register_min_max(actions_min_max)
179 | else:
180 | self.low = np.array([-0.03, -0.03, -0.03])
181 | self.high = np.array([0.03, 0.03, 0.03])
182 |
183 |
184 | self._name = name
185 |
186 | @property
187 | def container_pos(self):
188 | if "pick_shapenet_objects" in self._name:
189 | return self._task._task.large_container.get_position()
190 | return None
191 |
192 | @property
193 | def obs_space(self):
194 | spaces = {
195 | "reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32),
196 | "is_first": gym.spaces.Box(0, 1, (), dtype=bool),
197 | "is_last": gym.spaces.Box(0, 1, (), dtype=bool),
198 | "is_terminal": gym.spaces.Box(0, 1, (), dtype=bool),
199 | "success": gym.spaces.Box(0, 1, (), dtype=bool),
200 | "state": gym.spaces.Box(
201 | -np.inf, np.inf, (self._state_dim,), dtype=np.float32
202 | ),
203 | "image": gym.spaces.Box(
204 | 0,
205 | 255,
206 | (self._size[0], self._size[1] * len(self._camera_keys), 3),
207 | dtype=np.uint8,
208 | ),
209 | "init_state": gym.spaces.Box(
210 | -np.inf, np.inf, (self._state_dim,), dtype=np.float32
211 | ),
212 | "init_image": gym.spaces.Box(
213 | 0,
214 | 255,
215 | (self._size[0], self._size[1] * len(self._camera_keys), 3),
216 | dtype=np.uint8,
217 | ),
218 | "lang_num": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.uint8),
219 | }
220 | if self._use_lang_embeddings:
221 | spaces["lang_embedding"] = gym.spaces.Box(
222 | -np.inf, np.inf, (768,), dtype=np.float32)
223 | return spaces
224 |
225 | def register_min_max(self, actions_min_max):
226 | self.low, self.high = actions_min_max
227 |
228 | @property
229 | def act_space(self):
230 | assert self.low is not None
231 | if self.low.shape[0] == 3:
232 | self.low = np.hstack([self.low, [0.0]])
233 | self.high = np.hstack([self.high, [1.0]])
234 | action = gym.spaces.Box(
235 | low=self.low, high=self.high, shape=(self.low.shape[0],), dtype=np.float32
236 | )
237 | return {"action": action}
238 |
239 | def unnormalize(self, a):
240 | # Un-normalize gripper pose normalized to [-1, 1]
241 | assert self.low is not None
242 | pose = a[:3]
243 | pose = (pose + 1) / 2 * (self.high[:3] - self.low[:3]) + self.low[:3]
244 |
245 | # Manual handling of overflow in z axis
246 | curr_pose = self._task._task.robot.arm.get_tip().get_pose()[:3]
247 | curr_z = curr_pose[2]
248 | init_z = self._init_pose[2]
249 | delta_z = pose[2]
250 |
251 | if curr_z + delta_z >= init_z:
252 | pose[2] = 0.0
253 |
254 | # Un-normalize gripper action normalized to [-1, 1]
255 | gripper = a[3:4]
256 | gripper = (gripper + 1) / 2 * (self.high[3:4] - self.low[3:4]) + self.low[3:4]
257 |
258 | target_pose = pose
259 |
260 | # Identity quaternion
261 | quat = np.array([0.0, 0.0, 0.0, 1.0])
262 |
263 | action = np.hstack([target_pose, quat, gripper])
264 | assert action.shape[0] == 8
265 | return action
266 |
267 | def step(self, action):
268 | assert np.isfinite(action["action"]).all(), action["action"]
269 | try:
270 | original_action = self.unnormalize(action["action"])
271 | _obs, _reward, _ = self._task.step(original_action)
272 | terminal = False
273 | success, _ = self._task._task.success()
274 | if success:
275 | self._ep_success = True
276 | self._prev_obs, self._prev_reward = _obs, _reward
277 | if not self._shaped_rewards:
278 | reward = float(self._ep_success)
279 | else:
280 | reward = _reward
281 | except ConfigurationPathError:
282 | print("ConfigurationPathError")
283 | _obs = self._prev_obs
284 | terminal = False
285 | success = False
286 | if not self._shaped_rewards:
287 | reward = float(self._ep_success)
288 | else:
289 | reward = self._prev_reward
290 | except (IKError, InvalidActionError) as e:
291 | # print(e)
292 | _obs = self._prev_obs
293 | success = False
294 | if self._boundary_reward_penalty:
295 | terminal = True
296 | reward = -0.05
297 | else:
298 | terminal = False
299 | if not self._shaped_rewards:
300 | reward = float(self._ep_success)
301 | else:
302 | reward = self._prev_reward
303 |
304 | _obs.joint_velocities = None
305 | _obs.joint_positions = None
306 | _obs.task_low_dim_state = None
307 |
308 | obs = {
309 | "reward": reward,
310 | "is_first": False,
311 | "is_last": terminal,
312 | "is_terminal": terminal,
313 | "success": success,
314 | "state": _obs.get_low_dim_data(),
315 | 'lang_num': 0
316 | }
317 | images = []
318 | for key in self._camera_keys:
319 | if key == "image_front":
320 | images.append(_obs.front_rgb)
321 | if key == "image_wrist":
322 | images.append(_obs.wrist_rgb)
323 | if key == "image_overhead":
324 | images.append(_obs.overhead_rgb)
325 | obs["image"] = np.concatenate(images, axis=-2)
326 | if self._use_lang_embeddings and self.finetune_lang_encoding is not None:
327 | obs['lang_embedding'] = self.finetune_lang_encoding
328 | self._time_step += 1
329 | return obs
330 |
331 | def reset(self):
332 | self.lang = random.choice(self.langs)
333 | self._task._task.change_reward(self.lang)
334 | _, _obs = self._task.reset()
335 | print(f"Reset in env {self.task_name}.")
336 | self._prev_obs = _obs
337 | self._init_pose = copy.deepcopy(
338 | self._task._task.robot.arm.get_tip().get_pose()[:3]
339 | )
340 | self._time_step = 0
341 | self._ep_success = False
342 |
343 | _obs.joint_velocities = None
344 | _obs.joint_positions = None
345 | _obs.task_low_dim_state = None
346 |
347 | obs = {
348 | "reward": 0.0,
349 | "is_first": True,
350 | "is_last": False,
351 | "is_terminal": False,
352 | "success": False,
353 | "state": _obs.get_low_dim_data(),
354 | 'lang_num': 0
355 | }
356 | images = []
357 | for key in self._camera_keys:
358 | if key == "image_front":
359 | images.append(_obs.front_rgb)
360 | if key == "image_wrist":
361 | images.append(_obs.wrist_rgb)
362 | if key == "image_overhead":
363 | images.append(_obs.overhead_rgb)
364 | obs["image"] = np.concatenate(images, axis=-2)
365 | if self._use_lang_embeddings and self.finetune_lang_encoding is not None:
366 | obs['lang_embedding'] = self.finetune_lang_encoding
367 | return obs
368 |
369 | def change_lang_encoding(self, lang_encoding):
370 | self.finetune_lang_encoding = lang_encoding
371 |
372 | class VidLangRLBench(RLBench):
373 | def __init__(self, name, langs, camera_keys, size=(64, 64), actions_min_max=None,
374 | shaped_rewards=False, lang_to_num=None, lang_to_encoding=None,
375 | use_lang_embeddings=False, boundary_reward_penalty=False,
376 | curriculum=None, lang_curriculum=None, synonym_dict=None,
377 | randomize=False):
378 | self.lang_to_num = lang_to_num
379 | self.lang_to_encoding = lang_to_encoding
380 | self.langs = langs
381 | self.init_langs = copy.deepcopy(langs)
382 | self.episode_count = 0 # maybe not the most accurate, not sure
383 | if curriculum is not None:
384 | self.use_curriculum = curriculum.use
385 | self.curriculum = curriculum
386 | n_episodes = [int(x) for x in curriculum.num_episodes.split("|")]
387 | # prefix sum
388 | self.curriculum_benchmarks = [0 for _ in range(len(n_episodes))]
389 | for i in range(len(n_episodes)):
390 | for j in range(i):
391 | # technically the last_n_episodes doesn't matter
392 | # env will keep running until config.steps
393 | self.curriculum_benchmarks[i] += n_episodes[j]
394 | self.lang_prompts = lang_curriculum
395 | self.synonym_dict = synonym_dict
396 |
397 | self.objects = curriculum.objects.split("|")
398 | for i in range(len(self.objects)):
399 | self.objects[i] = self.objects[i].split(",")
400 | self.num_objects = curriculum.num_objects.split("|")
401 | self.num_objects = [int(x) for x in self.num_objects]
402 | self.num_unique_per_class = curriculum.num_unique_per_class.split("|")
403 | self.num_unique_per_class = [int(x) for x in self.num_unique_per_class]
404 |
405 | super(VidLangRLBench, self).__init__(langs, name, camera_keys, size, actions_min_max, shaped_rewards, use_lang_embeddings, boundary_reward_penalty, randomize)
406 |
407 | def reset_curriculum(self):
408 | # todo: this only works for the shapenet env for now.
409 | # find index of self.episode_count in the list self.curriculum_benchmarks, if it exists
410 | if self.episode_count in self.curriculum_benchmarks:
411 | index = self.curriculum_benchmarks.index(self.episode_count)
412 | self.init_langs, self.langs = self.lang_prompts[index].copy(), self.lang_prompts[index].copy()
413 | self._task._task.reset_samplers(self.objects[index], self.objects[index], self.num_unique_per_class[index], self.num_unique_per_class[index])
414 | self._task._task.set_num_objects(self.num_objects[index])
415 | print(f"[curriculum] lang prompts: {self.langs}")
416 | print(f"[curriculum] objects: {self.num_objects[index]} of {self.objects[index]}")
417 |
418 | def reset(self):
419 | if self.use_curriculum:
420 | self.reset_curriculum()
421 | if self.init_langs:
422 | self.lang = random.choice(self.init_langs)
423 | self.init_langs.remove(self.lang)
424 | else:
425 | self.lang = random.choice(self.langs)
426 |
427 | time_step = super(VidLangRLBench, self).reset()
428 | if "[NOUN]" in self.lang:
429 | curr_obj = self._task._task.bin_objects_meta[0]
430 | if isinstance(self.synonym_dict[curr_obj], list):
431 | synonym = random.choice(self.synonym_dict[curr_obj])
432 | else:
433 | synonym = self.synonym_dict[curr_obj]
434 | self.lang = self.lang.replace("[NOUN]", synonym)
435 | self._task._task.change_reward(self.lang)
436 |
437 | lang_num = self.lang_to_num[self.lang]
438 | print(f"Collecting for {self.lang} language instruction.")
439 | print(f"Reward is for {self._task._task.reward_lang}.")
440 | vidlang_time_step = time_step.copy()
441 | vidlang_time_step['init_image'] = vidlang_time_step['image']
442 | vidlang_time_step['init_state'] = vidlang_time_step['state']
443 | vidlang_time_step['lang_num'] = lang_num
444 | if self._use_lang_embeddings:
445 | vidlang_time_step['lang_embedding'] = self.lang_to_encoding[self.lang]
446 | self._init_vidlang_time_step = vidlang_time_step
447 | self.episode_count += 1
448 | return vidlang_time_step
449 |
450 | def step(self, action):
451 | time_step = super(VidLangRLBench, self).step(action)
452 | lang_num = self.lang_to_num[self.lang]
453 | vidlang_time_step = time_step.copy()
454 | vidlang_time_step['init_image'] = self._init_vidlang_time_step['init_image']
455 | vidlang_time_step['init_state'] = self._init_vidlang_time_step['init_state']
456 | vidlang_time_step['lang_num'] = lang_num
457 | if self._use_lang_embeddings:
458 | vidlang_time_step['lang_embedding'] = self.lang_to_encoding[self.lang]
459 | return vidlang_time_step
460 |
461 | class MultiTaskRLBench(RLBench):
462 | def __init__(self, task_names, camera_keys, size=(64, 64), actions_min_max=None, shaped_rewards=False, lang_to_num=None, lang_to_encoding=None, use_lang_embeddings=False, boundary_reward_penalty=False):
463 | self.task_names = task_names
464 | self.init_tasks = self.task_names.copy()
465 | self.task_name = random.choice(self.init_langs)
466 |
467 | super(MultiTaskRLBench, self).__init__(self.task_name, camera_keys, size, actions_min_max, shaped_rewards, use_lang_embeddings, boundary_reward_penalty)
468 | self.name_to_class = {}
469 | for t in self.task_names:
470 | self.name_to_class[t] = name_to_task_class(t)
471 |
472 | def reset(self):
473 | if self.init_tasks:
474 | self.task_name = random.choice(self.init_tasks)
475 | self.init_langs.remove(self.task_name)
476 | else:
477 | self.task_name = random.choice(self.task_names)
478 | self._task = self._env.get_task(self.name_to_class[self.task_name])
479 | return super(MultiTaskRLBench, self).reset()
480 |
481 | class MultiTaskVidLangRLBench(MultiTaskRLBench):
482 | def __init__(self, task_names, camera_keys, size=(64, 64), actions_min_max=None, shaped_rewards=False, lang_to_num=None, lang_to_encoding=None, use_lang_embeddings=False, boundary_reward_penalty=False):
483 | self.lang_to_num = lang_to_num
484 | self.lang_to_encoding = lang_to_encoding
485 | super(MultiTaskVidLangRLBench, self).__init__(task_names, camera_keys, size, actions_min_max, shaped_rewards, use_lang_embeddings, boundary_reward_penalty)
486 |
487 | def reset(self):
488 | time_step = super(MultiTaskVidLangRLBench, self).reset()
489 | lang_num = self.lang_to_num[self.task_name]
490 | vidlang_time_step = time_step.copy()
491 | vidlang_time_step['init_image'] = vidlang_time_step['image']
492 | vidlang_time_step['init_state'] = vidlang_time_step['state']
493 | vidlang_time_step['lang_num'] = lang_num
494 | self._init_vidlang_time_step = vidlang_time_step
495 | if self._use_lang_embeddings:
496 | vidlang_time_step['lang_embedding'] = self.lang_to_encoding[self.lang]
497 | return vidlang_time_step
498 |
499 | def step(self, action):
500 | time_step = super(MultiTaskVidLangRLBench, self).step(action)
501 | lang_num = self.lang_to_num[self.task_name]
502 | vidlang_time_step = time_step.copy()
503 | vidlang_time_step['init_image'] = self._init_vidlang_time_step['init_image']
504 | vidlang_time_step['init_state'] = self._init_vidlang_time_step['init_state']
505 | vidlang_time_step['lang_num'] = lang_num
506 | if self._use_lang_embeddings:
507 | vidlang_time_step['lang_embedding'] = self.lang_to_encoding[self.lang]
508 | return vidlang_time_step
509 |
510 | class TimeLimit:
511 | def __init__(self, env, duration):
512 | self._env = env
513 | self._duration = duration
514 | self._step = None
515 |
516 | def __getattr__(self, name):
517 | if name.startswith("__"):
518 | raise AttributeError(name)
519 | try:
520 | return getattr(self._env, name)
521 | except AttributeError:
522 | raise ValueError(name)
523 |
524 | def step(self, action):
525 | assert self._step is not None, "Must reset environment."
526 | obs = self._env.step(action)
527 | self._step += 1
528 | if self._duration and self._step >= self._duration:
529 | obs["is_last"] = True
530 | self._step = None
531 | return obs
532 |
533 | def reset(self):
534 | self._step = 0
535 | return self._env.reset()
536 |
537 |
538 | class ResizeImage:
539 | def __init__(self, env, size=(64, 64)):
540 | self._env = env
541 | self._size = size
542 | self._keys = [
543 | k
544 | for k, v in env.obs_space.items()
545 | if len(v.shape) > 1 and v.shape[:2] != size
546 | ]
547 | print(f'Resizing keys {",".join(self._keys)} to {self._size}.')
548 | if self._keys:
549 | from PIL import Image
550 |
551 | self._Image = Image
552 |
553 | def __getattr__(self, name):
554 | if name.startswith("__"):
555 | raise AttributeError(name)
556 | try:
557 | return getattr(self._env, name)
558 | except AttributeError:
559 | raise ValueError(name)
560 |
561 | @property
562 | def obs_space(self):
563 | spaces = self._env.obs_space
564 | for key in self._keys:
565 | shape = self._size + spaces[key].shape[2:]
566 | spaces[key] = gym.spaces.Box(0, 255, shape, np.uint8)
567 | return spaces
568 |
569 | def step(self, action):
570 | obs = self._env.step(action)
571 | for key in self._keys:
572 | obs[key] = self._resize(obs[key])
573 | return obs
574 |
575 | def reset(self):
576 | obs = self._env.reset()
577 | for key in self._keys:
578 | obs[key] = self._resize(obs[key])
579 | return obs
580 |
581 | def _resize(self, image):
582 | image = self._Image.fromarray(image)
583 | image = image.resize(self._size, self._Image.NEAREST)
584 | image = np.array(image)
585 | return image
586 |
587 |
588 | class RenderImage:
589 | def __init__(self, env, key="image"):
590 | self._env = env
591 | self._key = key
592 | self._shape = self._env.render().shape
593 |
594 | def __getattr__(self, name):
595 | if name.startswith("__"):
596 | raise AttributeError(name)
597 | try:
598 | return getattr(self._env, name)
599 | except AttributeError:
600 | raise ValueError(name)
601 |
602 | @property
603 | def obs_space(self):
604 | spaces = self._env.obs_space
605 | spaces[self._key] = gym.spaces.Box(0, 255, self._shape, np.uint8)
606 | return spaces
607 |
608 | def step(self, action):
609 | obs = self._env.step(action)
610 | obs[self._key] = self._env.render("rgb_array")
611 | return obs
612 |
613 | def reset(self):
614 | obs = self._env.reset()
615 | obs[self._key] = self._env.render("rgb_array")
616 | return obs
617 |
618 |
619 | class Async:
620 |
621 | # Message types for communication via the pipe.
622 | _ACCESS = 1
623 | _CALL = 2
624 | _RESULT = 3
625 | _CLOSE = 4
626 | _EXCEPTION = 5
627 |
628 | def __init__(self, constructor, strategy="thread"):
629 | self._pickled_ctor = cloudpickle.dumps(constructor)
630 | if strategy == "process":
631 | import multiprocessing as mp
632 |
633 | context = mp.get_context("spawn")
634 | elif strategy == "thread":
635 | import multiprocessing.dummy as context
636 | else:
637 | raise NotImplementedError(strategy)
638 | self._strategy = strategy
639 | self._conn, conn = context.Pipe()
640 | self._process = context.Process(target=self._worker, args=(conn,))
641 | atexit.register(self.close)
642 | self._process.start()
643 | self._receive() # Ready.
644 | self._obs_space = None
645 | self._act_space = None
646 |
647 | def access(self, name):
648 | self._conn.send((self._ACCESS, name))
649 | return self._receive
650 |
651 | def call(self, name, *args, **kwargs):
652 | payload = name, args, kwargs
653 | self._conn.send((self._CALL, payload))
654 | return self._receive
655 |
656 | def close(self):
657 | try:
658 | self._conn.send((self._CLOSE, None))
659 | self._conn.close()
660 | except IOError:
661 | pass # The connection was already closed.
662 | self._process.join(5)
663 |
664 | @property
665 | def obs_space(self):
666 | if not self._obs_space:
667 | self._obs_space = self.access("obs_space")()
668 | return self._obs_space
669 |
670 | @property
671 | def act_space(self):
672 | if not self._act_space:
673 | self._act_space = self.access("act_space")()
674 | return self._act_space
675 |
676 | def step(self, action, blocking=False):
677 | promise = self.call("step", action)
678 | if blocking:
679 | return promise()
680 | else:
681 | return promise
682 |
683 | def reset(self, blocking=False):
684 | promise = self.call("reset")
685 | if blocking:
686 | return promise()
687 | else:
688 | return promise
689 |
690 | def _receive(self):
691 | try:
692 | message, payload = self._conn.recv()
693 | except (OSError, EOFError):
694 | raise RuntimeError("Lost connection to environment worker.")
695 | # Re-raise exceptions in the main process.
696 | if message == self._EXCEPTION:
697 | stacktrace = payload
698 | raise Exception(stacktrace)
699 | if message == self._RESULT:
700 | return payload
701 | raise KeyError("Received message of unexpected type {}".format(message))
702 |
703 | def _worker(self, conn):
704 | try:
705 | ctor = cloudpickle.loads(self._pickled_ctor)
706 | env = ctor()
707 | conn.send((self._RESULT, None)) # Ready.
708 | while True:
709 | try:
710 | # Only block for short times to have keyboard exceptions be raised.
711 | if not conn.poll(0.1):
712 | continue
713 | message, payload = conn.recv()
714 | except (EOFError, KeyboardInterrupt):
715 | break
716 | if message == self._ACCESS:
717 | name = payload
718 | result = getattr(env, name)
719 | conn.send((self._RESULT, result))
720 | continue
721 | if message == self._CALL:
722 | name, args, kwargs = payload
723 | result = getattr(env, name)(*args, **kwargs)
724 | conn.send((self._RESULT, result))
725 | continue
726 | if message == self._CLOSE:
727 | break
728 | raise KeyError("Received message of unknown type {}".format(message))
729 | except Exception:
730 | stacktrace = "".join(traceback.format_exception(*sys.exc_info()))
731 | print("Error in environment process: {}".format(stacktrace))
732 | conn.send((self._EXCEPTION, stacktrace))
733 | finally:
734 | try:
735 | conn.close()
736 | except IOError:
737 | pass # The connection was already closed.
738 |
--------------------------------------------------------------------------------