├── .gitignore ├── prompts ├── single_instruction.txt ├── similar_noun │ ├── synonym_bowl.txt │ ├── synonym_pot.txt │ ├── synonym_laptop.txt │ ├── synonym_knife.txt │ ├── synonym_earphone.txt │ ├── synonym_cap.txt │ ├── synonym_jar.txt │ ├── synonym_faucet.txt │ ├── synonym_telephone.txt │ ├── synonym_mug.txt │ └── synonym_bag.txt ├── phone_on_base.txt ├── mug_40.txt ├── general_tasks_40.txt ├── rand_verb_40.txt ├── hamlet.txt ├── similar_verb_40.txt ├── rand_noun │ ├── synonym_bag.txt │ ├── synonym_cap.txt │ ├── synonym_jar.txt │ ├── synonym_mug.txt │ ├── synonym_pot.txt │ ├── synonym_bowl.txt │ ├── synonym_earphone.txt │ ├── synonym_faucet.txt │ ├── synonym_knife.txt │ ├── synonym_laptop.txt │ └── synonym_telephone.txt ├── human_mug_statement_40.txt ├── jar.txt ├── robot_mug_statement_40.txt └── mug.txt ├── method.png ├── common ├── synonym_remote.txt ├── assets │ └── textures │ │ ├── floor │ │ ├── carpet.png │ │ ├── forest.png │ │ └── rubber_mats.png │ │ ├── table │ │ ├── puzzle.png │ │ ├── stove.png │ │ ├── stove2.png │ │ └── stove3.png │ │ └── wall │ │ ├── cabinets.png │ │ ├── fridge.png │ │ ├── posters.png │ │ ├── wardrobe.png │ │ ├── white_wall.png │ │ ├── kitchen_hanger.png │ │ ├── kitchen_hanger2.png │ │ └── kitchen_hanger3.png ├── __init__.py ├── counter.py ├── when.py ├── rnd.py ├── driver.py ├── mae_utils.py ├── flags.py ├── base_envs.py ├── rlbench_utils.py ├── expl.py ├── dists.py ├── logger.py ├── tfutils.py ├── config.py ├── kitchen.py ├── other.py ├── replay.py ├── nets.py └── envs.py ├── LICENSE ├── r3mreward.py ├── README.md ├── env.yml ├── configs.yaml └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | -------------------------------------------------------------------------------- /prompts/single_instruction.txt: -------------------------------------------------------------------------------- 1 | Pick up the [NOUN] -------------------------------------------------------------------------------- /method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/method.png -------------------------------------------------------------------------------- /prompts/similar_noun/synonym_bowl.txt: -------------------------------------------------------------------------------- 1 | bowl 2 | dish 3 | soup bowl 4 | wooden bowl 5 | plastic bowl -------------------------------------------------------------------------------- /common/synonym_remote.txt: -------------------------------------------------------------------------------- 1 | remote control 2 | remote 3 | clicker 4 | controller 5 | TV remote 6 | video remote -------------------------------------------------------------------------------- /prompts/similar_noun/synonym_pot.txt: -------------------------------------------------------------------------------- 1 | pot 2 | flowerpot 3 | vase 4 | container 5 | plant pot 6 | flower holder -------------------------------------------------------------------------------- /prompts/similar_noun/synonym_laptop.txt: -------------------------------------------------------------------------------- 1 | laptop 2 | notebook 3 | ultrabook 4 | netbook 5 | tablet 6 | chromebook 7 | computer -------------------------------------------------------------------------------- /prompts/similar_noun/synonym_knife.txt: -------------------------------------------------------------------------------- 1 | knife 2 | blade 3 | cutter 4 | chopper 5 | cleaver 6 | pocket knife 7 | Swiss army knife -------------------------------------------------------------------------------- /common/assets/textures/floor/carpet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/floor/carpet.png -------------------------------------------------------------------------------- /common/assets/textures/floor/forest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/floor/forest.png -------------------------------------------------------------------------------- /common/assets/textures/table/puzzle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/table/puzzle.png -------------------------------------------------------------------------------- /common/assets/textures/table/stove.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/table/stove.png -------------------------------------------------------------------------------- /common/assets/textures/table/stove2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/table/stove2.png -------------------------------------------------------------------------------- /common/assets/textures/table/stove3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/table/stove3.png -------------------------------------------------------------------------------- /common/assets/textures/wall/cabinets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/wall/cabinets.png -------------------------------------------------------------------------------- /common/assets/textures/wall/fridge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/wall/fridge.png -------------------------------------------------------------------------------- /common/assets/textures/wall/posters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/wall/posters.png -------------------------------------------------------------------------------- /common/assets/textures/wall/wardrobe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/wall/wardrobe.png -------------------------------------------------------------------------------- /common/assets/textures/wall/white_wall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/wall/white_wall.png -------------------------------------------------------------------------------- /common/assets/textures/floor/rubber_mats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/floor/rubber_mats.png -------------------------------------------------------------------------------- /common/assets/textures/wall/kitchen_hanger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/wall/kitchen_hanger.png -------------------------------------------------------------------------------- /common/assets/textures/wall/kitchen_hanger2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/wall/kitchen_hanger2.png -------------------------------------------------------------------------------- /common/assets/textures/wall/kitchen_hanger3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ademiadeniji/lamp/HEAD/common/assets/textures/wall/kitchen_hanger3.png -------------------------------------------------------------------------------- /prompts/similar_noun/synonym_earphone.txt: -------------------------------------------------------------------------------- 1 | earphone 2 | headphone 3 | earbud 4 | in-ear monitor 5 | earpiece 6 | earphone set 7 | earphone accessory -------------------------------------------------------------------------------- /prompts/similar_noun/synonym_cap.txt: -------------------------------------------------------------------------------- 1 | cap 2 | hat 3 | snapback 4 | baseball cap 5 | visor 6 | bucket hat 7 | fedora 8 | cowboy hat 9 | knit cap 10 | nightcap -------------------------------------------------------------------------------- /prompts/similar_noun/synonym_jar.txt: -------------------------------------------------------------------------------- 1 | jar 2 | vase 3 | container 4 | bottle 5 | canister 6 | pot 7 | urn 8 | jug 9 | flask 10 | pitcher 11 | bucket 12 | pail 13 | basket -------------------------------------------------------------------------------- /prompts/similar_noun/synonym_faucet.txt: -------------------------------------------------------------------------------- 1 | faucet 2 | tap 3 | valve 4 | spout 5 | handle 6 | lever 7 | knob 8 | spigot 9 | water dispenser 10 | sink faucet 11 | bathtub faucet 12 | drinking fountain 13 | water faucet 14 | mixing faucet -------------------------------------------------------------------------------- /prompts/similar_noun/synonym_telephone.txt: -------------------------------------------------------------------------------- 1 | telephone 2 | phone 3 | mobile phone 4 | cell phone 5 | landline 6 | cordless phone 7 | smart phone 8 | flip phone 9 | touchscreen phone 10 | wireless phone 11 | home phone 12 | office phonee -------------------------------------------------------------------------------- /prompts/similar_noun/synonym_mug.txt: -------------------------------------------------------------------------------- 1 | mug 2 | cup 3 | tumbler 4 | teacup 5 | coffee cup 6 | espresso cup 7 | beer mug 8 | wine glass 9 | goblet 10 | flute 11 | martini glass 12 | punch cup 13 | thermal mug 14 | plastic cup 15 | styrofoam cup 16 | paper cup -------------------------------------------------------------------------------- /prompts/similar_noun/synonym_bag.txt: -------------------------------------------------------------------------------- 1 | bag 2 | handbag 3 | purse 4 | clutch 5 | tote 6 | backpack 7 | knapsack 8 | satchel 9 | shoulder bag 10 | duffel bag 11 | messenger bag 12 | grip 13 | briefcase 14 | pouch 15 | fanny pack 16 | drawstring bag 17 | beach bag 18 | grocery bag 19 | shopping bag 20 | gift bag 21 | lunch bag 22 | laptop bag 23 | travel bag -------------------------------------------------------------------------------- /prompts/phone_on_base.txt: -------------------------------------------------------------------------------- 1 | phone on base 2 | pick up the phone and place it on the base 3 | the robot gripper picks up the phone and places it on the base 4 | reach toward the phone and move it to the base 5 | the robot arm grasps the phone on sets it down 6 | put the phone on the base 7 | the robot is placing the phone on the box 8 | seize the phone and drop it on the box 9 | robot arm grasped the phone and dropped it on the square base 10 | the arm is picking up the phone and placing it on the base 11 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- 1 | # General tools. 2 | from .config import * 3 | from .counter import * 4 | from .flags import * 5 | from .logger import * 6 | from .when import * 7 | 8 | # RL tools. 9 | from .other import * 10 | from .driver import * 11 | from .rlbench_utils import * 12 | from .envs import * 13 | # from .kitchen import * 14 | from .replay import * 15 | 16 | # TensorFlow tools. 17 | from .tfutils import * 18 | from .dists import * 19 | 20 | from .nets import * 21 | from .mae_utils import * 22 | from .mae import * 23 | from .expl import * 24 | from .rnd import * -------------------------------------------------------------------------------- /common/counter.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | 4 | @functools.total_ordering 5 | class Counter: 6 | def __init__(self, initial=0): 7 | self.value = initial 8 | 9 | def __int__(self): 10 | return int(self.value) 11 | 12 | def __eq__(self, other): 13 | return int(self) == other 14 | 15 | def __ne__(self, other): 16 | return int(self) != other 17 | 18 | def __lt__(self, other): 19 | return int(self) < other 20 | 21 | def __add__(self, other): 22 | return int(self) + other 23 | 24 | def increment(self, amount=1): 25 | self.value += amount 26 | -------------------------------------------------------------------------------- /prompts/mug_40.txt: -------------------------------------------------------------------------------- 1 | Grasp the mug 2 | Extend your hand for the mug 3 | Retrieve the mug 4 | Take hold of the mug 5 | Stretch out for the mug 6 | Clutch the mug 7 | Seize the mug 8 | Lay your hands on the mug 9 | Pick up the mug 10 | Get a grip on the mug 11 | Reach out for the cup 12 | Grip the cup 13 | Hold the mug 14 | Stretch for the mug 15 | Snatch the mug 16 | Embrace the mug 17 | Catch the mug 18 | Grapple for the mug 19 | Take the mug 20 | Secure the mug 21 | Cling to the mug 22 | Obtain the mug 23 | Acquire the mug 24 | Clasp the mug 25 | Nudge the mug 26 | Bring the mug closer 27 | Draw the mug nearer 28 | Gather the mug 29 | Gain possession of the mug 30 | Lay hold of the mug 31 | Retrieve the cup 32 | Clutch the cup 33 | Hold onto the cup 34 | Reach out for the beverage holder 35 | Extend your arm for the coffee mug 36 | Grab the mug 37 | Obtain the coffee vessel 38 | Snag the mug 39 | Get a hold of the mug 40 | Hook the mug -------------------------------------------------------------------------------- /common/when.py: -------------------------------------------------------------------------------- 1 | class Every: 2 | def __init__(self, every): 3 | self._every = every 4 | self._last = None 5 | 6 | def __call__(self, step): 7 | step = int(step) 8 | if not self._every: 9 | return False 10 | if self._last is None: 11 | self._last = step 12 | return True 13 | if step >= self._last + self._every: 14 | self._last += self._every 15 | return True 16 | return False 17 | 18 | 19 | class Once: 20 | def __init__(self): 21 | self._once = True 22 | 23 | def __call__(self): 24 | if self._once: 25 | self._once = False 26 | return True 27 | return False 28 | 29 | 30 | class Until: 31 | def __init__(self, until): 32 | self._until = until 33 | 34 | def __call__(self, step): 35 | step = int(step) 36 | if not self._until: 37 | return True 38 | return step < self._until 39 | -------------------------------------------------------------------------------- /prompts/general_tasks_40.txt: -------------------------------------------------------------------------------- 1 | Close the window blinds 2 | Set the kitchen timer 3 | Sweep the kitchen floor 4 | Turn off the stove 5 | Dust the bookshelf 6 | Empty the kitchen trash can 7 | Arrange the spices alphabetically 8 | Turn on the kitchen fan 9 | Polish the kitchen sink faucet 10 | Put away clean dishes 11 | Take out the recycling 12 | Fluff the couch pillows 13 | Straighten the curtains 14 | Check the expiration date on the milk 15 | Replace the hand soap 16 | Sweep the front porch 17 | Take a sip of water 18 | Put away your shoes 19 | Change the tablecloth 20 | Feed the cat 21 | Take a deep breath 22 | Wipe down the kitchen counters 23 | Replace the batteries in the smoke detector 24 | Water the indoor plants 25 | Wash your hands 26 | Straighten the picture frames 27 | Refill the salt shaker 28 | Unload the dishwasher 29 | Turn on the kitchen lights 30 | Organize the fridge contents 31 | Take a quick snack break 32 | Tidy up the entryway 33 | Replace the dish towel 34 | Dust the lampshade 35 | Sweep the outdoor patio 36 | Put away the cereal box 37 | Check the weather forecast 38 | Test the oven temperature 39 | Wipe the mirror clean 40 | Take a moment to stretch -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 Ademi Adeniji, Amber Xie, Carmelo Sferrazza, Younggyo Seo, Stephen James, Pieter Abbeel 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /prompts/rand_verb_40.txt: -------------------------------------------------------------------------------- 1 | The [NOUN] is seized 2 | The [NOUN] is clutched 3 | The [NOUN] is gripped 4 | The [NOUN] is firmly grasped 5 | The [NOUN] is tightly held 6 | The [NOUN] is firmly caught 7 | The [NOUN] is securely clasped 8 | The [NOUN] is rotated 9 | The [NOUN] has been flipped 10 | The [NOUN] has been knotted 11 | The [NOUN] has been folded 12 | The [NOUN] has been rinsed 13 | The [NOUN] has been filled 14 | The [NOUN] is shaken 15 | The [NOUN] has been scooped 16 | The [NOUN] is poured 17 | The [NOUN] has been scrubbed 18 | The [NOUN] is tilted 19 | The [NOUN] has been heated 20 | Reach for the [NOUN] 21 | Grasp at the [NOUN] 22 | Stretch out to touch the [NOUN] 23 | Move your arm towards the [NOUN] 24 | Use the gripper to rinse the [NOUN] 25 | Position the end effector to fold the [NOUN] 26 | Reach out the robotic arm to wipe the [NOUN] 27 | Utilize the gripper to seize the [NOUN] 28 | Guide the robotic arm to obtain the [NOUN] 29 | Maneuver the end effector to lift up the [NOUN] 30 | Extend your hand towards the [NOUN] 31 | Reach out your hand to acquire the [NOUN] 32 | Guide your arm to rotate the [NOUN] 33 | Maneuver your hand to shake up the [NOUN] 34 | Flip the [NOUN] 35 | Tap the [NOUN] 36 | Fold the [NOUN] 37 | Rotate the [NOUN] 38 | Brush the [NOUN] 39 | Twist the [NOUN] 40 | Wipe the [NOUN] -------------------------------------------------------------------------------- /prompts/hamlet.txt: -------------------------------------------------------------------------------- 1 | Holla, Barnardo. 2 | BARNARDO Say, what, is Horatio there? 3 | HORATIO A piece of him. 4 | BARNARDO 5 | Welcome, Horatio.—Welcome, good Marcellus. 6 | HORATIO 7 | What, has this thing appeared again tonight? 8 | BARNARDO I have seen nothing. 9 | MARCELLUS 10 | Horatio says ’tis but our fantasy 11 | And will not let belief take hold of him 12 | Touching this dreaded sight twice seen of us. 13 | Therefore I have entreated him along 14 | With us to watch the minutes of this night, 15 | That, if again this apparition come, 16 | He may approve our eyes and speak to it. 17 | Tush, tush, ’twill not appear. 18 | BARNARDO Sit down awhile, 19 | How now, Horatio, you tremble and look pale. 20 | Is not this something more than fantasy? 21 | What think you on ’t? 22 | At least the whisper goes so: our last king, 23 | Whose image even but now appeared to us, 24 | Was, as you know, by Fortinbras of Norway, 25 | Thereto pricked on by a most emulate pride, 26 | Dared to the combat; in which our valiant Hamlet 27 | (For so this side of our known world esteemed him) 28 | Did slay this Fortinbras, who by a sealed compact, 29 | Well ratified by law and heraldry, 30 | Did forfeit, with his life, all those his lands 31 | Which he stood seized of, to the conqueror. 32 | Against the which a moiety competent 33 | Was gagèd by our king, which had returned 34 | To the inheritance of Fortinbras 35 | Had he been vanquisher, as, by the same comart 36 | And carriage of the article designed, 37 | His fell to Hamlet. Now, sir, young Fortinbras, 38 | Of unimprovèd mettle hot and full, 39 | Hath in the skirts of Norway here and there 40 | Sharked up a list of lawless resolutes -------------------------------------------------------------------------------- /prompts/similar_verb_40.txt: -------------------------------------------------------------------------------- 1 | Pick up the [NOUN] 2 | Lift the [NOUN] with your hands 3 | Hold the [NOUN] in your grasp 4 | Take hold of the [NOUN] and raise it 5 | Grasp the [NOUN] firmly and lift it up 6 | Raise the [NOUN] by picking it up 7 | Retrieve the [NOUN] and hold it up 8 | Lift the [NOUN] by gripping it 9 | Seize the [NOUN] and raise it off the surface 10 | Hold onto the [NOUN] and lift it up 11 | The [NOUN] is lifted up 12 | The [NOUN] is picked up off the ground 13 | The [NOUN] is raised up by hand 14 | The [NOUN] is grasped and lifted up 15 | The [NOUN] is taken up by hand 16 | The [NOUN] is retrieved and lifted up 17 | The [NOUN] is lifted off its surface 18 | The [NOUN] is elevated by being picked up 19 | The [NOUN] is hoisted up by hand 20 | The [NOUN] is scooped up and lifted 21 | The [NOUN] is lifted by the hand 22 | The [NOUN] is grasped and picked up 23 | The [NOUN] is raised by the palm 24 | The [NOUN] is taken up by the fingers 25 | The [NOUN] is held and lifted up 26 | The [NOUN] is lifted off the surface by the arm 27 | The [NOUN] is picked up and held by the wrist 28 | The [NOUN] is scooped up by the palm and lifted 29 | The [NOUN] is elevated by the hand 30 | The [NOUN] is taken up by the fingers of the hand 31 | The [NOUN] is grasped and raised 32 | The [NOUN] is lifted by the gripper 33 | The end effector picks up the [NOUN] 34 | The arm lifts the [NOUN] 35 | The [NOUN] is held aloft by the robotic hand 36 | The robotic gripper secures the [NOUN] 37 | The [NOUN] is lifted off the surface by the robotic arm 38 | The robotic manipulator seizes and elevates the [NOUN] 39 | The robotic end effector clasps and hoists the [NOUN] 40 | The [NOUN] is taken up by the robotic gripper -------------------------------------------------------------------------------- /prompts/rand_noun/synonym_bag.txt: -------------------------------------------------------------------------------- 1 | bag 2 | handbag 3 | purse 4 | clutch 5 | tote 6 | backpack 7 | knapsack 8 | satchel 9 | shoulder bag 10 | duffel bag 11 | messenger bag 12 | grip 13 | briefcase 14 | pouch 15 | fanny pack 16 | drawstring bag 17 | beach bag 18 | grocery bag 19 | shopping bag 20 | gift bag 21 | lunch bag 22 | laptop bag 23 | travel bag 24 | bowl 25 | dish 26 | soup bowl 27 | wooden bowl 28 | plastic bowl 29 | cap 30 | hat 31 | snapback 32 | baseball cap 33 | visor 34 | bucket hat 35 | fedora 36 | cowboy hat 37 | knit cap 38 | nightcap 39 | earphone 40 | headphone 41 | earbud 42 | in-ear monitor 43 | earpiece 44 | earphone set 45 | earphone accessory 46 | faucet 47 | tap 48 | valve 49 | spout 50 | handle 51 | lever 52 | knob 53 | spigot 54 | water dispenser 55 | sink faucet 56 | bathtub faucet 57 | drinking fountain 58 | water faucet 59 | mixing faucet 60 | jar 61 | vase 62 | container 63 | bottle 64 | canister 65 | pot 66 | urn 67 | jug 68 | flask 69 | pitcher 70 | bucket 71 | pail 72 | basket 73 | knife 74 | blade 75 | cutter 76 | chopper 77 | cleaver 78 | pocket knife 79 | Swiss army knife 80 | laptop 81 | notebook 82 | ultrabook 83 | netbook 84 | tablet 85 | chromebook 86 | computer 87 | mug 88 | cup 89 | tumbler 90 | teacup 91 | coffee cup 92 | espresso cup 93 | beer mug 94 | wine glass 95 | goblet 96 | flute 97 | martini glass 98 | punch cup 99 | thermal mug 100 | plastic cup 101 | styrofoam cup 102 | paper cup 103 | pot 104 | flowerpot 105 | vase 106 | container 107 | plant pot 108 | flower holder 109 | telephone 110 | phone 111 | mobile phone 112 | cell phone 113 | landline 114 | cordless phone 115 | smart phone 116 | flip phone 117 | touchscreen phone 118 | wireless phone 119 | home phone 120 | office phonee -------------------------------------------------------------------------------- /prompts/rand_noun/synonym_cap.txt: -------------------------------------------------------------------------------- 1 | bag 2 | handbag 3 | purse 4 | clutch 5 | tote 6 | backpack 7 | knapsack 8 | satchel 9 | shoulder bag 10 | duffel bag 11 | messenger bag 12 | grip 13 | briefcase 14 | pouch 15 | fanny pack 16 | drawstring bag 17 | beach bag 18 | grocery bag 19 | shopping bag 20 | gift bag 21 | lunch bag 22 | laptop bag 23 | travel bag 24 | bowl 25 | dish 26 | soup bowl 27 | wooden bowl 28 | plastic bowl 29 | cap 30 | hat 31 | snapback 32 | baseball cap 33 | visor 34 | bucket hat 35 | fedora 36 | cowboy hat 37 | knit cap 38 | nightcap 39 | earphone 40 | headphone 41 | earbud 42 | in-ear monitor 43 | earpiece 44 | earphone set 45 | earphone accessory 46 | faucet 47 | tap 48 | valve 49 | spout 50 | handle 51 | lever 52 | knob 53 | spigot 54 | water dispenser 55 | sink faucet 56 | bathtub faucet 57 | drinking fountain 58 | water faucet 59 | mixing faucet 60 | jar 61 | vase 62 | container 63 | bottle 64 | canister 65 | pot 66 | urn 67 | jug 68 | flask 69 | pitcher 70 | bucket 71 | pail 72 | basket 73 | knife 74 | blade 75 | cutter 76 | chopper 77 | cleaver 78 | pocket knife 79 | Swiss army knife 80 | laptop 81 | notebook 82 | ultrabook 83 | netbook 84 | tablet 85 | chromebook 86 | computer 87 | mug 88 | cup 89 | tumbler 90 | teacup 91 | coffee cup 92 | espresso cup 93 | beer mug 94 | wine glass 95 | goblet 96 | flute 97 | martini glass 98 | punch cup 99 | thermal mug 100 | plastic cup 101 | styrofoam cup 102 | paper cup 103 | pot 104 | flowerpot 105 | vase 106 | container 107 | plant pot 108 | flower holder 109 | telephone 110 | phone 111 | mobile phone 112 | cell phone 113 | landline 114 | cordless phone 115 | smart phone 116 | flip phone 117 | touchscreen phone 118 | wireless phone 119 | home phone 120 | office phonee -------------------------------------------------------------------------------- /prompts/rand_noun/synonym_jar.txt: -------------------------------------------------------------------------------- 1 | bag 2 | handbag 3 | purse 4 | clutch 5 | tote 6 | backpack 7 | knapsack 8 | satchel 9 | shoulder bag 10 | duffel bag 11 | messenger bag 12 | grip 13 | briefcase 14 | pouch 15 | fanny pack 16 | drawstring bag 17 | beach bag 18 | grocery bag 19 | shopping bag 20 | gift bag 21 | lunch bag 22 | laptop bag 23 | travel bag 24 | bowl 25 | dish 26 | soup bowl 27 | wooden bowl 28 | plastic bowl 29 | cap 30 | hat 31 | snapback 32 | baseball cap 33 | visor 34 | bucket hat 35 | fedora 36 | cowboy hat 37 | knit cap 38 | nightcap 39 | earphone 40 | headphone 41 | earbud 42 | in-ear monitor 43 | earpiece 44 | earphone set 45 | earphone accessory 46 | faucet 47 | tap 48 | valve 49 | spout 50 | handle 51 | lever 52 | knob 53 | spigot 54 | water dispenser 55 | sink faucet 56 | bathtub faucet 57 | drinking fountain 58 | water faucet 59 | mixing faucet 60 | jar 61 | vase 62 | container 63 | bottle 64 | canister 65 | pot 66 | urn 67 | jug 68 | flask 69 | pitcher 70 | bucket 71 | pail 72 | basket 73 | knife 74 | blade 75 | cutter 76 | chopper 77 | cleaver 78 | pocket knife 79 | Swiss army knife 80 | laptop 81 | notebook 82 | ultrabook 83 | netbook 84 | tablet 85 | chromebook 86 | computer 87 | mug 88 | cup 89 | tumbler 90 | teacup 91 | coffee cup 92 | espresso cup 93 | beer mug 94 | wine glass 95 | goblet 96 | flute 97 | martini glass 98 | punch cup 99 | thermal mug 100 | plastic cup 101 | styrofoam cup 102 | paper cup 103 | pot 104 | flowerpot 105 | vase 106 | container 107 | plant pot 108 | flower holder 109 | telephone 110 | phone 111 | mobile phone 112 | cell phone 113 | landline 114 | cordless phone 115 | smart phone 116 | flip phone 117 | touchscreen phone 118 | wireless phone 119 | home phone 120 | office phonee -------------------------------------------------------------------------------- /prompts/rand_noun/synonym_mug.txt: -------------------------------------------------------------------------------- 1 | bag 2 | handbag 3 | purse 4 | clutch 5 | tote 6 | backpack 7 | knapsack 8 | satchel 9 | shoulder bag 10 | duffel bag 11 | messenger bag 12 | grip 13 | briefcase 14 | pouch 15 | fanny pack 16 | drawstring bag 17 | beach bag 18 | grocery bag 19 | shopping bag 20 | gift bag 21 | lunch bag 22 | laptop bag 23 | travel bag 24 | bowl 25 | dish 26 | soup bowl 27 | wooden bowl 28 | plastic bowl 29 | cap 30 | hat 31 | snapback 32 | baseball cap 33 | visor 34 | bucket hat 35 | fedora 36 | cowboy hat 37 | knit cap 38 | nightcap 39 | earphone 40 | headphone 41 | earbud 42 | in-ear monitor 43 | earpiece 44 | earphone set 45 | earphone accessory 46 | faucet 47 | tap 48 | valve 49 | spout 50 | handle 51 | lever 52 | knob 53 | spigot 54 | water dispenser 55 | sink faucet 56 | bathtub faucet 57 | drinking fountain 58 | water faucet 59 | mixing faucet 60 | jar 61 | vase 62 | container 63 | bottle 64 | canister 65 | pot 66 | urn 67 | jug 68 | flask 69 | pitcher 70 | bucket 71 | pail 72 | basket 73 | knife 74 | blade 75 | cutter 76 | chopper 77 | cleaver 78 | pocket knife 79 | Swiss army knife 80 | laptop 81 | notebook 82 | ultrabook 83 | netbook 84 | tablet 85 | chromebook 86 | computer 87 | mug 88 | cup 89 | tumbler 90 | teacup 91 | coffee cup 92 | espresso cup 93 | beer mug 94 | wine glass 95 | goblet 96 | flute 97 | martini glass 98 | punch cup 99 | thermal mug 100 | plastic cup 101 | styrofoam cup 102 | paper cup 103 | pot 104 | flowerpot 105 | vase 106 | container 107 | plant pot 108 | flower holder 109 | telephone 110 | phone 111 | mobile phone 112 | cell phone 113 | landline 114 | cordless phone 115 | smart phone 116 | flip phone 117 | touchscreen phone 118 | wireless phone 119 | home phone 120 | office phonee -------------------------------------------------------------------------------- /prompts/rand_noun/synonym_pot.txt: -------------------------------------------------------------------------------- 1 | bag 2 | handbag 3 | purse 4 | clutch 5 | tote 6 | backpack 7 | knapsack 8 | satchel 9 | shoulder bag 10 | duffel bag 11 | messenger bag 12 | grip 13 | briefcase 14 | pouch 15 | fanny pack 16 | drawstring bag 17 | beach bag 18 | grocery bag 19 | shopping bag 20 | gift bag 21 | lunch bag 22 | laptop bag 23 | travel bag 24 | bowl 25 | dish 26 | soup bowl 27 | wooden bowl 28 | plastic bowl 29 | cap 30 | hat 31 | snapback 32 | baseball cap 33 | visor 34 | bucket hat 35 | fedora 36 | cowboy hat 37 | knit cap 38 | nightcap 39 | earphone 40 | headphone 41 | earbud 42 | in-ear monitor 43 | earpiece 44 | earphone set 45 | earphone accessory 46 | faucet 47 | tap 48 | valve 49 | spout 50 | handle 51 | lever 52 | knob 53 | spigot 54 | water dispenser 55 | sink faucet 56 | bathtub faucet 57 | drinking fountain 58 | water faucet 59 | mixing faucet 60 | jar 61 | vase 62 | container 63 | bottle 64 | canister 65 | pot 66 | urn 67 | jug 68 | flask 69 | pitcher 70 | bucket 71 | pail 72 | basket 73 | knife 74 | blade 75 | cutter 76 | chopper 77 | cleaver 78 | pocket knife 79 | Swiss army knife 80 | laptop 81 | notebook 82 | ultrabook 83 | netbook 84 | tablet 85 | chromebook 86 | computer 87 | mug 88 | cup 89 | tumbler 90 | teacup 91 | coffee cup 92 | espresso cup 93 | beer mug 94 | wine glass 95 | goblet 96 | flute 97 | martini glass 98 | punch cup 99 | thermal mug 100 | plastic cup 101 | styrofoam cup 102 | paper cup 103 | pot 104 | flowerpot 105 | vase 106 | container 107 | plant pot 108 | flower holder 109 | telephone 110 | phone 111 | mobile phone 112 | cell phone 113 | landline 114 | cordless phone 115 | smart phone 116 | flip phone 117 | touchscreen phone 118 | wireless phone 119 | home phone 120 | office phonee -------------------------------------------------------------------------------- /prompts/rand_noun/synonym_bowl.txt: -------------------------------------------------------------------------------- 1 | bag 2 | handbag 3 | purse 4 | clutch 5 | tote 6 | backpack 7 | knapsack 8 | satchel 9 | shoulder bag 10 | duffel bag 11 | messenger bag 12 | grip 13 | briefcase 14 | pouch 15 | fanny pack 16 | drawstring bag 17 | beach bag 18 | grocery bag 19 | shopping bag 20 | gift bag 21 | lunch bag 22 | laptop bag 23 | travel bag 24 | bowl 25 | dish 26 | soup bowl 27 | wooden bowl 28 | plastic bowl 29 | cap 30 | hat 31 | snapback 32 | baseball cap 33 | visor 34 | bucket hat 35 | fedora 36 | cowboy hat 37 | knit cap 38 | nightcap 39 | earphone 40 | headphone 41 | earbud 42 | in-ear monitor 43 | earpiece 44 | earphone set 45 | earphone accessory 46 | faucet 47 | tap 48 | valve 49 | spout 50 | handle 51 | lever 52 | knob 53 | spigot 54 | water dispenser 55 | sink faucet 56 | bathtub faucet 57 | drinking fountain 58 | water faucet 59 | mixing faucet 60 | jar 61 | vase 62 | container 63 | bottle 64 | canister 65 | pot 66 | urn 67 | jug 68 | flask 69 | pitcher 70 | bucket 71 | pail 72 | basket 73 | knife 74 | blade 75 | cutter 76 | chopper 77 | cleaver 78 | pocket knife 79 | Swiss army knife 80 | laptop 81 | notebook 82 | ultrabook 83 | netbook 84 | tablet 85 | chromebook 86 | computer 87 | mug 88 | cup 89 | tumbler 90 | teacup 91 | coffee cup 92 | espresso cup 93 | beer mug 94 | wine glass 95 | goblet 96 | flute 97 | martini glass 98 | punch cup 99 | thermal mug 100 | plastic cup 101 | styrofoam cup 102 | paper cup 103 | pot 104 | flowerpot 105 | vase 106 | container 107 | plant pot 108 | flower holder 109 | telephone 110 | phone 111 | mobile phone 112 | cell phone 113 | landline 114 | cordless phone 115 | smart phone 116 | flip phone 117 | touchscreen phone 118 | wireless phone 119 | home phone 120 | office phonee -------------------------------------------------------------------------------- /prompts/rand_noun/synonym_earphone.txt: -------------------------------------------------------------------------------- 1 | bag 2 | handbag 3 | purse 4 | clutch 5 | tote 6 | backpack 7 | knapsack 8 | satchel 9 | shoulder bag 10 | duffel bag 11 | messenger bag 12 | grip 13 | briefcase 14 | pouch 15 | fanny pack 16 | drawstring bag 17 | beach bag 18 | grocery bag 19 | shopping bag 20 | gift bag 21 | lunch bag 22 | laptop bag 23 | travel bag 24 | bowl 25 | dish 26 | soup bowl 27 | wooden bowl 28 | plastic bowl 29 | cap 30 | hat 31 | snapback 32 | baseball cap 33 | visor 34 | bucket hat 35 | fedora 36 | cowboy hat 37 | knit cap 38 | nightcap 39 | earphone 40 | headphone 41 | earbud 42 | in-ear monitor 43 | earpiece 44 | earphone set 45 | earphone accessory 46 | faucet 47 | tap 48 | valve 49 | spout 50 | handle 51 | lever 52 | knob 53 | spigot 54 | water dispenser 55 | sink faucet 56 | bathtub faucet 57 | drinking fountain 58 | water faucet 59 | mixing faucet 60 | jar 61 | vase 62 | container 63 | bottle 64 | canister 65 | pot 66 | urn 67 | jug 68 | flask 69 | pitcher 70 | bucket 71 | pail 72 | basket 73 | knife 74 | blade 75 | cutter 76 | chopper 77 | cleaver 78 | pocket knife 79 | Swiss army knife 80 | laptop 81 | notebook 82 | ultrabook 83 | netbook 84 | tablet 85 | chromebook 86 | computer 87 | mug 88 | cup 89 | tumbler 90 | teacup 91 | coffee cup 92 | espresso cup 93 | beer mug 94 | wine glass 95 | goblet 96 | flute 97 | martini glass 98 | punch cup 99 | thermal mug 100 | plastic cup 101 | styrofoam cup 102 | paper cup 103 | pot 104 | flowerpot 105 | vase 106 | container 107 | plant pot 108 | flower holder 109 | telephone 110 | phone 111 | mobile phone 112 | cell phone 113 | landline 114 | cordless phone 115 | smart phone 116 | flip phone 117 | touchscreen phone 118 | wireless phone 119 | home phone 120 | office phonee -------------------------------------------------------------------------------- /prompts/rand_noun/synonym_faucet.txt: -------------------------------------------------------------------------------- 1 | bag 2 | handbag 3 | purse 4 | clutch 5 | tote 6 | backpack 7 | knapsack 8 | satchel 9 | shoulder bag 10 | duffel bag 11 | messenger bag 12 | grip 13 | briefcase 14 | pouch 15 | fanny pack 16 | drawstring bag 17 | beach bag 18 | grocery bag 19 | shopping bag 20 | gift bag 21 | lunch bag 22 | laptop bag 23 | travel bag 24 | bowl 25 | dish 26 | soup bowl 27 | wooden bowl 28 | plastic bowl 29 | cap 30 | hat 31 | snapback 32 | baseball cap 33 | visor 34 | bucket hat 35 | fedora 36 | cowboy hat 37 | knit cap 38 | nightcap 39 | earphone 40 | headphone 41 | earbud 42 | in-ear monitor 43 | earpiece 44 | earphone set 45 | earphone accessory 46 | faucet 47 | tap 48 | valve 49 | spout 50 | handle 51 | lever 52 | knob 53 | spigot 54 | water dispenser 55 | sink faucet 56 | bathtub faucet 57 | drinking fountain 58 | water faucet 59 | mixing faucet 60 | jar 61 | vase 62 | container 63 | bottle 64 | canister 65 | pot 66 | urn 67 | jug 68 | flask 69 | pitcher 70 | bucket 71 | pail 72 | basket 73 | knife 74 | blade 75 | cutter 76 | chopper 77 | cleaver 78 | pocket knife 79 | Swiss army knife 80 | laptop 81 | notebook 82 | ultrabook 83 | netbook 84 | tablet 85 | chromebook 86 | computer 87 | mug 88 | cup 89 | tumbler 90 | teacup 91 | coffee cup 92 | espresso cup 93 | beer mug 94 | wine glass 95 | goblet 96 | flute 97 | martini glass 98 | punch cup 99 | thermal mug 100 | plastic cup 101 | styrofoam cup 102 | paper cup 103 | pot 104 | flowerpot 105 | vase 106 | container 107 | plant pot 108 | flower holder 109 | telephone 110 | phone 111 | mobile phone 112 | cell phone 113 | landline 114 | cordless phone 115 | smart phone 116 | flip phone 117 | touchscreen phone 118 | wireless phone 119 | home phone 120 | office phonee -------------------------------------------------------------------------------- /prompts/rand_noun/synonym_knife.txt: -------------------------------------------------------------------------------- 1 | bag 2 | handbag 3 | purse 4 | clutch 5 | tote 6 | backpack 7 | knapsack 8 | satchel 9 | shoulder bag 10 | duffel bag 11 | messenger bag 12 | grip 13 | briefcase 14 | pouch 15 | fanny pack 16 | drawstring bag 17 | beach bag 18 | grocery bag 19 | shopping bag 20 | gift bag 21 | lunch bag 22 | laptop bag 23 | travel bag 24 | bowl 25 | dish 26 | soup bowl 27 | wooden bowl 28 | plastic bowl 29 | cap 30 | hat 31 | snapback 32 | baseball cap 33 | visor 34 | bucket hat 35 | fedora 36 | cowboy hat 37 | knit cap 38 | nightcap 39 | earphone 40 | headphone 41 | earbud 42 | in-ear monitor 43 | earpiece 44 | earphone set 45 | earphone accessory 46 | faucet 47 | tap 48 | valve 49 | spout 50 | handle 51 | lever 52 | knob 53 | spigot 54 | water dispenser 55 | sink faucet 56 | bathtub faucet 57 | drinking fountain 58 | water faucet 59 | mixing faucet 60 | jar 61 | vase 62 | container 63 | bottle 64 | canister 65 | pot 66 | urn 67 | jug 68 | flask 69 | pitcher 70 | bucket 71 | pail 72 | basket 73 | knife 74 | blade 75 | cutter 76 | chopper 77 | cleaver 78 | pocket knife 79 | Swiss army knife 80 | laptop 81 | notebook 82 | ultrabook 83 | netbook 84 | tablet 85 | chromebook 86 | computer 87 | mug 88 | cup 89 | tumbler 90 | teacup 91 | coffee cup 92 | espresso cup 93 | beer mug 94 | wine glass 95 | goblet 96 | flute 97 | martini glass 98 | punch cup 99 | thermal mug 100 | plastic cup 101 | styrofoam cup 102 | paper cup 103 | pot 104 | flowerpot 105 | vase 106 | container 107 | plant pot 108 | flower holder 109 | telephone 110 | phone 111 | mobile phone 112 | cell phone 113 | landline 114 | cordless phone 115 | smart phone 116 | flip phone 117 | touchscreen phone 118 | wireless phone 119 | home phone 120 | office phonee -------------------------------------------------------------------------------- /prompts/rand_noun/synonym_laptop.txt: -------------------------------------------------------------------------------- 1 | bag 2 | handbag 3 | purse 4 | clutch 5 | tote 6 | backpack 7 | knapsack 8 | satchel 9 | shoulder bag 10 | duffel bag 11 | messenger bag 12 | grip 13 | briefcase 14 | pouch 15 | fanny pack 16 | drawstring bag 17 | beach bag 18 | grocery bag 19 | shopping bag 20 | gift bag 21 | lunch bag 22 | laptop bag 23 | travel bag 24 | bowl 25 | dish 26 | soup bowl 27 | wooden bowl 28 | plastic bowl 29 | cap 30 | hat 31 | snapback 32 | baseball cap 33 | visor 34 | bucket hat 35 | fedora 36 | cowboy hat 37 | knit cap 38 | nightcap 39 | earphone 40 | headphone 41 | earbud 42 | in-ear monitor 43 | earpiece 44 | earphone set 45 | earphone accessory 46 | faucet 47 | tap 48 | valve 49 | spout 50 | handle 51 | lever 52 | knob 53 | spigot 54 | water dispenser 55 | sink faucet 56 | bathtub faucet 57 | drinking fountain 58 | water faucet 59 | mixing faucet 60 | jar 61 | vase 62 | container 63 | bottle 64 | canister 65 | pot 66 | urn 67 | jug 68 | flask 69 | pitcher 70 | bucket 71 | pail 72 | basket 73 | knife 74 | blade 75 | cutter 76 | chopper 77 | cleaver 78 | pocket knife 79 | Swiss army knife 80 | laptop 81 | notebook 82 | ultrabook 83 | netbook 84 | tablet 85 | chromebook 86 | computer 87 | mug 88 | cup 89 | tumbler 90 | teacup 91 | coffee cup 92 | espresso cup 93 | beer mug 94 | wine glass 95 | goblet 96 | flute 97 | martini glass 98 | punch cup 99 | thermal mug 100 | plastic cup 101 | styrofoam cup 102 | paper cup 103 | pot 104 | flowerpot 105 | vase 106 | container 107 | plant pot 108 | flower holder 109 | telephone 110 | phone 111 | mobile phone 112 | cell phone 113 | landline 114 | cordless phone 115 | smart phone 116 | flip phone 117 | touchscreen phone 118 | wireless phone 119 | home phone 120 | office phonee -------------------------------------------------------------------------------- /prompts/rand_noun/synonym_telephone.txt: -------------------------------------------------------------------------------- 1 | bag 2 | handbag 3 | purse 4 | clutch 5 | tote 6 | backpack 7 | knapsack 8 | satchel 9 | shoulder bag 10 | duffel bag 11 | messenger bag 12 | grip 13 | briefcase 14 | pouch 15 | fanny pack 16 | drawstring bag 17 | beach bag 18 | grocery bag 19 | shopping bag 20 | gift bag 21 | lunch bag 22 | laptop bag 23 | travel bag 24 | bowl 25 | dish 26 | soup bowl 27 | wooden bowl 28 | plastic bowl 29 | cap 30 | hat 31 | snapback 32 | baseball cap 33 | visor 34 | bucket hat 35 | fedora 36 | cowboy hat 37 | knit cap 38 | nightcap 39 | earphone 40 | headphone 41 | earbud 42 | in-ear monitor 43 | earpiece 44 | earphone set 45 | earphone accessory 46 | faucet 47 | tap 48 | valve 49 | spout 50 | handle 51 | lever 52 | knob 53 | spigot 54 | water dispenser 55 | sink faucet 56 | bathtub faucet 57 | drinking fountain 58 | water faucet 59 | mixing faucet 60 | jar 61 | vase 62 | container 63 | bottle 64 | canister 65 | pot 66 | urn 67 | jug 68 | flask 69 | pitcher 70 | bucket 71 | pail 72 | basket 73 | knife 74 | blade 75 | cutter 76 | chopper 77 | cleaver 78 | pocket knife 79 | Swiss army knife 80 | laptop 81 | notebook 82 | ultrabook 83 | netbook 84 | tablet 85 | chromebook 86 | computer 87 | mug 88 | cup 89 | tumbler 90 | teacup 91 | coffee cup 92 | espresso cup 93 | beer mug 94 | wine glass 95 | goblet 96 | flute 97 | martini glass 98 | punch cup 99 | thermal mug 100 | plastic cup 101 | styrofoam cup 102 | paper cup 103 | pot 104 | flowerpot 105 | vase 106 | container 107 | plant pot 108 | flower holder 109 | telephone 110 | phone 111 | mobile phone 112 | cell phone 113 | landline 114 | cordless phone 115 | smart phone 116 | flip phone 117 | touchscreen phone 118 | wireless phone 119 | home phone 120 | office phonee -------------------------------------------------------------------------------- /prompts/human_mug_statement_40.txt: -------------------------------------------------------------------------------- 1 | The fingers are stretching towards the mug 2 | The palm is extending towards the mug 3 | The hand is grasping for the mug 4 | The arm is trying to retrieve the mug 5 | The person is moving towards the mug with their hand 6 | The human is inching towards the mug 7 | The fingers are approaching the mug with their grip 8 | The palm is leaning towards the mug 9 | The hand is aiming for the mug with its fingers 10 | The arm is going after the mug 11 | The fingers are making a move for the mug 12 | The palm is making a reach for the mug 13 | The hand is making an effort to reach the mug 14 | The arm is making a play for the mug 15 | The fingers are stretching out to grab the mug 16 | The palm is angling for the mug 17 | The hand is snatching for the mug 18 | The arm is closing in on the mug with its hand 19 | The fingers are clenching towards the mug 20 | The hand is coveting the mug 21 | The palm is yearning for the mug 22 | The arm is desiring to hold the mug 23 | The fingers are eagerly trying to take the mug 24 | The hand is avidly reaching for the mug 25 | The palm is enthusiastically extending towards the mug 26 | The arm is zealously attempting to reach the mug 27 | The fingers are ambitiously grasping for the mug 28 | The hand is fiercely reaching for the mug 29 | The palm is vigorously attempting to take the mug 30 | The arm is robustly reaching for the mug 31 | The fingers are ardently going after the mug 32 | The hand is intently reaching for the mug 33 | The palm is doggedly grasping for the mug 34 | The arm is persistently attempting to take the mug 35 | The fingers are tenaciously reaching for the mug 36 | The hand is determined to grasp the mug 37 | The palm is unwaveringly reaching for the mug 38 | The arm is steadfastly trying to take the mug 39 | The fingers are hungrily grasping for the mug 40 | The hand is greedily reaching for the mug -------------------------------------------------------------------------------- /prompts/jar.txt: -------------------------------------------------------------------------------- 1 | Grasp the jar 2 | Extend your arm for the jar 3 | Stretch to grab the jar 4 | Take hold of the jar 5 | Get the jar within your reach 6 | Snatch the jar 7 | Reach out for the jar 8 | Reach forward for the jar 9 | Reach up for the jar 10 | Go for the jar 11 | Aim for the jar 12 | Clutch the jar 13 | Lay your hand on the jar 14 | Nab the jar 15 | Catch the jar 16 | Hook the jar 17 | Secure the jar 18 | Seize the jar 19 | Grip the jar 20 | Snag the jar 21 | Acquire the jar 22 | Take possession of the jar 23 | Hold the jar 24 | Clasp the jar 25 | Embrace the jar 26 | Enclose the jar 27 | Envelop the jar 28 | Gather the jar 29 | Collect the jar 30 | Retrieve the jar 31 | Access the jar 32 | Attain the jar 33 | Capture the jar 34 | Draw in the jar 35 | Bring the jar closer 36 | Take the jar 37 | Take a hold of the jar 38 | Take the jar in your hand 39 | Bring the jar to yourself 40 | Take possession of the jar 41 | Lay hands on the jar 42 | Clench the jar 43 | Grab the jar 44 | Obtain the jar 45 | Take up the jar 46 | Pluck the jar 47 | Snatch up the jar 48 | Scoop up the jar 49 | Gather up the jar 50 | Close in on the jar 51 | Ensnare the jar 52 | Embrace the jar tightly 53 | Hold onto the jar 54 | Embody the jar 55 | Hold onto the jar tightly 56 | Take a firm grip of the jar 57 | Get your hands on the jar 58 | Embrace the jar with both hands 59 | Grasp onto the jar tightly 60 | Hold the jar with a firm grip 61 | Securely hold the jar 62 | Keep the jar within reach 63 | Hold the jar close 64 | Draw the jar closer 65 | Haul in the jar 66 | Take hold of the jar with purpose 67 | Take the jar with confidence 68 | Reach out and secure the jar 69 | Get a firm grasp on the jar 70 | Pick up the jar 71 | Take the jar into your hand 72 | Grip the jar firmly 73 | Tightly clasp the jar 74 | Hold the jar securely 75 | Bring the jar towards you 76 | Take control of the jar 77 | Hold onto the jar with determination 78 | Retrieve the jar with care 79 | Grasp the jar with assurance 80 | Keep the jar within arm's reach -------------------------------------------------------------------------------- /prompts/robot_mug_statement_40.txt: -------------------------------------------------------------------------------- 1 | The automaton is extending its arm towards the cup 2 | The machine is grasping for the mug 3 | The android is stretching towards the mug 4 | The robotic arm is attempting to take the mug 5 | The gripper is trying to retrieve the mug 6 | The mechanical arm is moving towards the mug with its hand 7 | The cyborg is inching towards the mug 8 | The robotic fingers are approaching the mug with its grip 9 | The machine is leaning towards the mug 10 | The robot is aiming for the cup with its gripper 11 | The automaton is going after the mug 12 | The cyborg is making a move for the mug 13 | The robotic arm is making a reach for the mug 14 | The machine is making an effort to reach the mug 15 | The android is making a play for the mug 16 | The robotic fingers are stretching out to grab the mug 17 | The gripper is angling for the mug 18 | The robot is snatching for the mug 19 | The machine is closing in on the mug with its hand 20 | The robotic fingers are clenching towards the mug 21 | The automaton is coveting the mug 22 | The android is yearning for the mug 23 | The machine is desiring to hold the mug 24 | The robot is yearning to possess the mug 25 | The cyborg is eagerly trying to take the mug 26 | The robotic arm is avidly reaching for the mug 27 | The machine is enthusiastically extending its arm for the mug 28 | The android is zealously attempting to reach the mug 29 | The robot is ambitiously grasping for the mug 30 | The automaton is fiercely reaching for the mug 31 | The cyborg is vigorously attempting to take the mug 32 | The robotic arm is robustly reaching for the mug 33 | The machine is ardently going after the mug 34 | The android is intently reaching for the mug 35 | The gripper is doggedly grasping for the mug 36 | The robot is persistently attempting to take the mug 37 | The mechanical arm is tenaciously reaching for the mug 38 | The machine is determined to grasp the mug 39 | The robotic fingers are unwaveringly reaching for the mug 40 | The android is steadfastly trying to take the mug -------------------------------------------------------------------------------- /common/rnd.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class Predictor(tf.keras.Model): 4 | def __init__(self, input_shape, hidden_dim, output_dim): 5 | super(Predictor, self).__init__() 6 | self.conv = tf.keras.Sequential([ 7 | tf.keras.layers.Conv2D(32, kernel_size=8, strides=4, activation='relu'), 8 | tf.keras.layers.Conv2D(64, kernel_size=4, strides=2, activation='relu'), 9 | tf.keras.layers.Conv2D(64, kernel_size=3, strides=1, activation='relu'), 10 | tf.keras.layers.Flatten() 11 | ]) 12 | conv_output_size = self._get_conv_output(input_shape) 13 | self.fc = tf.keras.Sequential([ 14 | tf.keras.layers.Dense(hidden_dim, activation='relu'), 15 | tf.keras.layers.Dense(output_dim) 16 | ]) 17 | 18 | def call(self, x): 19 | x = self.conv(x) 20 | x = self.fc(x) 21 | return x 22 | 23 | def _get_conv_output(self, shape): 24 | x = tf.zeros((1, *shape)) 25 | x = self.conv(x) 26 | return int(tf.reduce_prod(x.shape[1:])) 27 | 28 | class RandomNet(tf.keras.Model): 29 | def __init__(self, input_shape, output_dim): 30 | super(RandomNet, self).__init__() 31 | self.conv = tf.keras.Sequential([ 32 | tf.keras.layers.Conv2D(32, kernel_size=8, strides=4, activation='relu'), 33 | tf.keras.layers.Conv2D(64, kernel_size=4, strides=2, activation='relu'), 34 | tf.keras.layers.Conv2D(64, kernel_size=3, strides=1, activation='relu'), 35 | tf.keras.layers.Flatten() 36 | ]) 37 | conv_output_size = self._get_conv_output(input_shape) 38 | self.fc = tf.keras.Sequential([ 39 | tf.keras.layers.Dense(output_dim) 40 | ]) 41 | 42 | def call(self, x): 43 | x = self.conv(x) 44 | x = self.fc(x) 45 | return x 46 | 47 | def _get_conv_output(self, shape): 48 | x = tf.zeros((1, *shape)) 49 | x = self.conv(x) 50 | return int(tf.reduce_prod(x.shape[1:])) 51 | -------------------------------------------------------------------------------- /r3mreward.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import torch 3 | import numpy as np 4 | 5 | class R3MReward(object): 6 | def __init__(self, embed_model, sentences, standardize_rewards, queue_size, update_stats_steps, num_top_images, use_lang_embeddings): 7 | self.embed_model = embed_model 8 | self.reward_model = self.embed_model.get_reward 9 | self.sentences = sentences 10 | self.standardize_rewards = standardize_rewards 11 | self.use_sharding_r3m = False 12 | self.log_top_images = True 13 | self.r3m_reward_bonus = 10.0 14 | self.update_stats_steps = update_stats_steps 15 | self.use_lang_embeddings = use_lang_embeddings 16 | if self.standardize_rewards: 17 | self.stats = {} 18 | for t in self.sentences: 19 | self.stats[t] = deque(maxlen=queue_size) 20 | if self.log_top_images: 21 | self.top_images = {} 22 | for t in self.sentences: 23 | self.top_images[t] = {"images": [], "rewards": []} 24 | self.num_top_images = num_top_images 25 | 26 | def get_lang_encoding(self, lang): 27 | return self.embed_model.lang_enc(lang) 28 | 29 | def get_reward(self, init, curr, lang, lang_emb, step=None): 30 | if isinstance(lang, int): 31 | lang_strings = [self.sentences[lang]] 32 | init_image = torch.unsqueeze(init, 0) 33 | curr_image = torch.unsqueeze(curr, 0) 34 | else: 35 | lang_strings = [self.sentences[i] for i in lang[:, 0]] 36 | init_image = init 37 | curr_image = curr 38 | 39 | 40 | init = self.embed_model(init_image[:, -3:, :, :]) 41 | curr = self.embed_model(curr_image[:, -3:, :, :]) 42 | if self.use_lang_embeddings: 43 | reward = self.embed_model.get_reward_le(init, curr, lang_emb)[0].unsqueeze(-1) 44 | else: 45 | reward = self.reward_model(init, curr, lang_strings)[0].unsqueeze(-1) 46 | 47 | return reward, None, None -------------------------------------------------------------------------------- /prompts/mug.txt: -------------------------------------------------------------------------------- 1 | Grasp the mug 2 | Extend your hand for the mug 3 | Retrieve the mug 4 | Take hold of the mug 5 | Stretch out for the mug 6 | Clutch the mug 7 | Seize the mug 8 | Lay your hands on the mug 9 | Pick up the mug 10 | Get a grip on the mug 11 | Reach out for the cup 12 | Grip the cup 13 | Hold the mug 14 | Stretch for the mug 15 | Snatch the mug 16 | Embrace the mug 17 | Catch the mug 18 | Grapple for the mug 19 | Take the mug 20 | Secure the mug 21 | Cling to the mug 22 | Obtain the mug 23 | Acquire the mug 24 | Clasp the mug 25 | Nudge the mug 26 | Bring the mug closer 27 | Draw the mug nearer 28 | Gather the mug 29 | Gain possession of the mug 30 | Lay hold of the mug 31 | Retrieve the cup 32 | Clutch the cup 33 | Hold onto the cup 34 | Reach out for the beverage holder 35 | Extend your arm for the coffee mug 36 | Grab the mug 37 | Obtain the coffee vessel 38 | Snag the mug 39 | Get a hold of the mug 40 | Hook the mug 41 | Make a grab for the mug 42 | Stretch your fingers towards the mug 43 | Wrap your hand around the mug 44 | Clasp the warm mug 45 | Hold the handle of the mug 46 | Reach out and touch the mug 47 | Take the mug in your hands 48 | Grasp onto the mug's handle 49 | Retrieve the mug with ease 50 | Snatch the mug quickly 51 | Catch the mug before it falls 52 | Pull the mug towards you 53 | Reach out and grab the mug 54 | Secure the mug with your hand 55 | Get a grip on the coffee mug 56 | Hold onto the mug tightly 57 | Cling onto the mug's handle 58 | Acquire the mug without hesitation 59 | Lay your palm on the mug 60 | Nudge the mug slightly closer 61 | Draw the mug nearer to you 62 | Bring the mug within reach 63 | Obtain the mug in one swift motion 64 | Take hold of the mug firmly 65 | Grapple for the mug's handle 66 | Hook your fingers around the mug 67 | Embrace the mug lovingly 68 | Snuggle your hand around the mug 69 | Hold the mug with both hands 70 | Take the mug in a gentle grip 71 | Retrieve the cup of coffee 72 | Clutch the coffee mug tightly 73 | Hold onto the coffee cup 74 | Extend your arm and take the mug 75 | Grab onto the coffee mug's handle 76 | Obtain the warm mug with care 77 | Snag the mug from the table 78 | Hook the mug and lift it up 79 | Make a move for the mug's handle 80 | Reach for mug -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LAMP 2 | 3 | [**LA**nguage **M**odulated **P**retraining](https://arxiv.org/abs/2308.12270) (LAMP💡) is a method for pretraining a general RL agent for accelerated downstream learning by augmenting unsupervised RL rewards with extrinsic rewards parameterized by a Video-Langauge Model (VLM). 4 | 5 | LAMP method figure 6 | 7 | ## Installation 8 | To create a conda environment called `lamp`: 9 | ```bash 10 | conda env create -f env.yml 11 | conda activate lamp 12 | ``` 13 | 14 | Then, follow the RLBench installation instructions from [this fork](https://github.com/ademiadeniji/RLBench_lamp) that implements shaped rewards and the domain-randomized pretraining environment. 15 | 16 | Finally, install [our fork](https://github.com/ademiadeniji/r3m_lamp) of the R3M module that enables computing video-language alignment scores. 17 | 18 | ```bash 19 | git clone https://github.com/ademiadeniji/r3m_lamp 20 | pip install -e r3m_lamp 21 | ``` 22 | 23 | ## Training 24 | To pretrain your LAMP agent run: 25 | 26 | ```bash 27 | TF_CPP_MIN_LOG_LEVEL=0 CUDA_VISIBLE_DEVICES=0 TF_XLA_FLAGS=--tf_xla_auto_jit=2 vglrun -d :0.0 python train.py --logdir /YOUR/LOGDIR/HERE --task pick_shapenet_objects --seed 1 --use_r3m_reward True --device cuda:0 --vidlang_model_device cuda:0 --use_lang_embeddings True --configs front_wrist vlsp --curriculum.objects 'bag,bowl,cap,earphone,faucet,jar,knife,laptop,mug,pot,telephone' --curriculum.num_unique_per_class '-1' --curriculum.num_objects '3' --curriculum.lang_prompt 'prompts/similar_verb_40.txt' --curriculum.synonym_folder prompts/similar_noun --curriculum.num_episodes '20000' --randomize True --expl_intr_scale 0.9 --expl_extr_scale 0.1 --plan2explore True 28 | ``` 29 | 30 | To finetune your pretrained LAMP agent on the take lid off saucepan task run: 31 | 32 | ```bash 33 | TF_CPP_MIN_LOG_LEVEL=0 CUDA_VISIBLE_DEVICES=0 TF_XLA_FLAGS=--tf_xla_auto_jit=2 vglrun -d :0.0 python train.py --logdir /YOUR/LOGDIR/HERE --task take_lid_off_saucepan --seed 0 --device cuda:0 --vidlang_model_device cuda:0 --use_lang_embeddings True --configs front_wrist vlsp --curriculum.use False --critic_linear_probe True --loaddir [LOADDIR] --ts [NUM_STEPS_PRETRAINED] --plan2explore True --expl_intr_scale 0 --expl_extr_scale 1 --shaped_rewards True 34 | ``` 35 | ## Citations 36 | If you use this code for your research, please cite our paper: 37 | ```sh 38 | @misc{adeniji2023language, 39 | title={Language Reward Modulation for Pretraining Reinforcement Learning}, 40 | author={Ademi Adeniji and Amber Xie and Carmelo Sferrazza and Younggyo Seo and Stephen James and Pieter Abbeel}, 41 | year={2023}, 42 | eprint={2308.12270}, 43 | archivePrefix={arXiv}, 44 | primaryClass={cs.LG} 45 | } 46 | ``` -------------------------------------------------------------------------------- /common/driver.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Driver: 5 | def __init__(self, envs, **kwargs): 6 | self._envs = envs 7 | self._kwargs = kwargs 8 | self._on_steps = [] 9 | self._on_resets = [] 10 | self._on_episodes = [] 11 | self._act_spaces = [env.act_space for env in envs] 12 | self.reset() 13 | 14 | def on_step(self, callback): 15 | self._on_steps.append(callback) 16 | 17 | def on_reset(self, callback): 18 | self._on_resets.append(callback) 19 | 20 | def on_episode(self, callback): 21 | self._on_episodes.append(callback) 22 | 23 | def reset(self): 24 | self._obs = [None] * len(self._envs) 25 | self._eps = [None] * len(self._envs) 26 | self._state = None 27 | 28 | def __call__(self, policy, steps=0, episodes=0): 29 | step, episode = 0, 0 30 | while step < steps or episode < episodes: 31 | obs = { 32 | i: self._envs[i].reset() 33 | for i, ob in enumerate(self._obs) 34 | if ob is None or ob["is_last"] 35 | } 36 | for i, ob in obs.items(): 37 | self._obs[i] = ob() if callable(ob) else ob 38 | ob = self._obs[i] 39 | act = {k: np.zeros(v.shape) for k, v in self._act_spaces[i].items()} 40 | tran = {k: self._convert(v) for k, v in {**ob, **act}.items()} 41 | [fn(tran, worker=i, **self._kwargs) for fn in self._on_resets] 42 | self._eps[i] = [tran] 43 | obs = {k: np.stack([o[k] for o in self._obs]) for k in self._obs[0]} 44 | if len(self._kwargs.keys()) == 0: 45 | actions, self._state = policy(obs, self._state) 46 | else: 47 | actions, self._state = policy(obs, self._state, **self._kwargs) 48 | actions = [ 49 | {k: np.array(actions[k][i]) for k in actions} 50 | for i in range(len(self._envs)) 51 | ] 52 | assert len(actions) == len(self._envs) 53 | obs = [e.step(a) for e, a in zip(self._envs, actions)] 54 | obs = [ob() if callable(ob) else ob for ob in obs] 55 | for i, (act, ob) in enumerate(zip(actions, obs)): 56 | tran = {k: self._convert(v) for k, v in {**ob, **act}.items()} 57 | [fn(tran, worker=i, **self._kwargs) for fn in self._on_steps] 58 | self._eps[i].append(tran) 59 | step += 1 60 | if ob["is_last"]: 61 | ep = self._eps[i] 62 | ep = {k: self._convert([t[k] for t in ep]) for k in ep[0]} 63 | [fn(ep, **self._kwargs) for fn in self._on_episodes] 64 | episode += 1 65 | self._obs = obs 66 | 67 | def _convert(self, value): 68 | value = np.array(value) 69 | if np.issubdtype(value.dtype, np.floating): 70 | return value.astype(np.float32) 71 | elif np.issubdtype(value.dtype, np.signedinteger): 72 | return value.astype(np.int32) 73 | elif np.issubdtype(value.dtype, np.uint8): 74 | return value.astype(np.uint8) 75 | return value 76 | -------------------------------------------------------------------------------- /common/mae_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras.layers as tfkl 3 | from tensorflow.keras import mixed_precision as prec 4 | import numpy as np 5 | 6 | 7 | class Token(tf.keras.layers.Layer): 8 | def __init__( 9 | self, 10 | name, 11 | embed_dim, 12 | **kwargs, 13 | ): 14 | super().__init__(**kwargs) 15 | self._name = name 16 | self.embed_dim = embed_dim 17 | self.mask_token = None 18 | 19 | def build(self, input_shape): 20 | self.mask_token = self.add_weight( 21 | name=f"{self._name}_token", 22 | shape=(1, 1, self.embed_dim), 23 | initializer=tf.random_normal_initializer(stddev=0.02), 24 | trainable=True, 25 | ) 26 | 27 | def call(self, x): 28 | return self.mask_token 29 | 30 | 31 | def get_2d_sincos_pos_embed(embed_dim, grid_h_size, grid_w_size): 32 | """ 33 | grid_size: int of the grid height and width 34 | return: 35 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 36 | """ 37 | H, W = grid_h_size, grid_w_size 38 | 39 | grid_h = np.arange(grid_h_size, dtype=np.float32) 40 | grid_w = np.arange(grid_w_size, dtype=np.float32) 41 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 42 | grid = np.stack(grid, axis=0) 43 | 44 | grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) 45 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 46 | return pos_embed 47 | 48 | 49 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 50 | assert embed_dim % 2 == 0 51 | 52 | # use half of dimensions to encode grid_h 53 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 54 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 55 | 56 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 57 | return emb 58 | 59 | 60 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 61 | """ 62 | embed_dim: output dimension for each position 63 | pos: a list of positions to be encoded: size (M,) 64 | out: (M, D) 65 | """ 66 | assert embed_dim % 2 == 0 67 | omega = np.arange(embed_dim // 2, dtype=np.float32) 68 | omega /= embed_dim / 2.0 69 | omega = 1.0 / 10000 ** omega # (D/2,) 70 | 71 | pos = pos.reshape(-1) # (M,) 72 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 73 | 74 | emb_sin = np.sin(out) # (M, D/2) 75 | emb_cos = np.cos(out) # (M, D/2) 76 | 77 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 78 | return emb 79 | 80 | 81 | def get_1d_sincos_pos_embed(embed_dim, grid_size): 82 | """ 83 | grid_size: int of the grid height and width 84 | return: 85 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 86 | """ 87 | assert embed_dim % 2 == 0 88 | grid_t = np.arange(grid_size, dtype=np.float32) 89 | grid = np.meshgrid(grid_t) # here w goes first 90 | grid = np.stack(grid, axis=0) 91 | 92 | grid = grid[0] 93 | 94 | omega = np.arange(embed_dim / 2, dtype=np.float32) 95 | omega /= embed_dim / 2 96 | omega = 1.0 / 10000 ** omega 97 | pos = grid 98 | out = np.einsum("m,d->md", pos, omega) 99 | 100 | emb_sin = np.sin(out) 101 | emb_cos = np.cos(out) 102 | 103 | pos_emb = np.concatenate([emb_sin, emb_cos], axis=1) 104 | return pos_emb 105 | -------------------------------------------------------------------------------- /common/flags.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | 4 | 5 | class Flags: 6 | def __init__(self, *args, **kwargs): 7 | from .config import Config 8 | 9 | self._config = Config(*args, **kwargs) 10 | 11 | def parse(self, argv=None, known_only=False, help_exists=None): 12 | if help_exists is None: 13 | help_exists = not known_only 14 | if argv is None: 15 | argv = sys.argv[1:] 16 | if "--help" in argv: 17 | print("\nHelp:") 18 | lines = str(self._config).split("\n")[2:] 19 | print("\n".join("--" + re.sub(r"[:,\[\]]", "", x) for x in lines)) 20 | help_exists and sys.exit() 21 | parsed = {} 22 | remaining = [] 23 | key = None 24 | vals = None 25 | for arg in argv: 26 | if arg.startswith("--"): 27 | if key: 28 | self._submit_entry(key, vals, parsed, remaining) 29 | if "=" in arg: 30 | key, val = arg.split("=", 1) 31 | vals = [val] 32 | else: 33 | key, vals = arg, [] 34 | else: 35 | if key: 36 | vals.append(arg) 37 | else: 38 | remaining.append(arg) 39 | self._submit_entry(key, vals, parsed, remaining) 40 | parsed = self._config.update(parsed) 41 | if known_only: 42 | return parsed, remaining 43 | else: 44 | for flag in remaining: 45 | if flag.startswith("--"): 46 | raise ValueError(f"Flag '{flag}' did not match any config keys.") 47 | assert not remaining, remaining 48 | return parsed 49 | 50 | def _submit_entry(self, key, vals, parsed, remaining): 51 | if not key and not vals: 52 | return 53 | if not key: 54 | vals = ", ".join(f"'{x}'" for x in vals) 55 | raise ValueError(f"Values {vals} were not preceeded by any flag.") 56 | name = key[len("--") :] 57 | if "=" in name: 58 | remaining.extend([key] + vals) 59 | return 60 | if self._config.IS_PATTERN.match(name): 61 | pattern = re.compile(name) 62 | keys = {k for k in self._config.flat if pattern.match(k)} 63 | elif name in self._config: 64 | keys = [name] 65 | else: 66 | keys = [] 67 | if not keys: 68 | remaining.extend([key] + vals) 69 | return 70 | if not vals: 71 | raise ValueError(f"Flag '{key}' was not followed by any values.") 72 | for key in keys: 73 | parsed[key] = self._parse_flag_value(self._config[key], vals, key) 74 | 75 | def _parse_flag_value(self, default, value, key): 76 | value = value if isinstance(value, (tuple, list)) else (value,) 77 | if isinstance(default, (tuple, list)): 78 | if len(value) == 1 and "," in value[0]: 79 | value = value[0].split(",") 80 | return tuple(self._parse_flag_value(default[0], [x], key) for x in value) 81 | assert len(value) == 1, value 82 | value = str(value[0]) 83 | if default is None: 84 | return value 85 | if isinstance(default, bool): 86 | try: 87 | return bool(["False", "True"].index(value)) 88 | except ValueError: 89 | message = f"Expected bool but got '{value}' for key '{key}'." 90 | raise TypeError(message) 91 | if isinstance(default, int): 92 | value = float(value) # Allow scientific notation for integers. 93 | if float(int(value)) != value: 94 | message = f"Expected int but got float '{value}' for key '{key}'." 95 | raise TypeError(message) 96 | return int(value) 97 | return type(default)(value) 98 | -------------------------------------------------------------------------------- /common/base_envs.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import os 3 | import re 4 | 5 | import gym 6 | import numpy as np 7 | import pickle 8 | from d4rl.kitchen.adept_envs.simulation.renderer import DMRenderer, MjPyRenderer 9 | from d4rl.kitchen.adept_envs.simulation.sim_robot import RenderMode 10 | 11 | 12 | def get_device_id(): 13 | return int(os.environ.get('GL_DEVICE_ID', 0)) 14 | 15 | 16 | class DmBenchEnv(): 17 | 18 | def __init__(self, name, action_repeat=1, size=(64, 64), camera=None): 19 | domain, task = name.split('_', 1) 20 | if domain == 'cup': # Only domain with multiple words. 21 | domain = 'ball_in_cup' 22 | if isinstance(domain, str): 23 | from dm_control import suite 24 | self._env = suite.load(domain, task) 25 | else: 26 | assert task is None 27 | self._env = domain() 28 | self._action_repeat = action_repeat 29 | self._size = size 30 | if camera is None: 31 | camera = dict(quadruped=2).get(domain, 0) 32 | self._camera = camera 33 | 34 | @property 35 | def observation_space(self): 36 | spaces = {} 37 | for key, value in self._env.observation_spec().items(): 38 | spaces[key] = gym.spaces.Box( 39 | -np.inf, np.inf, value.shape, dtype=np.float32) 40 | spaces['image'] = gym.spaces.Box( 41 | 0, 255, self._size + (3,), dtype=np.uint8) 42 | return gym.spaces.Dict(spaces) 43 | 44 | def _update_obs(self, obs): 45 | return obs 46 | 47 | @property 48 | def action_space(self): 49 | spec = self._env.action_spec() 50 | return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) 51 | 52 | def step(self, action): 53 | assert np.isfinite(action).all(), action 54 | reward = 0 55 | for _ in range(self._action_repeat): 56 | time_step = self._env.step(action) 57 | reward += time_step.reward or 0 58 | if time_step.last(): 59 | break 60 | obs = dict(time_step.observation) 61 | obs['image'] = self.render() 62 | done = time_step.last() 63 | info = {'discount': np.array(time_step.discount, np.float32)} 64 | 65 | return self._update_obs(obs), reward, done, info 66 | 67 | def reset(self): 68 | time_step = self._env.reset() 69 | obs = dict(time_step.observation) 70 | obs['image'] = self.render() 71 | return self._update_obs(obs) 72 | 73 | def render(self, *args, **kwargs): 74 | if kwargs.get('mode', 'rgb_array') != 'rgb_array': 75 | raise ValueError("Only render mode 'rgb_array' is supported.") 76 | return self._env.physics.render(*self._size, camera_id=self._camera) 77 | 78 | 79 | class BenchEnv(): 80 | LOCK = threading.Lock() 81 | 82 | def __init__(self, action_repeat, width=64): 83 | self._action_repeat = action_repeat 84 | self._width = width 85 | self._size = (self._width, self._width) 86 | 87 | @property 88 | def observation_space(self): 89 | shape = self._size + (3,) 90 | # space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8) 91 | spaces = {"reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), 92 | } 93 | return spaces 94 | # return gym.spaces.Dict({'image': space}) 95 | 96 | @property 97 | def action_space(self): 98 | return self._env.action_space 99 | 100 | def close(self): 101 | return self._env.close() 102 | 103 | def reset(self): 104 | with self.LOCK: 105 | state = self._env.reset() 106 | return self._get_obs(state) 107 | 108 | def step(self, action): 109 | total_reward = 0.0 110 | for step in range(self._action_repeat): 111 | state, reward, done, info = self._env.step(action) 112 | total_reward += reward 113 | if done: 114 | break 115 | obs = self._get_obs(state) 116 | return obs, total_reward, done, info 117 | 118 | def render(self, mode): 119 | return self._env.render(mode, self._width, self._width) 120 | 121 | def render_offscreen(self): 122 | img = self.renderer.render_offscreen( 123 | self._width, self._width, mode=RenderMode.RGB, camera_id=-1) 124 | return np.flipud(np.fliplr(img)) 125 | 126 | def _get_obs(self, state): 127 | return {'image': self.render_offscreen(), 'state': state} -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: lamp 2 | channels: 3 | - nvidia 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - ca-certificates=2022.12.7=ha878542_0 10 | - certifi=2022.12.7=pyhd8ed1ab_0 11 | - cudatoolkit=11.1.74=h6bb024c_0 12 | - cudnn=8.2.1.32=h86fa8c9_0 13 | - ld_impl_linux-64=2.38=h1181459_1 14 | - libblas=3.9.0=15_linux64_openblas 15 | - libcblas=3.9.0=15_linux64_openblas 16 | - libffi=3.4.2=h6a678d5_6 17 | - libgcc-ng=11.2.0=h1234567_1 18 | - libgfortran-ng=12.2.0=h69a702a_19 19 | - libgfortran5=12.2.0=h337968e_19 20 | - libgomp=11.2.0=h1234567_1 21 | - liblapack=3.9.0=15_linux64_openblas 22 | - libopenblas=0.3.20=pthreads_h78a6416_0 23 | - libstdcxx-ng=11.2.0=h1234567_1 24 | - ncurses=6.4=h6a678d5_0 25 | - numpy 26 | - openssl=1.1.1t=h7f8727e_0 27 | - pip==22.3.1 28 | - python=3.8.16=h7a1cb2a_2 29 | - python_abi=3.8=2_cp38 30 | - readline=8.2=h5eee18b_0 31 | - sqlite=3.40.1=h5082296_0 32 | - tk=8.6.12=h1ccaba5_0 33 | - wheel=0.37.1=pyhd3eb1b0_0 34 | - xz=5.2.10=h5eee18b_1 35 | - zlib=1.2.13=h5eee18b_0 36 | - pip: 37 | - pip==22.3.1 38 | - absl-py==0.15.0 39 | - albumentations==1.2.1 40 | - antlr4-python3-runtime==4.8 41 | - astunparse==1.6.3 42 | - atari-py==0.2.9 43 | - beautifulsoup4==4.11.2 44 | - cached-property==1.5.2 45 | - cachetools==5.2.0 46 | - cffi==1.14.2 47 | - charset-normalizer==3.0.1 48 | - clang==5.0 49 | - click==8.1.3 50 | - cloudpickle==2.1.0 51 | - cycler==0.11.0 52 | - cython==0.29.32 53 | - decorator==5.1.1 54 | - dm-control==1.0.10 55 | - dm-env==1.5 56 | - dm-tree==0.1.7 57 | - fasteners==0.17.3 58 | - filelock==3.8.0 59 | - flatbuffers==1.12 60 | - fonttools==4.38.0 61 | - ftfy==6.1.1 62 | - gast==0.4.0 63 | - gdown==4.4.0 64 | - glfw==2.5.4 65 | - google-auth==2.11.0 66 | - google-auth-oauthlib==0.4.6 67 | - google-pasta==0.2.0 68 | - grpcio==1.48.1 69 | - html-testrunner==1.2.1 70 | - hydra-core==1.1.1 71 | - idna==3.4 72 | - imageio==2.21.2 73 | - importlib-metadata==4.12.0 74 | - importlib-resources==5.10.2 75 | - jinja2==3.1.2 76 | - joblib==1.1.0 77 | - keras==2.6.0 78 | - keras-preprocessing==1.1.2 79 | - llvmlite==0.39.1 80 | - lxml==4.9.1 81 | - markdown==3.4.1 82 | - markupsafe==2.1.1 83 | - matplotlib 84 | - natsort==8.2.0 85 | - networkx==2.6.3 86 | - numpy-quaternion==2022.4.3 87 | - nvidia-cublas-cu11==11.10.3.66 88 | - nvidia-cuda-nvrtc-cu11==11.7.99 89 | - nvidia-cuda-runtime-cu11==11.7.99 90 | - nvidia-cudnn-cu11==8.5.0.96 91 | - oauthlib==3.2.0 92 | - omegaconf==2.1.1 93 | - opencv-python==4.1.2.30 94 | - opencv-python-headless==4.6.0.66 95 | - opensimplex==0.3 96 | - opt-einsum==3.3.0 97 | - packaging==21.3 98 | - pandas==1.4.0 99 | - patchelf==0.15.0.0 100 | - pillow==9.4.0 101 | - protobuf==3.19.4 102 | - pyasn1==0.4.8 103 | - pyasn1-modules==0.2.8 104 | - pybullet==3.2.5 105 | - pycparser==2.21 106 | - pyopengl==3.1.6 107 | - pyparsing==3.0.9 108 | - pyquaternion==0.9.9 109 | - python-dateutil==2.8.2 110 | - pytz==2022.7.1 111 | - pywavelets==1.3.0 112 | - pyyaml==6.0 113 | - regex==2022.9.13 114 | - requests==2.28.2 115 | - requests-oauthlib==1.3.1 116 | - rsa==4.9 117 | - ruamel-yaml==0.17.21 118 | - ruamel-yaml-clib==0.2.6 119 | - scikit-image 120 | - scikit-learn 121 | - scipy==1.7.3 122 | - six 123 | - soupsieve==2.3.2.post1 124 | - tensorboard==2.9.0 125 | - tensorboard-data-server==0.6.1 126 | - tensorboard-plugin-wit==1.8.1 127 | - tensorflow==2.6.1 128 | - tensorflow-addons==0.18.0 129 | - tensorflow-estimator==2.6.0 130 | - tensorflow-gpu==2.6.0 131 | - tensorflow-hub==0.12.0 132 | - tensorflow-probability==0.14.1 133 | - tensorflow-text==2.6.0 134 | - termcolor==1.1.0 135 | - tfimm==0.2.6 136 | - threadpoolctl==3.1.0 137 | - tifffile==2021.11.2 138 | - timm==0.6.13 139 | - tokenizers==0.13.1 140 | - torch==1.13.1 141 | - torchvision==0.14.1 142 | - tqdm==4.64.1 143 | - transformers==4.23.1 144 | - transforms3d==0.4.1 145 | - typeguard==2.13.3 146 | - urllib3==1.26.14 147 | - wcwidth==0.2.6 148 | - werkzeug==2.2.2 149 | - wrapt==1.12.1 150 | - zipp==3.8.1 151 | - psutil==5.7.0 152 | - setuptools==65.5.0 153 | - gym==0.21.0 154 | - git+https://github.com/openai/CLIP.git 155 | -------------------------------------------------------------------------------- /common/rlbench_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | import copy 4 | import numpy as np 5 | 6 | 7 | def collect_demo( 8 | env, 9 | replay, 10 | num_demos, 11 | camera_keys, 12 | shaped_rewards=False, 13 | task_name_to_num=None, 14 | ): 15 | transitions = [] 16 | for _ in range(num_demos): 17 | success = False 18 | while not success: 19 | try: 20 | demo = env._task.get_demos(1, live_demos=True)[0] 21 | success = True 22 | except: 23 | pass 24 | transitions.extend(extract_from_demo(demo, shaped_rewards, camera_keys, env.task_name, task_name_to_num)) 25 | 26 | # Restrict translation space by min_max 27 | actions = [] 28 | 29 | for obs in transitions: 30 | if obs["is_first"]: 31 | continue 32 | action = obs["action"] 33 | actions.append(action) 34 | 35 | low, high = np.min(actions, 0)[:3], np.max(actions, 0)[:3] 36 | low -= 0.2 * np.fabs(low) 37 | high += 0.2 * np.fabs(high) 38 | 39 | for obs in transitions: 40 | if obs["is_first"]: 41 | # for first action, let's just label with zero action 42 | obs["action"] = np.zeros(3 + 1) 43 | else: 44 | action = obs["action"] 45 | updated_action = [] 46 | 47 | pose = action[:3] 48 | norm_pose = 2 * ((pose - low) / (high - low)) - 1 49 | updated_action.append(norm_pose) 50 | 51 | gripper = action[3:4] 52 | norm_gripper = gripper * 2 - 1.0 53 | updated_action.append(norm_gripper) 54 | 55 | obs["action"] = np.hstack(updated_action) 56 | 57 | replay.add_step(obs) 58 | 59 | print(f"Position min/max: {low}/{high}") 60 | actions_min_max = low, high 61 | 62 | return actions_min_max 63 | 64 | 65 | def get_action(prev_obs, obs): 66 | prev_pose = prev_obs.gripper_pose[:3] 67 | cur_pose = obs.gripper_pose[:3] 68 | pose = cur_pose - prev_pose 69 | gripper_action = float(obs.gripper_open) 70 | prev_action = np.hstack([pose, gripper_action]) 71 | return prev_action 72 | 73 | 74 | def extract_from_demo(demo, shaped_rewards, camera_keys, task_name=None, task_name_to_num=None): 75 | transitions = [] 76 | if task_name_to_num is not None: 77 | init_image, init_state = None, None 78 | for k, obs in enumerate(demo): 79 | if k == 0: 80 | prev_action = None 81 | else: 82 | prev_obs = demo[k - 1] 83 | prev_action = get_action(prev_obs, obs) 84 | 85 | terminal = k == len(demo) - 1 86 | first = k == 0 87 | success = terminal 88 | 89 | if shaped_rewards: 90 | reward = obs.task_low_dim_state[0] 91 | else: 92 | reward = float(success) 93 | 94 | # Not to override obs 95 | _obs = copy.deepcopy(obs) 96 | _obs.joint_velocities = None 97 | _obs.joint_positions = None 98 | _obs.task_low_dim_state = None 99 | 100 | transition = { 101 | "reward": reward, 102 | "is_first": first, 103 | "is_last": False, 104 | "is_terminal": False, 105 | "success": success, 106 | "action": prev_action, 107 | "state": _obs.get_low_dim_data(), 108 | } 109 | 110 | keys = get_camera_keys(camera_keys) 111 | images = [] 112 | for key in keys: 113 | if key == "image_front": 114 | images.append(_obs.front_rgb) 115 | if key == "image_wrist": 116 | images.append(_obs.wrist_rgb) 117 | transition["image"] = np.concatenate(images, axis=-2) 118 | if task_name_to_num is not None: 119 | if k == 0: 120 | init_image = transition["image"] 121 | init_state = transition["state"] 122 | transition["init_image"] = init_image 123 | transition["init_state"] = init_state 124 | transition['task_num'] = task_name_to_num[task_name] 125 | transitions.append(transition) 126 | 127 | if len(transitions) % 50 == 0: 128 | time_limit = len(transitions) 129 | else: 130 | time_limit = 50 * (1 + (len(transitions) // 50)) 131 | while len(transitions) < time_limit: 132 | transitions.append(copy.deepcopy(transition)) 133 | transitions[-1]["is_last"] = True 134 | return transitions 135 | 136 | 137 | def get_camera_keys(keys): 138 | camera_keys = keys.split("|") 139 | return camera_keys 140 | -------------------------------------------------------------------------------- /common/expl.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow_probability import distributions as tfd 3 | 4 | import agent 5 | import common 6 | 7 | 8 | class Random(common.Module): 9 | 10 | def __init__(self, config, act_space, wm, tfstep, reward): 11 | self.config = config 12 | self.act_space = self.act_space 13 | 14 | def actor(self, feat): 15 | shape = feat.shape[:-1] + self.act_space.shape 16 | if self.config.actor.dist == 'onehot': 17 | return common.OneHotDist(tf.zeros(shape)) 18 | else: 19 | dist = tfd.Uniform(-tf.ones(shape), tf.ones(shape)) 20 | return tfd.Independent(dist, 1) 21 | 22 | def train(self, start, context, data): 23 | return None, {} 24 | 25 | 26 | class Plan2Explore(common.Module): 27 | 28 | def __init__(self, config, act_space, wm, tfstep, reward): 29 | self.config = config 30 | self.reward = reward 31 | self.wm = wm 32 | self.ac = agent.ActorCritic(config, act_space, tfstep) 33 | self.actor = self.ac.actor 34 | stoch_size = config.rssm.stoch 35 | if config.rssm.discrete: 36 | stoch_size *= config.rssm.discrete 37 | size = { 38 | 'stoch': stoch_size, 39 | 'deter': config.rssm.deter, 40 | 'feat': config.rssm.stoch + config.rssm.deter, 41 | }[self.config.disag_target] 42 | self._networks = [ 43 | common.MLP(size, **config.expl_head) 44 | for _ in range(config.disag_models)] 45 | self.opt = common.Optimizer('expl', **config.expl_opt) 46 | self.extr_rewnorm = common.StreamNorm(**self.config.expl_reward_norm) 47 | self.intr_rewnorm = common.StreamNorm(**self.config.expl_reward_norm) 48 | 49 | def train(self, start, context, data): 50 | metrics = {} 51 | stoch = start['stoch'] 52 | if self.config.rssm.discrete: 53 | stoch = tf.reshape( 54 | stoch, stoch.shape[:-2] + (stoch.shape[-2] * stoch.shape[-1])) 55 | target = { 56 | 'embed': context['embed'], 57 | 'stoch': stoch, 58 | 'deter': start['deter'], 59 | 'feat': context['feat'], 60 | }[self.config.disag_target] 61 | inputs = context['feat'] 62 | if self.config.disag_action_cond: 63 | action = tf.cast(data['action'], inputs.dtype) 64 | inputs = tf.concat([inputs, action], -1) 65 | metrics.update(self._train_ensemble(inputs, target)) 66 | metrics.update(self.ac.train( 67 | self.wm, start, data['is_terminal'], self._intr_reward)) 68 | return None, metrics 69 | 70 | def _intr_reward(self, seq): 71 | inputs = seq['feat'] 72 | if self.config.disag_action_cond: 73 | action = tf.cast(seq['action'], inputs.dtype) 74 | inputs = tf.concat([inputs, action], -1) 75 | preds = [head(inputs).mode() for head in self._networks] 76 | disag = tf.tensor(preds).std(0).mean(-1) 77 | if self.config.disag_log: 78 | disag = tf.math.log(disag) 79 | reward = self.config.expl_intr_scale * self.intr_rewnorm(disag)[0] 80 | if self.config.expl_extr_scale: 81 | reward += self.config.expl_extr_scale * self.extr_rewnorm( 82 | self.reward(seq))[0] 83 | return reward 84 | 85 | def _train_ensemble(self, inputs, targets): 86 | if self.config.disag_offset: 87 | targets = targets[:, self.config.disag_offset:] 88 | inputs = inputs[:, :-self.config.disag_offset] 89 | targets = tf.stop_gradient(targets) 90 | inputs = tf.stop_gradient(inputs) 91 | with tf.GradientTape() as tape: 92 | preds = [head(inputs) for head in self._networks] 93 | loss = -sum([pred.log_prob(targets).mean() for pred in preds]) 94 | metrics = self.opt(tape, loss, self._networks) 95 | return metrics 96 | 97 | 98 | class ModelLoss(common.Module): 99 | 100 | def __init__(self, config, act_space, wm, tfstep, reward): 101 | self.config = config 102 | self.reward = reward 103 | self.wm = wm 104 | self.ac = agent.ActorCritic(config, act_space, tfstep) 105 | self.actor = self.ac.actor 106 | self.head = common.MLP([], **self.config.expl_head) 107 | self.opt = common.Optimizer('expl', **self.config.expl_opt) 108 | 109 | def train(self, start, context, data): 110 | metrics = {} 111 | target = tf.cast(context[self.config.expl_model_loss], tf.float32) 112 | with tf.GradientTape() as tape: 113 | loss = -self.head(context['feat']).log_prob(target).mean() 114 | metrics.update(self.opt(tape, loss, self.head)) 115 | metrics.update(self.ac.train( 116 | self.wm, start, data['is_terminal'], self._intr_reward)) 117 | return None, metrics 118 | 119 | def _intr_reward(self, seq): 120 | reward = self.config.expl_intr_scale * self.head(seq['feat']).mode() 121 | if self.config.expl_extr_scale: 122 | reward += self.config.expl_extr_scale * self.reward(seq) 123 | return reward -------------------------------------------------------------------------------- /common/dists.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow_probability as tfp 4 | from tensorflow_probability import distributions as tfd 5 | 6 | import common 7 | 8 | # Patch to ignore seed to avoid synchronization across GPUs. 9 | _orig_random_categorical = tf.random.categorical 10 | 11 | 12 | def random_categorical(*args, **kwargs): 13 | kwargs["seed"] = None 14 | return _orig_random_categorical(*args, **kwargs) 15 | 16 | 17 | tf.random.categorical = random_categorical 18 | 19 | # Patch to ignore seed to avoid synchronization across GPUs. 20 | _orig_random_normal = tf.random.normal 21 | 22 | 23 | def random_normal(*args, **kwargs): 24 | kwargs["seed"] = None 25 | return _orig_random_normal(*args, **kwargs) 26 | 27 | 28 | tf.random.normal = random_normal 29 | 30 | 31 | class SampleDist: 32 | def __init__(self, dist, samples=100): 33 | self._dist = dist 34 | self._samples = samples 35 | 36 | @property 37 | def name(self): 38 | return "SampleDist" 39 | 40 | def __getattr__(self, name): 41 | return getattr(self._dist, name) 42 | 43 | def mean(self): 44 | samples = self._dist.sample(self._samples) 45 | return samples.mean(0) 46 | 47 | def mode(self): 48 | sample = self._dist.sample(self._samples) 49 | logprob = self._dist.log_prob(sample) 50 | return tf.gather(sample, tf.argmax(logprob))[0] 51 | 52 | def entropy(self): 53 | sample = self._dist.sample(self._samples) 54 | logprob = self.log_prob(sample) 55 | return -logprob.mean(0) 56 | 57 | 58 | class OneHotDist(tfd.OneHotCategorical): 59 | def __init__(self, logits=None, probs=None, dtype=None): 60 | self._sample_dtype = dtype or tf.float32 61 | super().__init__(logits=logits, probs=probs) 62 | 63 | def mode(self): 64 | return tf.cast(super().mode(), self._sample_dtype) 65 | 66 | def sample(self, sample_shape=(), seed=None): 67 | # Straight through biased gradient estimator. 68 | sample = tf.cast(super().sample(sample_shape, seed), self._sample_dtype) 69 | probs = self._pad(super().probs_parameter(), sample.shape) 70 | sample += tf.cast(probs - tf.stop_gradient(probs), self._sample_dtype) 71 | return sample 72 | 73 | def _pad(self, tensor, shape): 74 | tensor = super().probs_parameter() 75 | while len(tensor.shape) < len(shape): 76 | tensor = tensor[None] 77 | return tensor 78 | 79 | 80 | class TruncNormalDist(tfd.TruncatedNormal): 81 | def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): 82 | super().__init__(loc, scale, low, high) 83 | self._clip = clip 84 | self._mult = mult 85 | 86 | def sample(self, *args, **kwargs): 87 | event = super().sample(*args, **kwargs) 88 | if self._clip: 89 | clipped = tf.clip_by_value( 90 | event, self.low + self._clip, self.high - self._clip 91 | ) 92 | event = event - tf.stop_gradient(event) + tf.stop_gradient(clipped) 93 | if self._mult: 94 | event *= self._mult 95 | return event 96 | 97 | 98 | class TanhBijector(tfp.bijectors.Bijector): 99 | def __init__(self, validate_args=False, name="tanh"): 100 | super().__init__( 101 | forward_min_event_ndims=0, validate_args=validate_args, name=name 102 | ) 103 | 104 | def _forward(self, x): 105 | return tf.nn.tanh(x) 106 | 107 | def _inverse(self, y): 108 | dtype = y.dtype 109 | y = tf.cast(y, tf.float32) 110 | y = tf.where( 111 | tf.less_equal(tf.abs(y), 1.0), 112 | tf.clip_by_value(y, -0.99999997, 0.99999997), 113 | y, 114 | ) 115 | y = tf.atanh(y) 116 | y = tf.cast(y, dtype) 117 | return y 118 | 119 | def _forward_log_det_jacobian(self, x): 120 | log2 = tf.math.log(tf.constant(2.0, dtype=x.dtype)) 121 | return 2.0 * (log2 - x - tf.nn.softplus(-2.0 * x)) 122 | 123 | 124 | class MSEDist: 125 | def __init__(self, mode, dims, agg="sum"): 126 | self._mode = mode 127 | self._dims = tuple([-x for x in range(1, dims + 1)]) 128 | self._agg = agg 129 | self.batch_shape = mode.shape[: len(mode.shape) - dims] 130 | self.event_shape = mode.shape[len(mode.shape) - dims :] 131 | 132 | def mode(self): 133 | return self._mode 134 | 135 | def mean(self): 136 | return self._mode 137 | 138 | def log_prob(self, value): 139 | assert self._mode.shape == value.shape, (self._mode.shape, value.shape) 140 | distance = (self._mode - value) ** 2 141 | if self._agg == "mean": 142 | loss = distance.mean(self._dims) 143 | elif self._agg == "sum": 144 | loss = distance.sum(self._dims) 145 | else: 146 | raise NotImplementedError(self._agg) 147 | return -loss 148 | 149 | 150 | class SymlogDist: 151 | def __init__(self, mode, dims, agg="sum"): 152 | self._mode = mode 153 | self._dims = tuple([-x for x in range(1, dims + 1)]) 154 | self._agg = agg 155 | self.batch_shape = mode.shape[: len(mode.shape) - dims] 156 | self.event_shape = mode.shape[len(mode.shape) - dims :] 157 | 158 | def mode(self): 159 | return symexp(self._mode) 160 | 161 | def mean(self): 162 | return symexp(self._mode) 163 | 164 | def log_prob(self, value): 165 | assert self._mode.shape == value.shape, (self._mode.shape, value.shape) 166 | distance = (self._mode - symlog(value)) ** 2 167 | if self._agg == "mean": 168 | loss = distance.mean(self._dims) 169 | elif self._agg == "sum": 170 | loss = distance.sum(self._dims) 171 | else: 172 | raise NotImplementedError(self._agg) 173 | return -loss 174 | 175 | def symlog(x): 176 | return tf.sign(x) * tf.math.log(1 + tf.abs(x)) 177 | 178 | 179 | def symexp(x): 180 | return tf.sign(x) * (tf.math.exp(tf.abs(x)) - 1) 181 | -------------------------------------------------------------------------------- /common/logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pathlib 4 | import time 5 | 6 | import numpy as np 7 | 8 | 9 | class Logger: 10 | def __init__(self, step, outputs, multiplier=1): 11 | self._step = step 12 | self._outputs = outputs 13 | self._multiplier = multiplier 14 | self._last_step = None 15 | self._last_time = None 16 | self._metrics = [] 17 | 18 | def add(self, mapping, prefix=None): 19 | step = int(self._step) * self._multiplier 20 | for name, value in dict(mapping).items(): 21 | name = f"{prefix}_{name}" if prefix else name 22 | value = np.array(value) 23 | if len(value.shape) not in (0, 2, 3, 4): 24 | raise ValueError( 25 | f"Shape {value.shape} for name '{name}' cannot be " 26 | "interpreted as scalar, image, or video." 27 | ) 28 | self._metrics.append((step, name, value)) 29 | 30 | def scalar(self, name, value): 31 | self.add({name: value}) 32 | 33 | def image(self, name, value): 34 | self.add({name: value}) 35 | 36 | def video(self, name, value): 37 | self.add({name: value}) 38 | 39 | def write(self, fps=False): 40 | fps and self.scalar("fps", self._compute_fps()) 41 | if not self._metrics: 42 | return 43 | for output in self._outputs: 44 | output(self._metrics) 45 | self._metrics.clear() 46 | 47 | def _compute_fps(self): 48 | step = int(self._step) * self._multiplier 49 | if self._last_step is None: 50 | self._last_time = time.time() 51 | self._last_step = step 52 | return 0 53 | steps = step - self._last_step 54 | duration = time.time() - self._last_time 55 | self._last_time += duration 56 | self._last_step = step 57 | return steps / duration 58 | 59 | 60 | class TerminalOutput: 61 | def __call__(self, summaries): 62 | step = max(s for s, _, _, in summaries) 63 | scalars = {k: float(v) for _, k, v in summaries if len(v.shape) == 0} 64 | formatted = {k: self._format_value(v) for k, v in scalars.items()} 65 | print(f"[{step}]", " / ".join(f"{k} {v}" for k, v in formatted.items())) 66 | 67 | def _format_value(self, value): 68 | if value == 0: 69 | return "0" 70 | elif 0.01 < abs(value) < 10000: 71 | value = f"{value:.2f}" 72 | value = value.rstrip("0") 73 | value = value.rstrip("0") 74 | value = value.rstrip(".") 75 | return value 76 | else: 77 | value = f"{value:.1e}" 78 | value = value.replace(".0e", "e") 79 | value = value.replace("+0", "") 80 | value = value.replace("+", "") 81 | value = value.replace("-0", "-") 82 | return value 83 | 84 | 85 | class JSONLOutput: 86 | def __init__(self, logdir): 87 | self._logdir = pathlib.Path(logdir).expanduser() 88 | 89 | def __call__(self, summaries): 90 | scalars = {k: float(v) for _, k, v in summaries if len(v.shape) == 0} 91 | step = max(s for s, _, _, in summaries) 92 | with (self._logdir / "metrics.jsonl").open("a") as f: 93 | f.write(json.dumps({"step": step, **scalars}) + "\n") 94 | 95 | 96 | class TensorBoardOutput: 97 | def __init__(self, logdir, fps=20): 98 | # The TensorFlow summary writer supports file protocols like gs://. We use 99 | # os.path over pathlib here to preserve those prefixes. 100 | self._logdir = os.path.expanduser(logdir) 101 | self._writer = None 102 | self._fps = fps 103 | 104 | def __call__(self, summaries): 105 | import tensorflow as tf 106 | 107 | self._ensure_writer() 108 | self._writer.set_as_default() 109 | for step, name, value in summaries: 110 | if len(value.shape) == 0: 111 | tf.summary.scalar("scalars/" + name, value, step) 112 | elif len(value.shape) == 2: 113 | tf.summary.image(name, value, step) 114 | elif len(value.shape) == 3: 115 | tf.summary.image(name, value, step) 116 | elif len(value.shape) == 4: 117 | self._video_summary(name, value, step) 118 | self._writer.flush() 119 | 120 | def _ensure_writer(self): 121 | if not self._writer: 122 | import tensorflow as tf 123 | 124 | self._writer = tf.summary.create_file_writer(self._logdir, max_queue=1000) 125 | 126 | def _video_summary(self, name, video, step): 127 | import tensorflow as tf 128 | import tensorflow.compat.v1 as tf1 129 | 130 | name = name if isinstance(name, str) else name.decode("utf-8") 131 | if np.issubdtype(video.dtype, np.floating): 132 | video = np.clip(255 * video, 0, 255).astype(np.uint8) 133 | try: 134 | T, H, W, C = video.shape 135 | summary = tf1.Summary() 136 | image = tf1.Summary.Image(height=H, width=W, colorspace=C) 137 | image.encoded_image_string = encode_gif(video, self._fps) 138 | summary.value.add(tag=name, image=image) 139 | tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step) 140 | except (IOError, OSError) as e: 141 | print("GIF summaries require ffmpeg in $PATH.", e) 142 | tf.summary.image(name, video, step) 143 | 144 | 145 | def encode_gif(frames, fps): 146 | from subprocess import Popen, PIPE 147 | 148 | h, w, c = frames[0].shape 149 | pxfmt = {1: "gray", 3: "rgb24"}[c] 150 | cmd = " ".join( 151 | [ 152 | "ffmpeg -y -f rawvideo -vcodec rawvideo", 153 | f"-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex", 154 | "[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse", 155 | f"-r {fps:.02f} -f gif -", 156 | ] 157 | ) 158 | proc = Popen(cmd.split(" "), stdin=PIPE, stdout=PIPE, stderr=PIPE) 159 | for image in frames: 160 | proc.stdin.write(image.tobytes()) 161 | out, err = proc.communicate() 162 | if proc.returncode: 163 | raise IOError("\n".join([" ".join(cmd), err.decode("utf8")])) 164 | del proc 165 | return out 166 | -------------------------------------------------------------------------------- /configs.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | 3 | # Train Script 4 | logdir: /dev/null 5 | loaddir: '' 6 | ts: '' 7 | seed: 0 8 | task: 'take_lid_off_saucepan' 9 | envs: 1 10 | envs_parallel: process 11 | render_size: [128, 128] 12 | time_limit: 150 13 | steps: 1002000 14 | log_every: 1e3 15 | eval_every: 1e4 16 | eval_eps: 10 17 | prefill: 0 18 | pretrain: 1000 19 | mae_pretrain: 10000 20 | train_every: 2 21 | train_mae_every: 2 22 | train_steps: 1 23 | train_mae_steps: 1 24 | replay: {capacity: 2e6, minlen: 1, maxlen: 50, prioritize_ends: True} 25 | dataset: {batch: 50, length: 50} 26 | mae_dataset: {batch: 32, length: 32} 27 | log_keys_video: ['image'] 28 | log_keys_sum: '^$' 29 | log_keys_mean: '^$' 30 | log_keys_max: '^$' 31 | precision: 16 32 | jit: True 33 | action_repeat: 1 34 | device: 'cuda:0' 35 | vidlang_model_device: 'cuda:1' 36 | actor_linear_probe: False 37 | critic_linear_probe: False 38 | scripted_corner: 'top_left' 39 | randomize: False 40 | tune_instruction: False 41 | num_tune_instructions: 10 42 | instructions_file: '' 43 | 44 | # Env 45 | eval_noise: 0.0 46 | expl_noise: 0.0 47 | franka_kitchen: False 48 | 49 | # Agent 50 | clip_rewards: identity 51 | 52 | # Foundation Model 53 | use_r3m_reward: False 54 | use_internvideo_reward: False 55 | use_clip_reward: False 56 | multi_vidlang: False 57 | multi_task_vidlang: False 58 | internvideo_load_dir: "InternVideo/Pretrain/Multi-Modalities-Pretraining/models/InternVideo-MM-B-16.ckpt" 59 | standardize_rewards: False 60 | queue_size: 100000 61 | update_stats_steps: 2000 62 | num_top_images: 2 63 | use_zero_rewards: False 64 | boundary_reward_penalty: False 65 | use_lang_embeddings: False 66 | 67 | # Demo 68 | num_demos: 100 69 | shaped_rewards: True 70 | 71 | # MAE 72 | camera_keys: 'image_front|image_wrist' 73 | mask_ratio: 0.95 74 | mae: {img_h_size: 128, img_w_size: 128, patch_size: 16, embed_dim: 256, depth: 8, num_heads: 4, decoder_embed_dim: 256, decoder_depth: 6, decoder_num_heads: 4, reward_pred: True, early_conv: True, state_pred: True, in_chans: 3, ncams: 0, state_dim: 10, view_masking: True, control_input: 'front_wrist'} 75 | wm_flat_vit: {img_h_size: 8, img_w_size: 8, patch_size: 1, embed_dim: 128, depth: 2, num_heads: 4, decoder_embed_dim: 128, decoder_depth: 2, decoder_num_heads: 4, in_chans: 256, state_pred: False} 76 | image_t_size: 4 77 | mae_chunk: 1 78 | mae_avg: False 79 | 80 | # World Model 81 | grad_heads: [reward, discount] 82 | pred_discount: True 83 | rssm: {action_free: False, hidden: 1024, deter: 1024, stoch: 32, discrete: 32, act: elu, norm: none, std_act: sigmoid2, min_std: 0.1} 84 | reward_head: {layers: [512, 512, 512, 512], act: elu, norm: none, dist: symlog} 85 | discount_head: {layers: [512, 512, 512, 512], act: elu, norm: none, dist: binary} 86 | loss_scales: {feature: 1.0, kl: 1.0, reward: 1.0, discount: 1.0, proprio: 1.0, mae_reward: 1.0} 87 | wmkl: {scale: 1.0} 88 | wmkl_minloss: 0.1 89 | wmkl_balance: 0.8 90 | model_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100.0, wd: 1e-6, wd_pattern: 'kernel', warmup: 0} 91 | mae_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100.0, wd: 1e-6, warmup: 2500} 92 | 93 | # Actor Critic 94 | actor: {layers: [512, 512, 512, 512], act: elu, norm: none, dist: trunc_normal, min_std: 0.1} 95 | critic: {layers: [512, 512, 512, 512], act: elu, norm: none, dist: mse} 96 | actor_opt: {opt: adam, lr: 1e-4, eps: 1e-5, clip: 100.0, wd: 1e-6, wd_pattern: 'kernel', warmup: 0} 97 | critic_opt: {opt: adam, lr: 1e-4, eps: 1e-5, clip: 100.0, wd: 1e-6, wd_pattern: 'kernel', warmup: 0} 98 | discount: 0.99 99 | discount_lambda: 0.95 100 | imag_horizon: 15 101 | actor_grad: dynamics 102 | actor_grad_mix: 0.1 103 | aent: {scale: 1e-4} 104 | slow_target: True 105 | slow_target_update: 100 106 | slow_target_fraction: 1 107 | slow_baseline: True 108 | reward_norm: {momentum: 0.99, scale: 1.0, eps: 1e-8} 109 | curriculum: {use: True, num_episodes: '100|100|100', neg_lang_prompt: 'reach away from bowl', lang_prompt: 'reach_for_bowl|reach_for_mug|reach_for_jar', objects: 'bowl|mug|jar', synonym_folder: null, num_objects: "3|3|3", num_unique_per_class: "-1|-1|-1"} 110 | 111 | # Plan2Explore 112 | plan2explore: False 113 | expl_intr_scale: 0.5 114 | expl_extr_scale: 0.5 115 | expl_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100, wd: 1e-6} 116 | expl_head: {layers: [512, 512, 512, 512], act: elu, norm: none, dist: mse} 117 | expl_reward_norm: {momentum: 1.0, scale: 1.0, eps: 1e-8} 118 | disag_target: stoch 119 | disag_log: False 120 | disag_models: 10 121 | disag_offset: 1 122 | disag_action_cond: True 123 | expl_model_loss: kl 124 | 125 | # Rnd 126 | rnd: False 127 | rnd_embedding_dim: 512 128 | rnd_hidden_dim: 256 129 | rnd_opt: {opt: adam, lr: 3e-4, eps: 1e-5, clip: 100.0, wd: 1e-6, warmup: 2500} 130 | 131 | front: 132 | camera_keys: image_front 133 | mae.control_input: front 134 | 135 | wrist: 136 | camera_keys: image_wrist 137 | mae.control_input: wrist 138 | 139 | overhead: 140 | camera_keys: image_overhead 141 | mae.control_input: overhead 142 | 143 | front_wrist: 144 | camera_keys: image_front|image_wrist 145 | mae.control_input: front_wrist 146 | 147 | overhead_wrist: 148 | camera_keys: image_overhead|image_wrist 149 | mae.control_input: overhead_wrist 150 | 151 | front_wrist_to_front: 152 | camera_keys: image_front|image_wrist 153 | mae.control_input: front 154 | 155 | front_wrist_to_wrist: 156 | camera_keys: image_front|image_wrist 157 | mae.control_input: wrist 158 | 159 | vlsp: 160 | image_t_size: 1 161 | use_imagenet_mae: False 162 | mae.view_masking: False 163 | mae.depth: 3 164 | mae.decoder_depth: 2 165 | prefill: 200 166 | num_demos: 0 167 | mae_pretrain: 0 168 | pretrain: 0 169 | 170 | ptmae: 171 | mae.img_w_size: 224 172 | mae.img_h_size: 224 173 | wm_flat_vit.img_h_size: 7 174 | wm_flat_vit.img_w_size: 7 175 | mae.state_pred: False 176 | wm_flat_vit.in_chans: 768 177 | wm_flat_vit.embed_dim: 128 178 | mae_avg: True 179 | 180 | debug: 181 | eval_eps: 1 182 | dataset.batch: 8 183 | dataset.length: 10 184 | mae_dataset.batch: 4 185 | mae_dataset.length: 8 186 | mae.depth: 1 187 | mae.decoder_depth: 1 188 | pretrain: 1 189 | mae_pretrain: 1 190 | num_demos: 1 191 | rssm.hidden: 64 192 | rssm.deter: 64 193 | rssm.stoch: 4 194 | rssm.discrete: 4 195 | imag_horizon: 3 196 | jit: False 197 | log_every: 100 198 | -------------------------------------------------------------------------------- /common/tfutils.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import pickle 3 | import re 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from tensorflow.keras import mixed_precision as prec 8 | 9 | try: 10 | from tensorflow.python.distribute import values 11 | except Exception: 12 | from google3.third_party.tensorflow.python.distribute import values 13 | 14 | tf.tensor = tf.convert_to_tensor 15 | for base in (tf.Tensor, tf.Variable, values.PerReplica): 16 | base.mean = tf.math.reduce_mean 17 | base.std = tf.math.reduce_std 18 | base.var = tf.math.reduce_variance 19 | base.sum = tf.math.reduce_sum 20 | base.any = tf.math.reduce_any 21 | base.all = tf.math.reduce_all 22 | base.min = tf.math.reduce_min 23 | base.max = tf.math.reduce_max 24 | base.abs = tf.math.abs 25 | base.logsumexp = tf.math.reduce_logsumexp 26 | base.transpose = tf.transpose 27 | base.reshape = tf.reshape 28 | base.astype = tf.cast 29 | 30 | 31 | # values.PerReplica.dtype = property(lambda self: self.values[0].dtype) 32 | 33 | # tf.TensorHandle.__repr__ = lambda x: '' 34 | # tf.TensorHandle.__str__ = lambda x: '' 35 | # np.set_printoptions(threshold=5, edgeitems=0) 36 | 37 | 38 | class Module(tf.Module): 39 | def save(self, filename, verbose=True): 40 | values = tf.nest.map_structure(lambda x: x.numpy(), self.variables) 41 | amount = len(tf.nest.flatten(values)) 42 | count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values))) 43 | if verbose: 44 | print(f"Save checkpoint with {amount} tensors and {count} parameters.") 45 | with pathlib.Path(filename).open("wb") as f: 46 | pickle.dump(values, f) 47 | 48 | def load(self, filename, verbose=True): 49 | with pathlib.Path(filename).open("rb") as f: 50 | values = pickle.load(f) 51 | amount = len(tf.nest.flatten(values)) 52 | count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values))) 53 | if verbose: 54 | print(f"Load checkpoint with {amount} tensors and {count} parameters.") 55 | for i in range(len(values)): 56 | # print(f"{i} {self.variables[i].name} self.variables: {self.variables[i].shape}, {values[i].shape}. {values[i].shape == self.variables[i].shape}") 57 | if values[i].shape != self.variables[i].shape: 58 | print(f"{i} {self.variables[i].name} self.variables: {self.variables[i].shape}, {values[i].shape}") 59 | 60 | tf.nest.map_structure(lambda x, y: x.assign(y), self.variables, values) 61 | 62 | def get(self, name, ctor, *args, **kwargs): 63 | # Create or get layer by name to avoid mentioning it in the constructor. 64 | if not hasattr(self, "_modules"): 65 | self._modules = {} 66 | if name not in self._modules: 67 | self._modules[name] = ctor(*args, **kwargs) 68 | return self._modules[name] 69 | 70 | 71 | class Optimizer(tf.Module): 72 | def __init__( 73 | self, 74 | name, 75 | lr, 76 | eps=1e-4, 77 | clip=None, 78 | wd=None, 79 | opt="adam", 80 | warmup=0, 81 | wd_pattern=r".*", 82 | ): 83 | assert 0 <= wd < 1 84 | assert not clip or 1 <= clip 85 | self._name = name 86 | self._clip = clip 87 | self._wd = wd 88 | self._wd_pattern = wd_pattern 89 | self._updates = tf.Variable(0, trainable=False, dtype=tf.int64) 90 | self._lr = lr 91 | if warmup: 92 | self._lr = lambda: lr * tf.clip_by_value( 93 | self._updates.astype(tf.float32) / warmup, 0.0, 1.0 94 | ) 95 | self._opt = { 96 | "adam": lambda: tf.optimizers.Adam(self._lr, epsilon=eps), 97 | "nadam": lambda: tf.optimizers.Nadam(self._lr, epsilon=eps), 98 | "adamax": lambda: tf.optimizers.Adamax(self._lr, epsilon=eps), 99 | "sgd": lambda: tf.optimizers.SGD(self._lr), 100 | "momentum": lambda: tf.optimizers.SGD(self._lr, 0.9), 101 | }[opt]() 102 | self._mixed = prec.global_policy().compute_dtype == tf.float16 103 | if self._mixed: 104 | self._opt = prec.LossScaleOptimizer(self._opt, dynamic=True) 105 | self._once = True 106 | 107 | @property 108 | def variables(self): 109 | return self._opt.variables() 110 | 111 | def __call__(self, tape, loss, modules): 112 | assert loss.dtype is tf.float32, (self._name, loss.dtype) 113 | assert len(loss.shape) == 0, (self._name, loss.shape) 114 | metrics = {} 115 | 116 | # Find variables. 117 | modules = modules if hasattr(modules, "__len__") else (modules,) 118 | varibs = tf.nest.flatten([module.variables for module in modules]) 119 | count = sum(np.prod(x.shape) for x in varibs) 120 | if self._once: 121 | print(f"Found {count} {self._name} parameters.") 122 | self._once = False 123 | 124 | # Check loss. 125 | tf.debugging.check_numerics(loss, self._name + "_loss") 126 | metrics[f"{self._name}_loss"] = loss 127 | 128 | # Compute scaled gradient. 129 | if self._mixed: 130 | with tape: 131 | loss = self._opt.get_scaled_loss(loss) 132 | grads = tape.gradient(loss, varibs) 133 | if self._mixed: 134 | grads = self._opt.get_unscaled_gradients(grads) 135 | if self._mixed: 136 | metrics[f"{self._name}_loss_scale"] = self._opt.loss_scale 137 | 138 | # Distributed sync. 139 | context = tf.distribute.get_replica_context() 140 | if context: 141 | grads = context.all_reduce("mean", grads) 142 | 143 | # Gradient clipping. 144 | norm = tf.linalg.global_norm(grads) 145 | if not self._mixed: 146 | tf.debugging.check_numerics(norm, self._name + "_norm") 147 | if self._clip: 148 | grads, _ = tf.clip_by_global_norm(grads, self._clip, norm) 149 | metrics[f"{self._name}_grad_norm"] = norm 150 | 151 | # Weight decay. 152 | if self._wd: 153 | self._apply_weight_decay(varibs) 154 | 155 | # Apply gradients. 156 | self._opt.apply_gradients( 157 | zip(grads, varibs), experimental_aggregate_gradients=False 158 | ) 159 | self._updates.assign_add(1) 160 | 161 | return metrics 162 | 163 | def _apply_weight_decay(self, varibs): 164 | nontrivial = self._wd_pattern != r".*" 165 | # if nontrivial: 166 | # print("Applied weight decay to variables:") 167 | for var in varibs: 168 | if re.search(self._wd_pattern, self._name + "/" + var.name): 169 | # if nontrivial: 170 | # print("- " + self._name + "/" + var.name) 171 | var.assign((1 - self._wd) * var) 172 | -------------------------------------------------------------------------------- /common/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pathlib 3 | import re 4 | 5 | 6 | class Config(dict): 7 | 8 | SEP = "." 9 | IS_PATTERN = re.compile(r".*[^A-Za-z0-9_.-].*") 10 | 11 | def __init__(self, *args, **kwargs): 12 | mapping = dict(*args, **kwargs) 13 | mapping = self._flatten(mapping) 14 | mapping = self._ensure_keys(mapping) 15 | mapping = self._ensure_values(mapping) 16 | self._flat = mapping 17 | self._nested = self._nest(mapping) 18 | # Need to assign the values to the base class dictionary so that 19 | # conversion to dict does not lose the content. 20 | super().__init__(self._nested) 21 | 22 | @property 23 | def flat(self): 24 | return self._flat.copy() 25 | 26 | def save(self, filename): 27 | filename = pathlib.Path(filename) 28 | if filename.suffix == ".json": 29 | filename.write_text(json.dumps(dict(self))) 30 | elif filename.suffix in (".yml", ".yaml"): 31 | import ruamel.yaml as yaml 32 | 33 | with filename.open("w") as f: 34 | yaml.safe_dump(dict(self), f) 35 | else: 36 | raise NotImplementedError(filename.suffix) 37 | 38 | @classmethod 39 | def load(cls, filename): 40 | filename = pathlib.Path(filename) 41 | if filename.suffix == ".json": 42 | return cls(json.loads(filename.read_text())) 43 | elif filename.suffix in (".yml", ".yaml"): 44 | import ruamel.yaml as yaml 45 | 46 | return cls(yaml.safe_load(filename.read_text())) 47 | else: 48 | raise NotImplementedError(filename.suffix) 49 | 50 | def parse_flags(self, argv=None, known_only=False, help_exists=None): 51 | from . import flags 52 | 53 | return flags.Flags(self).parse(argv, known_only, help_exists) 54 | 55 | def __contains__(self, name): 56 | try: 57 | self[name] 58 | return True 59 | except KeyError: 60 | return False 61 | 62 | def __getattr__(self, name): 63 | if name.startswith("_"): 64 | return super().__getattr__(name) 65 | try: 66 | return self[name] 67 | except KeyError: 68 | raise AttributeError(name) 69 | 70 | def __getitem__(self, name): 71 | result = self._nested 72 | for part in name.split(self.SEP): 73 | result = result[part] 74 | if isinstance(result, dict): 75 | result = type(self)(result) 76 | return result 77 | 78 | def __setattr__(self, key, value): 79 | if key.startswith("_"): 80 | return super().__setattr__(key, value) 81 | message = f"Tried to set key '{key}' on immutable config. Use update()." 82 | raise AttributeError(message) 83 | 84 | def __setitem__(self, key, value): 85 | if key.startswith("_"): 86 | return super().__setitem__(key, value) 87 | message = f"Tried to set key '{key}' on immutable config. Use update()." 88 | raise AttributeError(message) 89 | 90 | def __reduce__(self): 91 | return (type(self), (dict(self),)) 92 | 93 | def __str__(self): 94 | lines = ["\nConfig:"] 95 | keys, vals, typs = [], [], [] 96 | for key, val in self.flat.items(): 97 | keys.append(key + ":") 98 | vals.append(self._format_value(val)) 99 | typs.append(self._format_type(val)) 100 | max_key = max(len(k) for k in keys) if keys else 0 101 | max_val = max(len(v) for v in vals) if vals else 0 102 | for key, val, typ in zip(keys, vals, typs): 103 | key = key.ljust(max_key) 104 | val = val.ljust(max_val) 105 | lines.append(f"{key} {val} ({typ})") 106 | return "\n".join(lines) 107 | 108 | def update(self, *args, **kwargs): 109 | result = self._flat.copy() 110 | inputs = self._flatten(dict(*args, **kwargs)) 111 | for key, new in inputs.items(): 112 | if self.IS_PATTERN.match(key): 113 | pattern = re.compile(key) 114 | keys = {k for k in result if pattern.match(k)} 115 | else: 116 | keys = [key] 117 | if not keys: 118 | raise KeyError(f"Unknown key or pattern {key}.") 119 | for key in keys: 120 | if key not in result: 121 | result[key] = new 122 | continue 123 | old = result[key] 124 | try: 125 | if isinstance(old, int) and isinstance(new, float): 126 | if float(int(new)) != new: 127 | message = f"Cannot convert fractional float {new} to int." 128 | raise ValueError(message) 129 | if old is None: 130 | result[key] = str(new) 131 | else: 132 | result[key] = type(old)(new) 133 | except (ValueError, TypeError): 134 | raise TypeError( 135 | f"Cannot convert '{new}' to type '{type(old).__name__}' " 136 | + f"of value '{old}' for key '{key}'." 137 | ) 138 | return type(self)(result) 139 | 140 | def _flatten(self, mapping): 141 | result = {} 142 | for key, value in mapping.items(): 143 | if isinstance(value, dict): 144 | for k, v in self._flatten(value).items(): 145 | if self.IS_PATTERN.match(key) or self.IS_PATTERN.match(k): 146 | combined = f"{key}\\{self.SEP}{k}" 147 | else: 148 | combined = f"{key}{self.SEP}{k}" 149 | result[combined] = v 150 | else: 151 | result[key] = value 152 | return result 153 | 154 | def _nest(self, mapping): 155 | result = {} 156 | for key, value in mapping.items(): 157 | parts = key.split(self.SEP) 158 | node = result 159 | for part in parts[:-1]: 160 | if part not in node: 161 | node[part] = {} 162 | node = node[part] 163 | node[parts[-1]] = value 164 | return result 165 | 166 | def _ensure_keys(self, mapping): 167 | for key in mapping: 168 | assert not self.IS_PATTERN.match(key), key 169 | return mapping 170 | 171 | def _ensure_values(self, mapping): 172 | result = json.loads(json.dumps(mapping)) 173 | for key, value in result.items(): 174 | if isinstance(value, list): 175 | value = tuple(value) 176 | if isinstance(value, tuple): 177 | if len(value) == 0: 178 | message = ( 179 | "Empty lists are disallowed because their type is unclear." 180 | ) 181 | raise TypeError(message) 182 | if not isinstance(value[0], (str, float, int, bool)): 183 | message = "Lists can only contain strings, floats, ints, bools" 184 | message += f" but not {type(value[0])}" 185 | raise TypeError(message) 186 | if not all(isinstance(x, type(value[0])) for x in value[1:]): 187 | message = "Elements of a list must all be of the same type." 188 | raise TypeError(message) 189 | result[key] = value 190 | return result 191 | 192 | def _format_value(self, value): 193 | if isinstance(value, (list, tuple)): 194 | return "[" + ", ".join(self._format_value(x) for x in value) + "]" 195 | return str(value) 196 | 197 | def _format_type(self, value): 198 | if isinstance(value, (list, tuple)): 199 | assert len(value) > 0, value 200 | return self._format_type(value[0]) + "s" 201 | return str(type(value).__name__) 202 | -------------------------------------------------------------------------------- /common/kitchen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | import random 4 | import itertools 5 | from itertools import combinations 6 | from common.base_envs import BenchEnv 7 | from d4rl.kitchen.kitchen_envs import KitchenMicrowaveKettleBottomBurnerLightV0 8 | # import d4rl.kitchen.kitchen_envs 9 | 10 | 11 | class KitchenEnv(BenchEnv): 12 | def __init__(self, task, action_repeat=1, use_goal_idx=False, log_per_goal=False, control_mode='end_effector', width=64): 13 | # currently, task is not used 14 | super().__init__(action_repeat, width) 15 | self.use_goal_idx = use_goal_idx 16 | self.log_per_goal = log_per_goal 17 | with self.LOCK: 18 | self._env = KitchenMicrowaveKettleBottomBurnerLightV0() 19 | 20 | self._env.sim_robot.renderer._camera_settings = dict( 21 | distance=1.86, lookat=[-0.3, .5, 2.], azimuth=90, elevation=-60) 22 | 23 | self.rendered_goal = False 24 | self._env.reset() 25 | self.init_qpos = self._env.sim.data.qpos.copy() 26 | self.goal_idx = 0 27 | self.obs_element_goals, self.obs_element_indices, self.goal_configs = get_kitchen_benchmark_goals() 28 | self.goals = list(range(len(self.obs_element_goals))) 29 | 30 | @property 31 | def act_space(self): 32 | return {"action": self.action_space} 33 | 34 | @property 35 | def obs_space(self): 36 | return self.observation_space 37 | 38 | def set_goal_idx(self, idx): 39 | self.goal_idx = idx 40 | 41 | def get_goal_idx(self): 42 | return self.goal_idx 43 | 44 | def get_goals(self): 45 | return self.goals 46 | 47 | def _get_obs(self, state): 48 | image = self._env.render('rgb_array') 49 | obs = {'image': image, 'state': state, 'image_goal': self.render_goal(), 'goal': self.goal} 50 | if self.log_per_goal: 51 | for i, goal_idx in enumerate(self.goals): 52 | # add rewards for all goals 53 | task_rel_success, all_obj_success = self.compute_success(goal_idx) 54 | obs['metric_success_task_relevant/goal_'+str(goal_idx)] = task_rel_success 55 | obs['metric_success_all_objects/goal_'+str(goal_idx)] = all_obj_success 56 | if self.use_goal_idx: 57 | task_rel_success, all_obj_success = self.compute_success(self.goal_idx) 58 | obs['metric_success_task_relevant/goal_'+str(self.goal_idx)] = task_rel_success 59 | obs['metric_success_all_objects/goal_'+str(self.goal_idx)] = all_obj_success 60 | 61 | return obs 62 | 63 | def reset(self): 64 | 65 | with self.LOCK: 66 | state = self._env.reset() 67 | if not self.use_goal_idx: 68 | self.goal_idx = np.random.randint(len(self.goals)) 69 | self.goal = self.goals[self.goal_idx] 70 | self.rendered_goal = False 71 | obs = self._get_obs(state) 72 | obs['state'] = obs['state'].astype(np.float32) 73 | obs["is_last"] = False 74 | obs["is_first"] = True 75 | obs["reward"] = 0 # not sure if this is good 76 | obs["success"] = 1 # hard-code for now? 77 | obs["is_terminal"] = False 78 | obs["lang_num"] = 0 # change in future for multitask 79 | self._init_vidlang_time_step = obs 80 | obs['init_image'] = obs['image'] 81 | obs['init_state'] = obs['state'] 82 | # print(type(obs['state']), obs['state'].dtype) 83 | # breakpoint() 84 | return obs 85 | 86 | def step(self, action): 87 | total_reward = 0.0 88 | for step in range(self._action_repeat): 89 | state, reward, done, info = self._env.step(action['action']) 90 | reward = self.compute_reward() 91 | total_reward += reward 92 | if done: 93 | break 94 | obs = self._get_obs(state) 95 | for k, v in obs.items(): 96 | if 'metric_' in k: 97 | info[k] = v 98 | obs["is_last"] = done 99 | obs["is_terminal"] = done 100 | obs["is_first"] = False 101 | obs["reward"] = total_reward 102 | obs["success"] = 1 # hard-code for now? 103 | obs["lang_num"] = 0 # change in future for multitask 104 | obs['state'] = obs['state'].astype(np.float32) 105 | obs['init_image'] = self._init_vidlang_time_step['init_image'] 106 | obs['init_state'] = self._init_vidlang_time_step['init_state'] 107 | return obs #, total_reward, done, info 108 | 109 | def compute_reward(self, goal=None): 110 | if goal is None: 111 | goal = self.goal 112 | qpos = self._env.sim.data.qpos.copy() 113 | 114 | if len(self.obs_element_indices[goal]) > 9 : 115 | return -np.linalg.norm(qpos[self.obs_element_indices[goal]][9:] - self.obs_element_goals[goal][9:]) 116 | else: 117 | return -np.linalg.norm(qpos[self.obs_element_indices[goal]] - self.obs_element_goals[goal]) 118 | 119 | def compute_success(self, goal = None): 120 | 121 | if goal is None: 122 | goal = self.goal 123 | qpos = self._env.sim.data.qpos.copy() 124 | 125 | goal_qpos = self.init_qpos.copy() 126 | goal_qpos[self.obs_element_indices[goal]] = self.obs_element_goals[goal] 127 | 128 | per_obj_success = { 129 | 'bottom_burner' : ((qpos[9]<-0.38) and (goal_qpos[9]<-0.38)) or ((qpos[9]>-0.38) and (goal_qpos[9]>-0.38)), 130 | 'top_burner': ((qpos[13]<-0.38) and (goal_qpos[13]<-0.38)) or ((qpos[13]>-0.38) and (goal_qpos[13]>-0.38)), 131 | 'light_switch': ((qpos[17]<-0.25) and (goal_qpos[17]<-0.25)) or ((qpos[17]>-0.25) and (goal_qpos[17]>-0.25)), 132 | 'slide_cabinet' : abs(qpos[19] - goal_qpos[19])<0.1, 133 | 'hinge_cabinet' : abs(qpos[21] - goal_qpos[21])<0.2, 134 | 'microwave' : abs(qpos[22] - goal_qpos[22])<0.2, 135 | 'kettle' : np.linalg.norm(qpos[23:25] - goal_qpos[23:25]) < 0.2 136 | } 137 | task_objects = self.goal_configs[goal] 138 | 139 | task_rel_success = 1 140 | for _obj in task_objects: 141 | task_rel_success *= per_obj_success[_obj] 142 | 143 | all_obj_success = 1 144 | for _obj in per_obj_success: 145 | all_obj_success *= per_obj_success[_obj] 146 | 147 | return int(task_rel_success), int(all_obj_success) 148 | 149 | def render_goal(self): 150 | if self.rendered_goal: 151 | return self.rendered_goal_obj 152 | 153 | # random.sample(list(obs_element_goals), 1)[0] 154 | backup_qpos = self._env.sim.data.qpos.copy() 155 | backup_qvel = self._env.sim.data.qvel.copy() 156 | 157 | qpos = self.init_qpos.copy() 158 | qpos[self.obs_element_indices[self.goal]] = self.obs_element_goals[self.goal] 159 | self._env.set_state(qpos, np.zeros(len(self._env.init_qvel))) 160 | 161 | goal_obs = self._env.render('rgb_array') 162 | 163 | self._env.set_state(backup_qpos, backup_qvel) 164 | 165 | self.rendered_goal = True 166 | self.rendered_goal_obj = goal_obs 167 | return goal_obs 168 | 169 | def get_kitchen_benchmark_goals(): 170 | 171 | object_goal_vals = {'bottom_burner' : [-0.88, -0.01], 172 | 'light_switch' : [ -0.69, -0.05], 173 | 'slide_cabinet': [0.37], 174 | 'hinge_cabinet': [0., 0.5], 175 | 'microwave' : [-0.5], 176 | 'kettle' : [-0.23, 0.75, 1.62]} 177 | 178 | object_goal_idxs = {'bottom_burner' : [9, 10], 179 | 'light_switch' : [17, 18], 180 | 'slide_cabinet': [19], 181 | 'hinge_cabinet': [20, 21], 182 | 'microwave' : [22], 183 | 'kettle' : [23, 24, 25]} 184 | 185 | base_task_names = [ 'bottom_burner', 'light_switch', 'slide_cabinet', 186 | 'hinge_cabinet', 'microwave', 'kettle' ] 187 | 188 | 189 | goal_configs = [] 190 | #single task 191 | for i in range(6): 192 | goal_configs.append( [base_task_names[i]]) 193 | 194 | #two tasks 195 | for i,j in combinations([1,2,3,5], 2) : 196 | goal_configs.append( [base_task_names[i], base_task_names[j]] ) 197 | 198 | obs_element_goals = [] ; obs_element_indices = [] 199 | for objects in goal_configs: 200 | _goal = np.concatenate([object_goal_vals[obj] for obj in objects]) 201 | _goal_idxs = np.concatenate([object_goal_idxs[obj] for obj in objects]) 202 | 203 | obs_element_goals.append(_goal) 204 | obs_element_indices.append(_goal_idxs) 205 | 206 | return obs_element_goals, obs_element_indices, goal_configs -------------------------------------------------------------------------------- /common/other.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import contextlib 3 | import re 4 | import time 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | from tensorflow_probability import distributions as tfd 9 | 10 | from . import dists 11 | from . import tfutils 12 | 13 | 14 | class ScriptedAgent: 15 | def __init__(self, act_space, eval_env, corner = 'top_right'): 16 | self.act_space = act_space["action"] 17 | offset_xy = 0.15 18 | offset_z = 0.05 19 | if corner == 'top_left': 20 | self.target = [offset_xy,offset_xy,offset_z] 21 | elif corner == 'top_right': 22 | self.target = [offset_xy,-offset_xy,offset_z] 23 | elif corner == 'bottom_left': 24 | self.target = [-offset_xy,offset_xy,offset_z] 25 | elif corner == 'bottom_right': 26 | self.target = [-offset_xy,-offset_xy,offset_z] 27 | self.eval_env = eval_env 28 | print("low: ", self.act_space.low) 29 | print("high: ", self.act_space.high) 30 | 31 | def __call__(self, obs, state=None, mode=None): 32 | print(obs['state']) 33 | action = np.zeros(self.act_space.shape)[None] 34 | action[:,3] = 1. # Keep gripper open 35 | 36 | container_pos = self.eval_env.access("container_pos")() 37 | print("container_pos: ", container_pos) 38 | 39 | distance = container_pos + self.target - obs['state'][0][1:4] 40 | print("distance: ", distance) 41 | action[:,:3] = (distance - self.act_space.low[:3])/(self.act_space.high[:3] - self.act_space.low[:3]) * 2 - 1 42 | action = np.clip(action, -np.ones_like(self.act_space.low), np.ones_like(self.act_space.high)) 43 | 44 | output = {"action": action} 45 | 46 | print("action: ", action) 47 | 48 | return output, None 49 | 50 | class RandomAgent: 51 | def __init__(self, act_space, logprob=False): 52 | self.act_space = act_space["action"] 53 | self.logprob = logprob 54 | if hasattr(self.act_space, "n"): 55 | self._dist = dists.OneHotDist(tf.zeros(self.act_space.n)) 56 | else: 57 | dist = tfd.Uniform(self.act_space.low, self.act_space.high) 58 | self._dist = tfd.Independent(dist, 1) 59 | 60 | def __call__(self, obs, state=None, mode=None): 61 | action = self._dist.sample(len(obs["is_first"])) 62 | output = {"action": action} 63 | if self.logprob: 64 | output["logprob"] = self._dist.log_prob(action) 65 | return output, None 66 | 67 | 68 | def static_scan(fn, inputs, start, reverse=False): 69 | last = start 70 | outputs = [[] for _ in tf.nest.flatten(start)] 71 | indices = range(tf.nest.flatten(inputs)[0].shape[0]) 72 | if reverse: 73 | indices = reversed(indices) 74 | for index in indices: 75 | inp = tf.nest.map_structure(lambda x: x[index], inputs) 76 | last = fn(last, inp) 77 | [o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))] 78 | if reverse: 79 | outputs = [list(reversed(x)) for x in outputs] 80 | outputs = [tf.stack(x, 0) for x in outputs] 81 | return tf.nest.pack_sequence_as(start, outputs) 82 | 83 | 84 | def schedule(string, step): 85 | try: 86 | return float(string) 87 | except ValueError: 88 | step = tf.cast(step, tf.float32) 89 | match = re.match(r"linear\((.+),(.+),(.+)\)", string) 90 | if match: 91 | initial, final, duration = [float(group) for group in match.groups()] 92 | mix = tf.clip_by_value(step / duration, 0, 1) 93 | return (1 - mix) * initial + mix * final 94 | match = re.match(r"warmup\((.+),(.+)\)", string) 95 | if match: 96 | warmup, value = [float(group) for group in match.groups()] 97 | scale = tf.clip_by_value(step / warmup, 0, 1) 98 | return scale * value 99 | match = re.match(r"exp\((.+),(.+),(.+)\)", string) 100 | if match: 101 | initial, final, halflife = [float(group) for group in match.groups()] 102 | return (initial - final) * 0.5 ** (step / halflife) + final 103 | match = re.match(r"horizon\((.+),(.+),(.+)\)", string) 104 | if match: 105 | initial, final, duration = [float(group) for group in match.groups()] 106 | mix = tf.clip_by_value(step / duration, 0, 1) 107 | horizon = (1 - mix) * initial + mix * final 108 | return 1 - 1 / horizon 109 | raise NotImplementedError(string) 110 | 111 | 112 | def lambda_return(reward, value, pcont, bootstrap, lambda_, axis): 113 | # Setting lambda=1 gives a discounted Monte Carlo return. 114 | # Setting lambda=0 gives a fixed 1-step return. 115 | assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape) 116 | if isinstance(pcont, (int, float)): 117 | pcont = pcont * tf.ones_like(reward) 118 | dims = list(range(reward.shape.ndims)) 119 | dims = [axis] + dims[1:axis] + [0] + dims[axis + 1 :] 120 | if axis != 0: 121 | reward = tf.transpose(reward, dims) 122 | value = tf.transpose(value, dims) 123 | pcont = tf.transpose(pcont, dims) 124 | if bootstrap is None: 125 | bootstrap = tf.zeros_like(value[-1]) 126 | next_values = tf.concat([value[1:], bootstrap[None]], 0) 127 | inputs = reward + pcont * next_values * (1 - lambda_) 128 | returns = static_scan( 129 | lambda agg, cur: cur[0] + cur[1] * lambda_ * agg, 130 | (inputs, pcont), 131 | bootstrap, 132 | reverse=True, 133 | ) 134 | if axis != 0: 135 | returns = tf.transpose(returns, dims) 136 | return returns 137 | 138 | 139 | def action_noise(action, amount, act_space): 140 | if amount == 0: 141 | return action 142 | amount = tf.cast(amount, action.dtype) 143 | if hasattr(act_space, "n"): 144 | probs = amount / action.shape[-1] + (1 - amount) * action 145 | return dists.OneHotDist(probs=probs).sample() 146 | else: 147 | return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1) 148 | 149 | 150 | class StreamNorm(tfutils.Module): 151 | def __init__(self, shape=(), momentum=0.99, scale=1.0, eps=1e-8): 152 | # Momentum of 0 normalizes only based on the current batch. 153 | # Momentum of 1 disables normalization. 154 | self._shape = tuple(shape) 155 | self._momentum = momentum 156 | self._scale = scale 157 | self._eps = eps 158 | self.mag = tf.Variable(tf.ones(shape, tf.float64), False) 159 | 160 | def __call__(self, inputs): 161 | metrics = {} 162 | self.update(inputs) 163 | metrics["mean"] = inputs.mean() 164 | metrics["std"] = inputs.std() 165 | outputs = self.transform(inputs) 166 | metrics["normed_mean"] = outputs.mean() 167 | metrics["normed_std"] = outputs.std() 168 | return outputs, metrics 169 | 170 | def reset(self): 171 | self.mag.assign(tf.ones_like(self.mag)) 172 | 173 | def update(self, inputs): 174 | batch = inputs.reshape((-1,) + self._shape) 175 | mag = tf.abs(batch).mean(0).astype(tf.float64) 176 | self.mag.assign(self._momentum * self.mag + (1 - self._momentum) * mag) 177 | 178 | def transform(self, inputs): 179 | values = inputs.reshape((-1,) + self._shape) 180 | values /= self.mag.astype(inputs.dtype)[None] + self._eps 181 | values *= self._scale 182 | return values.reshape(inputs.shape) 183 | 184 | 185 | class Timer: 186 | def __init__(self): 187 | self._indurs = collections.defaultdict(list) 188 | self._outdurs = collections.defaultdict(list) 189 | self._start_times = {} 190 | self._end_times = {} 191 | 192 | @contextlib.contextmanager 193 | def section(self, name): 194 | self.start(name) 195 | yield 196 | self.end(name) 197 | 198 | def wrap(self, function, name): 199 | def wrapped(*args, **kwargs): 200 | with self.section(name): 201 | return function(*args, **kwargs) 202 | 203 | return wrapped 204 | 205 | def start(self, name): 206 | now = time.time() 207 | self._start_times[name] = now 208 | if name in self._end_times: 209 | last = self._end_times[name] 210 | self._outdurs[name].append(now - last) 211 | 212 | def end(self, name): 213 | now = time.time() 214 | self._end_times[name] = now 215 | self._indurs[name].append(now - self._start_times[name]) 216 | 217 | def result(self): 218 | metrics = {} 219 | for key in self._indurs: 220 | indurs = self._indurs[key] 221 | outdurs = self._outdurs[key] 222 | metrics[f"timer_count_{key}"] = len(indurs) 223 | metrics[f"timer_inside_{key}"] = np.sum(indurs) 224 | metrics[f"timer_outside_{key}"] = np.sum(outdurs) 225 | indurs.clear() 226 | outdurs.clear() 227 | return metrics 228 | 229 | 230 | class CarryOverState: 231 | def __init__(self, fn): 232 | self._fn = fn 233 | self._state = None 234 | 235 | def __call__(self, *args): 236 | self._state, out = self._fn(*args, self._state) 237 | return out 238 | -------------------------------------------------------------------------------- /common/replay.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import datetime 3 | import io 4 | import pathlib 5 | import uuid 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | import albumentations as A 11 | 12 | 13 | class Replay: 14 | def __init__( 15 | self, 16 | directory, 17 | load_directory=None, 18 | capacity=0, 19 | minlen=1, 20 | maxlen=0, 21 | prioritize_ends=False, 22 | ): 23 | self._directory = pathlib.Path(directory).expanduser() 24 | self._directory.mkdir(parents=True, exist_ok=True) 25 | self._capacity = capacity 26 | self._minlen = minlen 27 | self._maxlen = maxlen 28 | self._prioritize_ends = prioritize_ends 29 | self._random = np.random.RandomState() 30 | 31 | self.load_directory = load_directory 32 | if load_directory is None: 33 | load_directory = self._directory 34 | else: 35 | load_directory = pathlib.Path(load_directory).expanduser() 36 | # filename -> key -> value_sequence 37 | self._complete_eps = load_episodes(load_directory, capacity, minlen) 38 | if len(self._complete_eps) != 0: 39 | self._eps_keys = list(self._complete_eps.keys()) 40 | idxs = np.sort( 41 | np.argsort(np.random.uniform(size=len(self._eps_keys)))[ 42 | : int(np.ceil(len(self._eps_keys) / 2)) 43 | ] 44 | ) 45 | else: 46 | self._eps_keys = [] 47 | self._eps_masks = dict() 48 | # worker -> key -> value_sequence 49 | self._ongoing_eps = collections.defaultdict( 50 | lambda: collections.defaultdict(list) 51 | ) 52 | self._total_episodes, self._total_steps = count_episodes(directory) 53 | self._loaded_episodes = len(self._complete_eps) 54 | self._loaded_steps = sum(eplen(x) for x in self._complete_eps.values()) 55 | self.reward_func = None 56 | 57 | @property 58 | def stats(self): 59 | return { 60 | "total_steps": self._total_steps, 61 | "total_episodes": self._total_episodes, 62 | "loaded_steps": self._loaded_steps, 63 | "loaded_episodes": self._loaded_episodes, 64 | } 65 | 66 | def reward_relabel(self, reward_func): 67 | self.reward_func = reward_func 68 | 69 | def add_step(self, transition, worker=0): 70 | episode = self._ongoing_eps[worker] 71 | for key, value in transition.items(): 72 | episode[key].append(value) 73 | if transition["is_last"]: 74 | self.add_episode(episode) 75 | episode.clear() 76 | 77 | def add_episode(self, episode): 78 | if 'lang_num' not in episode.keys() or len(episode['lang_num']) != len(episode['reward']): 79 | # running into errors when running pure RL, no multi-task learning 80 | # there's likely a better way to fix, but couldn't find it 81 | episode['lang_num'] = [0] * len(episode['reward']) 82 | length = eplen(episode) 83 | if length < self._minlen: 84 | print(f"Skipping short episode of length {length}.") 85 | return 86 | self._total_steps += length 87 | self._loaded_steps += length 88 | self._total_episodes += 1 89 | self._loaded_episodes += 1 90 | episode = {key: convert(value) for key, value in episode.items()} 91 | if self.reward_func is not None: 92 | vidlang_reward = self.reward_func(episode) 93 | episode["reward"] = vidlang_reward 94 | filename = save_episode(self._directory, episode) 95 | self._complete_eps[str(filename)] = episode 96 | 97 | self._eps_keys.append(str(filename)) 98 | self._enforce_limit() 99 | 100 | def add_demo_episode(self, episode): 101 | success = np.sum(episode["success"]) >= 1.0 102 | if success: 103 | self.add_episode(episode) 104 | 105 | def dataset(self, batch, length): 106 | example = next(iter(self._generate_chunks(length))) 107 | dataset = tf.data.Dataset.from_generator( 108 | lambda: self._generate_chunks(length), 109 | {k: v.dtype for k, v in example.items()}, 110 | {k: v.shape for k, v in example.items()}, 111 | ) 112 | dataset = dataset.batch(batch, drop_remainder=True) 113 | dataset = dataset.prefetch(5) 114 | return dataset 115 | 116 | def _generate_chunks(self, length): 117 | sequence = self._sample_sequence() 118 | while True: 119 | chunk = collections.defaultdict(list) 120 | added = 0 121 | while added < length: 122 | needed = length - added 123 | adding = {k: v[:needed] for k, v in sequence.items()} 124 | sequence = {k: v[needed:] for k, v in sequence.items()} 125 | for key, value in adding.items(): 126 | chunk[key].append(value) 127 | added += len(adding["reward"]) 128 | if len(sequence["reward"]) < 1: 129 | sequence = self._sample_sequence() 130 | chunk = {k: np.concatenate(v) for k, v in chunk.items()} 131 | yield chunk 132 | 133 | def _sample_sequence(self): 134 | eps_keys = self._eps_keys 135 | L = len(eps_keys) 136 | i = np.random.randint(0, L) 137 | episode_key = eps_keys[i] 138 | episode = self._complete_eps[episode_key] 139 | 140 | total = len(episode["reward"]) 141 | length = total 142 | if self._maxlen: 143 | length = min(length, self._maxlen) 144 | # Randomize length to avoid all chunks ending at the same time in case the 145 | # episodes are all of the same length. 146 | length -= np.random.randint(self._minlen) 147 | length = max(self._minlen, length) 148 | upper = total - length + 1 149 | if self._prioritize_ends: 150 | upper += self._minlen 151 | index = min(self._random.randint(upper), total - length) 152 | sequence = { 153 | k: convert(v[index : index + length]) 154 | for k, v in episode.items() 155 | if not k.startswith("log_") 156 | } 157 | sequence["is_first"] = np.zeros(len(sequence["reward"]), bool) 158 | sequence["is_first"][0] = True 159 | if self._maxlen: 160 | assert self._minlen <= len(sequence["reward"]) <= self._maxlen 161 | return sequence 162 | 163 | def _enforce_limit(self): 164 | if not self._capacity: 165 | return 166 | while self._loaded_episodes > 1 and self._loaded_steps > self._capacity: 167 | # Relying on Python preserving the insertion order of dicts. 168 | oldest, episode = next(iter(self._complete_eps.items())) 169 | self._loaded_steps -= eplen(episode) 170 | self._loaded_episodes -= 1 171 | del self._complete_eps[oldest] 172 | 173 | 174 | def count_episodes(directory): 175 | filenames = list(directory.glob("*.npz")) 176 | num_episodes = len(filenames) 177 | num_steps = sum(int(str(n).split("-")[-1][:-4]) - 1 for n in filenames) 178 | return num_episodes, num_steps 179 | 180 | 181 | def save_episode(directory, episode): 182 | timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S") 183 | identifier = str(uuid.uuid4().hex) 184 | length = eplen(episode) 185 | filename = directory / f"{timestamp}-{identifier}-{length}.npz" 186 | with io.BytesIO() as f1: 187 | np.savez_compressed(f1, **episode) 188 | f1.seek(0) 189 | with filename.open("wb") as f2: 190 | f2.write(f1.read()) 191 | return filename 192 | 193 | 194 | def load_episodes(directory, capacity=None, minlen=1): 195 | # The returned directory from filenames to episodes is guaranteed to be in 196 | # temporally sorted order. 197 | filenames = sorted(directory.glob("*.npz")) 198 | if capacity: 199 | num_steps = 0 200 | num_episodes = 0 201 | for filename in reversed(filenames): 202 | length = int(str(filename).split("-")[-1][:-4]) 203 | num_steps += length 204 | num_episodes += 1 205 | if num_steps >= capacity: 206 | break 207 | filenames = filenames[-num_episodes:] 208 | episodes = {} 209 | for filename in filenames: 210 | try: 211 | with filename.open("rb") as f: 212 | episode = np.load(f) 213 | episode = {k: episode[k] for k in episode.keys()} 214 | except Exception as e: 215 | print(f"Could not load episode {str(filename)}: {e}") 216 | continue 217 | episodes[str(filename)] = episode 218 | return episodes 219 | 220 | 221 | def convert(value): 222 | value = np.array(value) 223 | if np.issubdtype(value.dtype, np.floating): 224 | return value.astype(np.float32) 225 | elif np.issubdtype(value.dtype, np.signedinteger): 226 | return value.astype(np.int32) 227 | elif np.issubdtype(value.dtype, np.uint8): 228 | return value.astype(np.uint8) 229 | return value 230 | 231 | 232 | def eplen(episode): 233 | return len(episode["reward"]) - 1 234 | -------------------------------------------------------------------------------- /common/nets.py: -------------------------------------------------------------------------------- 1 | import re 2 | import functools 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | from tensorflow.keras import layers as tfkl 7 | from tensorflow.keras import initializers as tfki 8 | from tensorflow_probability import distributions as tfd 9 | from tensorflow.keras.mixed_precision import experimental as prec 10 | 11 | import common 12 | 13 | 14 | class RSSM(common.Module): 15 | def __init__( 16 | self, 17 | action_free=False, 18 | stoch=30, 19 | deter=200, 20 | hidden=200, 21 | discrete=False, 22 | act="elu", 23 | norm="none", 24 | std_act="softplus", 25 | min_std=0.1, 26 | ): 27 | super().__init__() 28 | self._action_free = action_free 29 | self._stoch = stoch 30 | self._deter = deter 31 | self._hidden = hidden 32 | self._discrete = discrete 33 | self._act = get_act(act) 34 | self._norm = norm 35 | self._std_act = std_act 36 | self._min_std = min_std 37 | self._cell = GRUCell(self._deter, norm=True) 38 | self._cast = lambda x: tf.cast(x, prec.global_policy().compute_dtype) 39 | 40 | def initial(self, batch_size): 41 | dtype = prec.global_policy().compute_dtype 42 | if self._discrete: 43 | state = dict( 44 | logit=tf.zeros([batch_size, self._stoch, self._discrete], dtype), 45 | stoch=tf.zeros([batch_size, self._stoch, self._discrete], dtype), 46 | deter=self._cell.get_initial_state(None, batch_size, dtype), 47 | ) 48 | else: 49 | state = dict( 50 | mean=tf.zeros([batch_size, self._stoch], dtype), 51 | std=tf.zeros([batch_size, self._stoch], dtype), 52 | stoch=tf.zeros([batch_size, self._stoch], dtype), 53 | deter=self._cell.get_initial_state(None, batch_size, dtype), 54 | ) 55 | return state 56 | 57 | def fill_action_with_zero(self, action): 58 | # action: [B, action] 59 | B, D = action.shape[0], action.shape[1] 60 | if self._action_free: 61 | return self._cast(tf.zeros([B, 50])) 62 | else: 63 | zeros = self._cast(tf.zeros([B, 50 - D])) 64 | return tf.concat([action, zeros], axis=1) 65 | 66 | @tf.function 67 | def observe(self, embed, action, is_first, state=None): 68 | swap = lambda x: tf.transpose(x, [1, 0] + list(range(2, len(x.shape)))) 69 | if state is None: 70 | state = self.initial(tf.shape(action)[0]) 71 | post, prior = common.static_scan( 72 | lambda prev, inputs: self.obs_step(prev[0], *inputs), 73 | (swap(action), swap(embed), swap(is_first)), 74 | (state, state), 75 | ) 76 | post = {k: swap(v) for k, v in post.items()} 77 | prior = {k: swap(v) for k, v in prior.items()} 78 | return post, prior 79 | 80 | @tf.function 81 | def imagine(self, action, state=None): 82 | swap = lambda x: tf.transpose(x, [1, 0] + list(range(2, len(x.shape)))) 83 | if state is None: 84 | state = self.initial(tf.shape(action)[0]) 85 | assert isinstance(state, dict), state 86 | action = swap(action) 87 | prior = common.static_scan(self.img_step, action, state) 88 | prior = {k: swap(v) for k, v in prior.items()} 89 | return prior 90 | 91 | def get_feat(self, state): 92 | stoch = self._cast(state["stoch"]) 93 | if self._discrete: 94 | shape = stoch.shape[:-2] + [self._stoch * self._discrete] 95 | stoch = tf.reshape(stoch, shape) 96 | return tf.concat([stoch, state[f"deter"]], -1) 97 | 98 | def get_dist(self, state): 99 | if self._discrete: 100 | logit = state["logit"] 101 | logit = tf.cast(logit, tf.float32) 102 | dist = tfd.Independent(common.OneHotDist(logit), 1) 103 | else: 104 | mean, std = state["mean"], state["std"] 105 | mean = tf.cast(mean, tf.float32) 106 | std = tf.cast(std, tf.float32) 107 | dist = tfd.MultivariateNormalDiag(mean, std) 108 | return dist 109 | 110 | @tf.function 111 | def obs_step(self, prev_state, prev_action, embed, is_first, sample=True): 112 | # if is_first.any(): 113 | prev_state, prev_action = tf.nest.map_structure( 114 | lambda x: tf.einsum("b,b...->b...", 1.0 - is_first.astype(x.dtype), x), 115 | (prev_state, prev_action), 116 | ) 117 | prior = self.img_step(prev_state, prev_action, sample) 118 | x = tf.concat([prior[f"deter"], embed], -1) 119 | x = self.get("obs_out", tfkl.Dense, self._hidden)(x) 120 | x = self.get("obs_out_norm", NormLayer, self._norm)(x) 121 | x = self._act(x) 122 | stats = self._suff_stats_layer("obs_dist", x) 123 | dist = self.get_dist(stats) 124 | stoch = dist.sample() if sample else dist.mode() 125 | post = {"stoch": stoch, "deter": prior[f"deter"], **stats} 126 | return post, prior 127 | 128 | @tf.function 129 | def img_step(self, prev_state, prev_action, sample=True): 130 | prev_stoch = self._cast(prev_state["stoch"]) 131 | prev_action = self._cast(prev_action) 132 | if self._discrete: 133 | shape = prev_stoch.shape[:-2] + [self._stoch * self._discrete] 134 | prev_stoch = tf.reshape(prev_stoch, shape) 135 | x = tf.concat([prev_stoch, self.fill_action_with_zero(prev_action)], -1) 136 | x = self.get("img_in", tfkl.Dense, self._hidden)(x) 137 | x = self.get("img_in_norm", NormLayer, self._norm)(x) 138 | x, deter = self._cell(x, [prev_state[f"deter"]]) 139 | deter = deter[0] 140 | x = self.get("img_out", tfkl.Dense, self._hidden)(x) 141 | x = self.get("img_out_norm", NormLayer, self._norm)(x) 142 | x = self._act(x) 143 | stats = self._suff_stats_layer(f"img_dist", x) 144 | dist = self.get_dist(stats) 145 | stoch = dist.sample() if sample else dist.mode() 146 | prior = {"stoch": stoch, "deter": deter, **stats} 147 | return prior 148 | 149 | def _suff_stats_layer(self, name, x): 150 | if self._discrete: 151 | x = self.get(name, tfkl.Dense, self._stoch * self._discrete, None)(x) 152 | logit = tf.reshape(x, x.shape[:-1] + [self._stoch, self._discrete]) 153 | return {"logit": logit} 154 | else: 155 | x = self.get(name, tfkl.Dense, 2 * self._stoch, None)(x) 156 | mean, std = tf.split(x, 2, -1) 157 | std = { 158 | "softplus": lambda: tf.nn.softplus(std), 159 | "sigmoid": lambda: tf.nn.sigmoid(std), 160 | "sigmoid2": lambda: 2 * tf.nn.sigmoid(std / 2), 161 | }[self._std_act]() 162 | std = std + self._min_std 163 | return {"mean": mean, "std": std} 164 | 165 | def kl_loss(self, post, prior, balance=0.8): 166 | post_const = tf.nest.map_structure(tf.stop_gradient, post) 167 | prior_const = tf.nest.map_structure(tf.stop_gradient, prior) 168 | lhs = tfd.kl_divergence(self.get_dist(post_const), self.get_dist(prior)) 169 | rhs = tfd.kl_divergence(self.get_dist(post), self.get_dist(prior_const)) 170 | return balance * lhs + (1 - balance) * rhs 171 | 172 | 173 | class MLP(common.Module): 174 | def __init__( 175 | self, shape, linear_probe=False, layers=[512, 512, 512, 512], act="elu", norm="none", **out 176 | ): 177 | self._shape = (shape,) if isinstance(shape, int) else shape 178 | self._layers = layers 179 | self._norm = norm 180 | self._act = get_act(act) 181 | self._out = out 182 | self._linear_probe = linear_probe 183 | 184 | def __call__(self, features): 185 | x = tf.cast(features, prec.global_policy().compute_dtype) 186 | x = x.reshape([-1, x.shape[-1]]) 187 | for index, unit in enumerate(self._layers): 188 | x = self.get(f"dense{index}", tfkl.Dense, unit, trainable=(not self._linear_probe))(x) 189 | x = self.get(f"norm{index}", NormLayer, self._norm, trainable=(not self._linear_probe))(x) 190 | x = self._act(x) 191 | if self._linear_probe: 192 | x = x.reshape(features.shape[:-1] + [x.shape[-1]]) 193 | return self.get("out_probe", DistLayer, self._shape, **self._out)(x) 194 | x = x.reshape(features.shape[:-1] + [x.shape[-1]]) 195 | return self.get("out", DistLayer, self._shape, **self._out)(x) 196 | 197 | 198 | class GRUCell(tf.keras.layers.AbstractRNNCell): 199 | def __init__(self, size, norm=True, act="tanh", update_bias=-1, **kwargs): 200 | super().__init__() 201 | self._size = size 202 | self._act = get_act(act) 203 | self._update_bias = update_bias 204 | self._layer = tfkl.Dense(3 * size, **kwargs) 205 | if norm: 206 | self._norm = NormLayer("layer") 207 | else: 208 | self._norm = NormLayer("none") 209 | 210 | @property 211 | def state_size(self): 212 | return self._size 213 | 214 | @tf.function 215 | def call(self, inputs, state): 216 | state = state[0] # Keras wraps the state in a list. 217 | parts = self._layer(tf.concat([inputs, state], -1)) 218 | parts = self._norm(parts) 219 | reset, cand, update = tf.split(parts, 3, -1) 220 | reset = tf.nn.sigmoid(reset) 221 | cand = self._act(reset * cand) 222 | update = tf.nn.sigmoid(update + self._update_bias) 223 | output = update * cand + (1 - update) * state 224 | return output, [output] 225 | 226 | 227 | class DistLayer(common.Module): 228 | def __init__( 229 | self, 230 | shape, 231 | dist="mse", 232 | outscale=0.1, 233 | min_std=0.1, 234 | max_std=1.0, 235 | ): 236 | self._shape = shape 237 | self._dist = dist 238 | self._min_std = min_std 239 | self._max_std = max_std 240 | self._outscale = outscale 241 | 242 | def __call__(self, inputs): 243 | kw = {} 244 | if self._outscale == 0.0: 245 | kw["kernel_initializer"] = tfki.Zeros() 246 | else: 247 | kw["kernel_initializer"] = tfki.VarianceScaling( 248 | self._outscale, "fan_avg", "uniform" 249 | ) 250 | out = self.get("out", tfkl.Dense, np.prod(self._shape), **kw)(inputs) 251 | out = tf.reshape(out, tf.concat([tf.shape(inputs)[:-1], self._shape], 0)) 252 | out = tf.cast(out, tf.float32) 253 | if self._dist in ("normal", "trunc_normal"): 254 | std = self.get("std", tfkl.Dense, np.prod(self._shape))(inputs) 255 | std = tf.reshape(std, tf.concat([tf.shape(inputs)[:-1], self._shape], 0)) 256 | std = tf.cast(std, tf.float32) 257 | if self._dist == "mse": 258 | return common.MSEDist(out, len(self._shape), "sum") 259 | if self._dist == "symlog": 260 | return common.SymlogDist(out, len(self._shape), "sum") 261 | if self._dist == "nmse": 262 | return common.NormalizedMSEDist(out, len(self._shape), "sum") 263 | if self._dist == "normal": 264 | lo, hi = self._min_std, self._max_std 265 | std = (hi - lo) * tf.nn.sigmoid(std) + lo 266 | dist = tfd.Normal(tf.tanh(out), std) 267 | dist = tfd.Independent(dist, len(self._shape)) 268 | dist.minent = np.prod(self._shape) * tfd.Normal(0.0, lo).entropy() 269 | dist.maxent = np.prod(self._shape) * tfd.Normal(0.0, hi).entropy() 270 | return dist 271 | if self._dist == "binary": 272 | dist = tfd.Bernoulli(out) 273 | return tfd.Independent(dist, len(self._shape)) 274 | if self._dist == "trunc_normal": 275 | lo, hi = self._min_std, self._max_std 276 | std = (hi - lo) * tf.nn.sigmoid(std) + lo 277 | dist = tfd.TruncatedNormal(tf.tanh(out), std, -1, 1) 278 | dist = tfd.Independent(dist, 1) 279 | dist.minent = np.prod(self._shape) * tfd.Normal(0.99, lo).entropy() 280 | dist.maxent = np.prod(self._shape) * tfd.Normal(0.0, hi).entropy() 281 | return dist 282 | if self._dist == "onehot": 283 | dist = common.OneHotDist(out) 284 | if len(self._shape) > 1: 285 | dist = tfd.Independent(dist, len(self._shape) - 1) 286 | dist.minent = 0.0 287 | dist.maxent = np.prod(self._shape[:-1]) * np.log(self._shape[-1]) 288 | return dist 289 | raise NotImplementedError(self._dist) 290 | 291 | 292 | class NormLayer(common.Module, tf.keras.layers.Layer): 293 | def __init__(self, impl, trainable=True): 294 | super().__init__() 295 | self._impl = impl 296 | self._trainable = trainable 297 | 298 | def build(self, input_shape): 299 | if self._impl == "keras": 300 | self.layer = tfkl.LayerNormalization(trainable=self._trainable) 301 | self.layer.build(input_shape) 302 | elif self._impl == "layer": 303 | self.scale = self.add_weight("scale", input_shape[-1], tf.float32, "Ones", trainable=self._trainable) 304 | self.offset = self.add_weight( 305 | "offset", input_shape[-1], tf.float32, "Zeros", trainable=self._trainable 306 | ) 307 | 308 | def call(self, x): 309 | if self._impl == "none": 310 | return x 311 | elif self._impl == "keras": 312 | return self.layer(x) 313 | elif self._impl == "layer": 314 | mean, var = tf.nn.moments(x, -1, keepdims=True) 315 | return tf.nn.batch_normalization( 316 | x, mean, var, self.offset, self.scale, 1e-3 317 | ) 318 | else: 319 | raise NotImplementedError(self._impl) 320 | 321 | 322 | class MLPEncoder(common.Module): 323 | def __init__( 324 | self, act="elu", norm="none", layers=[512, 512, 512, 512], batchnorm=False 325 | ): 326 | self._act = get_act(act) 327 | self._layers = layers 328 | self._norm = norm 329 | self._batchnorm = batchnorm 330 | 331 | @tf.function 332 | def __call__(self, x, training=False): 333 | x = x.astype(prec.global_policy().compute_dtype) 334 | if self._batchnorm: 335 | x = self.get(f"batchnorm", tfkl.BatchNormalization)(x, training=training) 336 | for i, unit in enumerate(self._layers): 337 | x = self.get(f"dense{i}", tfkl.Dense, unit)(x) 338 | x = self.get(f"densenorm{i}", NormLayer, self._norm)(x) 339 | x = self._act(x) 340 | return x 341 | 342 | 343 | class CNNEncoder(common.Module): 344 | def __init__( 345 | self, 346 | cnn_depth=64, 347 | cnn_kernels=(4, 4), 348 | act="elu", 349 | ): 350 | self._act = get_act(act) 351 | self._cnn_depth = cnn_depth 352 | self._cnn_kernels = cnn_kernels 353 | 354 | @tf.function 355 | def __call__(self, x): 356 | x = x.astype(prec.global_policy().compute_dtype) 357 | for i, kernel in enumerate(self._cnn_kernels): 358 | depth = 2 ** i * self._cnn_depth 359 | x = self.get(f"conv{i}", tfkl.Conv2D, depth, kernel, 1)(x) 360 | x = self._act(x) 361 | return x 362 | 363 | 364 | class CNNDecoder(common.Module): 365 | def __init__( 366 | self, 367 | out_dim, 368 | cnn_depth=64, 369 | cnn_kernels=(4, 5), 370 | act="elu", 371 | ): 372 | self._out_dim = out_dim 373 | self._act = get_act(act) 374 | self._cnn_depth = cnn_depth 375 | self._cnn_kernels = cnn_kernels 376 | 377 | @tf.function 378 | def __call__(self, x): 379 | x = x.astype(prec.global_policy().compute_dtype) 380 | 381 | x = self.get("convin", tfkl.Dense, 2 * 2 * 2 * self._cnn_depth)(x) 382 | x = tf.reshape(x, [-1, 1, 1, 8 * self._cnn_depth]) 383 | 384 | for i, kernel in enumerate(self._cnn_kernels): 385 | depth = 2 ** (len(self._cnn_kernels) - i - 1) * self._cnn_depth 386 | x = self.get(f"conv{i}", tfkl.Conv2DTranspose, depth, kernel, 1)(x) 387 | x = self._act(x) 388 | x = self.get("convout", tfkl.Dense, self._out_dim)(x) 389 | return x 390 | 391 | 392 | def get_act(name): 393 | if name == "none": 394 | return tf.identity 395 | if name == "mish": 396 | return lambda x: x * tf.math.tanh(tf.nn.softplus(x)) 397 | elif hasattr(tf.nn, name): 398 | return getattr(tf.nn, name) 399 | elif hasattr(tf, name): 400 | return getattr(tf, name) 401 | else: 402 | raise NotImplementedError(name) 403 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import functools 3 | import logging 4 | import os 5 | import pathlib 6 | import re 7 | import sys 8 | import warnings 9 | import torch 10 | import itertools 11 | from omegaconf import OmegaConf 12 | 13 | try: 14 | import rich.traceback 15 | 16 | rich.traceback.install() 17 | except ImportError: 18 | pass 19 | 20 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 21 | logging.getLogger().setLevel("ERROR") 22 | warnings.filterwarnings("ignore", ".*box bound precision lowered.*") 23 | 24 | sys.path.append(str(pathlib.Path(__file__).parent)) 25 | sys.path.append(str(pathlib.Path(__file__).parent.parent)) 26 | 27 | import numpy as np 28 | from keras import backend as K 29 | import ruamel.yaml as yaml 30 | 31 | import common 32 | 33 | def process_curriculum_str(lang_curriculum_str): 34 | cfg_lang_prompts = lang_curriculum_str.split("|") 35 | lang_curriculum = [] 36 | for i in range(len(cfg_lang_prompts)): 37 | lang_curriculum.append(cfg_lang_prompts[i].split(",")) 38 | 39 | for i in range(len(lang_curriculum)): 40 | for j in range(len(lang_curriculum[i])): 41 | if os.path.exists(lang_curriculum[i][j]): 42 | with open(lang_curriculum[i][j], 'r') as f: 43 | lang_curriculum[i].pop(j) 44 | lang_curriculum[i].extend(f.read().splitlines()) 45 | 46 | lang_instructions = list(itertools.chain(*lang_curriculum)) 47 | lang_instructions = list(set(lang_instructions)) 48 | return lang_curriculum, lang_instructions 49 | 50 | def get_lang_info(multi_task_vidlang, task, lang_prompt, synonym_folder, objects): 51 | """ 52 | Returns 53 | final_lang_instructions: list of all possible processed language instructions 54 | lang_to_num: dict mapping language instruction (from above) to number 55 | lang_curriculum: list of non-processed instructions for each stage of curriculum 56 | synonym_dict: dict mapping object to object synonyms 57 | """ 58 | if multi_task_vidlang: 59 | lang_instructions = task.split(',') 60 | return lang_instructions, None, None 61 | else: 62 | lang_curriculum, lang_instructions = process_curriculum_str(lang_prompt) 63 | 64 | _, all_objs = process_curriculum_str(objects) 65 | noun_variations = [] 66 | synonym_dict = {} 67 | if synonym_folder is None: 68 | for obj in all_objs: 69 | synonym_dict[obj] = obj 70 | noun_variations = all_objs 71 | else: 72 | for obj in all_objs: 73 | with open(os.path.join(synonym_folder, f"synonym_{obj}.txt"), 'r') as f: 74 | synonyms = f.read().splitlines() 75 | synonym_dict[obj] = synonyms 76 | noun_variations.extend(synonyms) 77 | noun_variations = list(set(noun_variations)) 78 | 79 | final_lang_instructions = [] 80 | for lang_instr in lang_instructions: 81 | if "[NOUN]" in lang_instr: 82 | for noun in noun_variations: 83 | final_lang_instructions.append(lang_instr.replace("[NOUN]", noun)) 84 | else: 85 | final_lang_instructions.append(lang_instr) 86 | 87 | lang_nums = range(len(final_lang_instructions)) 88 | lang_to_num = dict(zip(final_lang_instructions, lang_nums)) 89 | return final_lang_instructions, lang_to_num, lang_curriculum, synonym_dict 90 | 91 | def main(): 92 | 93 | configs = yaml.safe_load( 94 | (pathlib.Path(sys.argv[0]).parent / "configs.yaml").read_text() 95 | ) 96 | parsed, remaining = common.Flags(configs=["defaults"]).parse(known_only=True) 97 | config = common.Config(configs["defaults"]) 98 | for name in parsed.configs: 99 | config = config.update(configs[name]) 100 | config = common.Flags(config).parse(remaining) 101 | 102 | logdir = pathlib.Path(config.logdir).expanduser() 103 | logdir.mkdir(parents=True, exist_ok=True) 104 | config.save(logdir / "config.yaml") 105 | 106 | print(config, "\n") 107 | print("Logdir", logdir) 108 | 109 | loaddir = pathlib.Path(config.loaddir).expanduser() 110 | print("Loaddir", loaddir) 111 | 112 | import tensorflow as tf 113 | tf.config.experimental_run_functions_eagerly(not config.jit) 114 | message = "No GPU found. To actually train on CPU remove this assert." 115 | assert tf.config.experimental.list_physical_devices("GPU"), message 116 | print(tf.config.experimental.list_physical_devices("GPU")) 117 | for gpu in tf.config.experimental.list_physical_devices("GPU"): 118 | tf.config.experimental.set_memory_growth(gpu, True) 119 | assert config.precision in (16, 32), config.precision 120 | if config.precision == 16: 121 | from tensorflow.keras.mixed_precision import experimental as prec 122 | 123 | prec.set_policy(prec.Policy("mixed_float16")) 124 | 125 | train_replay = common.Replay(logdir / "train_episodes", **config.replay) 126 | eval_replay = common.Replay( 127 | logdir / "eval_episodes", 128 | **dict( 129 | capacity=config.replay.capacity // 10, 130 | minlen=config.replay.minlen, 131 | maxlen=config.replay.maxlen, 132 | ), 133 | ) 134 | step = common.Counter(train_replay.stats["total_steps"]) 135 | outputs = [ 136 | common.TerminalOutput(), 137 | common.JSONLOutput(logdir), 138 | common.TensorBoardOutput(logdir), 139 | ] 140 | logger = common.Logger(step, outputs, multiplier=config.action_repeat) 141 | metrics = collections.defaultdict(list) 142 | 143 | should_train = common.Every(config.train_every) 144 | should_train_mae = common.Every(config.train_mae_every) 145 | should_log = common.Every(config.log_every) 146 | 147 | lang_instructions, lang_to_num, lang_curriculum, synonym_dict = get_lang_info(config.multi_task_vidlang, config.task, config.curriculum.lang_prompt, config.curriculum.synonym_folder, config.curriculum.objects) 148 | 149 | if (loaddir / f"variables_{config.ts}.pkl").exists(): 150 | with open(loaddir / 'config.yaml', 'r') as f: 151 | pt_cfg = OmegaConf.create(yaml.safe_load(f)) 152 | pt_lang_instructions, _, _, _ = get_lang_info(pt_cfg.multi_task_vidlang, pt_cfg.task, pt_cfg.curriculum.lang_prompt, pt_cfg.curriculum.synonym_folder, pt_cfg.curriculum.objects) 153 | if (pt_cfg.plan2explore or pt_cfg.rnd): 154 | if (pt_cfg.use_r3m_reward or pt_cfg.use_internvideo_reward or pt_cfg.use_clip_reward): 155 | config = config.update({'mae.state_dim': config.mae.state_dim + 768}) 156 | config = config.update({'pretrain_mode': 'lang_emb'}) 157 | else: 158 | config = config.update({'pretrain_mode': 'no_lang'}) 159 | elif pt_cfg.use_lang_embeddings: 160 | config = config.update({'mae.state_dim': config.mae.state_dim + 768}) 161 | config = config.update({'pretrain_mode': 'lang_emb'}) 162 | else: 163 | config = config.update({'mae.state_dim': config.mae.state_dim + len(pt_lang_instructions)}) 164 | config = config.update({'pretrain_mode': 'one_hot'}) 165 | config = config.update({'num_langs': len(pt_lang_instructions)}) 166 | config = config.update({'train_mode': 'finetune'}) 167 | elif (config.plan2explore or config.rnd) and not (config.use_r3m_reward or config.use_internvideo_reward or config.use_clip_reward): 168 | config = config.update({'train_mode': 'pretrain'}) 169 | else: 170 | config = config.update({'pretrain_mode': None}) 171 | if config.use_lang_embeddings: 172 | config = config.update({'mae.state_dim': config.mae.state_dim + 768}) 173 | else: 174 | config = config.update({'mae.state_dim': config.mae.state_dim + len(lang_instructions)}) 175 | config = config.update({'num_langs': len(lang_instructions)}) 176 | config = config.update({'train_mode': 'pretrain'}) 177 | config.save(logdir / "config_updated.yaml") 178 | 179 | finetune_lang_encoding = None 180 | if config.train_mode == "finetune" and config.pretrain_mode == "lang_emb": 181 | import r3mreward as r3mreward 182 | from r3m import load_r3m 183 | finetune_instruction = config.task.replace("_", " ") 184 | model = load_r3m('resnet50').module.eval().to(config.vidlang_model_device) 185 | model = r3mreward.R3MReward(model, [finetune_instruction], config.standardize_rewards, 186 | config.queue_size, config.update_stats_steps, config.num_top_images, config.use_lang_embeddings) 187 | finetune_lang_encoding = model.get_lang_encoding([finetune_instruction]).cpu().numpy()[0] 188 | if config.tune_instruction: 189 | if os.path.exists(config.instructions_file): 190 | print("Loading instructions to tune with...") 191 | with open(config.instructions_file, 'r') as f: 192 | candidate_instructions = f.read().splitlines() 193 | candidate_instructions = candidate_instructions[:config.num_tune_instructions-10] 194 | print(f"Loaded {config.num_tune_instructions} instructions from {config.instructions_file}") 195 | else: 196 | print("Generating instructions to tune with...") 197 | pt_lang_instructions, _, lang_curriculum, _ = get_lang_info(pt_cfg.multi_task_vidlang, pt_cfg.task, pt_cfg.curriculum.lang_prompt, pt_cfg.curriculum.synonym_folder, pt_cfg.curriculum.objects) 198 | candidate_instructions = np.random.choice(pt_lang_instructions, config.num_tune_instructions-10, replace=False) 199 | with open(f"prompts/tune_instructions_{config.num_tune_instructions}.txt", 'w') as f: 200 | for instr in candidate_instructions: 201 | f.write(instr + '\n') 202 | print(f"Generated {config.num_tune_instructions} instructions and saved to prompts/tune_instructions_{config.num_tune_instructions}.txt") 203 | with open(f"prompts/{config.task}.txt", 'r') as f: 204 | task_candidate_instructions = f.read().splitlines() 205 | candidate_instructions.extend(task_candidate_instructions) 206 | candidate_instructions_encodings = [] 207 | for instr in candidate_instructions: 208 | candidate_instructions_encodings.append(model.get_lang_encoding([instr]).cpu().numpy()[0]) 209 | candidate_instructions_scores = [0] * len(candidate_instructions) 210 | global score_idx 211 | score_idx = 0 212 | del model 213 | import gc; gc.collect() 214 | lang_instructions = [finetune_instruction] 215 | 216 | if not config.use_r3m_reward and config.train_mode == "pretrain" and config.use_lang_embeddings: 217 | import r3mreward as r3mreward 218 | from r3m import load_r3m 219 | sentences = [t.replace("_", " ") for t in lang_instructions] 220 | model = load_r3m('resnet50').module.eval().to(config.vidlang_model_device) 221 | model = r3mreward.R3MReward(model, sentences, config.standardize_rewards, 222 | config.queue_size, config.update_stats_steps, config.num_top_images, config.use_lang_embeddings) 223 | lang_encodings = model.get_lang_encoding(sentences).cpu().numpy() 224 | lang_to_encoding = dict(zip(lang_instructions, lang_encodings)) 225 | del model 226 | import gc; gc.collect() 227 | 228 | 229 | if config.use_r3m_reward: 230 | import r3mreward as r3mreward 231 | from r3m import load_r3m 232 | sentences = [t.replace("_", " ") for t in lang_instructions] 233 | 234 | model = load_r3m('resnet50').module.eval().to(config.vidlang_model_device) 235 | model = r3mreward.R3MReward(model, sentences, config.standardize_rewards, 236 | config.queue_size, config.update_stats_steps, config.num_top_images, config.use_lang_embeddings) 237 | if config.use_lang_embeddings: 238 | lang_encodings = model.get_lang_encoding(sentences).cpu().numpy() 239 | lang_to_encoding = dict(zip(lang_instructions, lang_encodings)) 240 | 241 | def get_r3m_reward(data, step=0): 242 | with torch.no_grad(): 243 | if config.camera_keys == "image_front|image_wrist" or config.camera_keys == "image_overhead|image_wrist": 244 | init_image = np.split(data['init_image'], 2, axis=-2)[0] 245 | image = np.split(data['image'], 2, axis=-2)[0] 246 | else: 247 | init_image = data['init_image'] 248 | image = data['image'] 249 | init_image = torch.from_numpy(init_image).to(config.vidlang_model_device) 250 | image = torch.from_numpy(image).to(config.vidlang_model_device) 251 | lang_num = torch.from_numpy(data['lang_num']).to(config.vidlang_model_device).unsqueeze(-1) 252 | if config.use_lang_embeddings: 253 | lang_embedding = torch.from_numpy(data['lang_embedding']).to(config.vidlang_model_device) 254 | else: 255 | lang_embedding = None 256 | init_image = init_image.permute((0, 3, 1, 2)) 257 | image = image.permute((0, 3, 1, 2)) 258 | reward, _, _ = model.get_reward(init_image, image, lang_num, lang_embedding, step) 259 | return reward.squeeze(-1).cpu().numpy() 260 | 261 | train_replay.reward_relabel(get_r3m_reward) 262 | elif config.use_internvideo_reward: 263 | sys.path.append('InternVideo/Pretrain/Multi-Modalities-Pretraining') 264 | import InternVideo 265 | from InternVideo import video_transform 266 | from torchvision import transforms 267 | print('Loading InternVideo model from path: {}...'.format(config.internvideo_load_dir)) 268 | model = InternVideo.load_model(config.internvideo_load_dir).cuda().to(config.vidlang_model_device) 269 | upsample = torch.nn.Upsample(size=(224,224), mode='bilinear', align_corners=False) 270 | input_mean = [0.48145466, 0.4578275, 0.40821073] 271 | input_std = [0.26862954, 0.26130258, 0.27577711] 272 | trans = transforms.Compose([ 273 | video_transform.ClipToTensor(channel_nb=3), 274 | video_transform.Normalize(mean=input_mean, std=input_std) 275 | ]) 276 | def get_internvideo_reward(data, step=0): 277 | with torch.no_grad(): 278 | if config.camera_keys == "image_front|image_wrist" or config.camera_keys == "image_overhead|image_wrist": 279 | image = np.split(data['image'], 2, axis=-2)[0] 280 | else: 281 | image = data['image'] 282 | videos = trans(image).to(config.vidlang_model_device) 283 | videos = upsample(videos) 284 | text_cand = [lang_instructions[lang_num] for lang_num in data["lang_num"]] 285 | text = InternVideo.tokenize(text_cand).cuda().to(config.vidlang_model_device) 286 | text_features = model.encode_text(text) 287 | 288 | reward = [] 289 | for i in range(videos.shape[1]): 290 | if i < 8: 291 | video = videos[:, :i+1, :, :] 292 | video = torch.cat([video, video[:, -1:, :, :].repeat(1, 8-(i+1), 1, 1)], dim=1).unsqueeze(0) 293 | else: 294 | indices = np.ceil(np.linspace(0, i, 8)).astype(int) 295 | video = videos[:, indices, :, :].unsqueeze(0) 296 | video_features = model.encode_video(video) 297 | video_features = torch.nn.functional.normalize(video_features, dim=1) 298 | text_features = torch.nn.functional.normalize(text_features, dim=1) 299 | t = model.logit_scale.exp() 300 | reward.append((video_features @ text_features[i]).cpu().numpy()) 301 | 302 | return np.array(reward).squeeze(-1) 303 | 304 | train_replay.reward_relabel(get_internvideo_reward) 305 | elif config.use_clip_reward: 306 | import clip 307 | sys.path.append('InternVideo/Pretrain/Multi-Modalities-Pretraining') 308 | import InternVideo 309 | from InternVideo import video_transform 310 | from torchvision import transforms 311 | model, _ = clip.load('ViT-B/32', config.vidlang_model_device) 312 | upsample = torch.nn.Upsample(size=(224,224), mode='bilinear', align_corners=False) 313 | input_mean = [0.48145466, 0.4578275, 0.40821073] 314 | input_std = [0.26862954, 0.26130258, 0.27577711] 315 | trans = transforms.Compose([ 316 | video_transform.ClipToTensor(channel_nb=3), 317 | video_transform.Normalize(mean=input_mean, std=input_std) 318 | ]) 319 | def get_clip_reward(data, step=0): 320 | with torch.no_grad(): 321 | if config.camera_keys == "image_front|image_wrist" or config.camera_keys == "image_overhead|image_wrist": 322 | image = np.split(data['image'], 2, axis=-2)[0] 323 | init_image = np.split(data['init_image'], 2, axis=-2)[0] 324 | else: 325 | image = data['image'] 326 | init_image = data['init_image'] 327 | init_videos = trans(init_image).to(config.vidlang_model_device) 328 | init_videos = upsample(init_videos) 329 | videos = trans(image).to(config.vidlang_model_device) 330 | videos = upsample(videos) 331 | text_inputs = torch.cat([clip.tokenize(lang_instructions[lang_num]) for lang_num in data["lang_num"]]).to(config.vidlang_model_device) 332 | init_image_inputs = init_videos.permute((1, 0, 2, 3)) 333 | image_inputs = videos.permute((1, 0, 2, 3)) 334 | init_image_features = model.encode_image(init_image_inputs) 335 | image_features = model.encode_image(image_inputs) 336 | init_text_features = model.encode_text(text_inputs) 337 | text_features = model.encode_text(text_inputs) 338 | init_image_features /= init_image_features.norm(dim=-1, keepdim=True) 339 | image_features /= image_features.norm(dim=-1, keepdim=True) 340 | init_text_features /= init_text_features.norm(dim=-1, keepdim=True) 341 | text_features /= text_features.norm(dim=-1, keepdim=True) 342 | delta_image_features = image_features - init_image_features 343 | delta_text_features = text_features - init_text_features 344 | reward = (delta_image_features * delta_text_features).sum(dim=-1).cpu().numpy() 345 | return reward 346 | train_replay.reward_relabel(get_clip_reward) 347 | else: 348 | print("Training from task reward...") 349 | 350 | if not config.use_lang_embeddings: 351 | lang_to_encoding = None 352 | def make_env(mode, actions_min_max=None): 353 | camera_keys = common.get_camera_keys(config.camera_keys) 354 | task = config.task.split(",") 355 | 356 | if config.franka_kitchen: 357 | env = common.KitchenEnv(task) 358 | elif config.multi_task_vidlang: 359 | env = common.MultiTaskVidLangRLBench(task, 360 | camera_keys, 361 | config.render_size, 362 | shaped_rewards=config.shaped_rewards, 363 | lang_to_num=lang_to_num, 364 | lang_to_encoding=lang_to_encoding, 365 | use_lang_embeddings=config.use_lang_embeddings, 366 | boundary_reward_penalty=config.boundary_reward_penalty, 367 | randomize=config.randomize, 368 | ) 369 | elif config.use_r3m_reward or config.use_internvideo_reward or config.use_clip_reward: 370 | env = common.VidLangRLBench(task[0], 371 | lang_instructions, 372 | camera_keys, 373 | config.render_size, 374 | shaped_rewards=config.shaped_rewards, 375 | lang_to_num=lang_to_num, 376 | lang_to_encoding=lang_to_encoding, 377 | use_lang_embeddings=config.use_lang_embeddings, 378 | boundary_reward_penalty=config.boundary_reward_penalty, 379 | curriculum=config.curriculum, 380 | lang_curriculum=lang_curriculum, 381 | synonym_dict=synonym_dict, 382 | randomize=config.randomize 383 | ) 384 | else: 385 | env = common.RLBench( 386 | lang_instructions, 387 | task[0], 388 | camera_keys, 389 | config.render_size, 390 | shaped_rewards=config.shaped_rewards, 391 | use_lang_embeddings=config.use_lang_embeddings, 392 | randomize=config.randomize, 393 | finetune_lang_encoding=finetune_lang_encoding 394 | ) 395 | if actions_min_max: 396 | env.register_min_max(actions_min_max) 397 | 398 | env = common.TimeLimit(env, config.time_limit) 399 | return env 400 | 401 | def per_episode(ep, mode, lang_instructions=None): 402 | length = len(ep["reward"]) - 1 403 | score = float(ep["reward"].astype(np.float64).sum()) 404 | success = float(np.sum(ep["success"]) >= 1.0) 405 | print( 406 | f"{mode.title()} episode has {float(success)} success, {length} steps and return {score:.1f}." 407 | ) 408 | logger.scalar(f"{mode}_success", float(success)) 409 | logger.scalar(f"{mode}_return", score) 410 | logger.scalar(f"{mode}_length", length) 411 | for key, value in ep.items(): 412 | if re.match(config.log_keys_sum, key): 413 | logger.scalar(f"sum_{mode}_{key}", ep[key].sum()) 414 | if re.match(config.log_keys_mean, key): 415 | logger.scalar(f"mean_{mode}_{key}", ep[key].mean()) 416 | if re.match(config.log_keys_max, key): 417 | logger.scalar(f"max_{mode}_{key}", ep[key].max(0).mean()) 418 | replay = dict(train=train_replay, eval=eval_replay)[mode] 419 | logger.add(replay.stats, prefix=mode) 420 | logger.write() 421 | 422 | print("Create envs.") 423 | num_eval_envs = min(config.envs, config.eval_eps) 424 | train_envs = [make_env("train") for _ in range(config.envs)] 425 | 426 | 427 | actions_min_max = None 428 | 429 | act_space = train_envs[0].act_space 430 | obs_space = train_envs[0].obs_space 431 | 432 | import agent as agent 433 | 434 | agnt = agent.Agent(config, obs_space, act_space, step) 435 | eval_policy = lambda *args: agnt.policy(*args, mode="eval") 436 | 437 | if config.tune_instruction: 438 | print("Tuning instruction...") 439 | def tune(ep, candidate_instructions_encodings=candidate_instructions_encodings, candidate_instructions_scores=candidate_instructions_scores): 440 | global score_idx 441 | candidate_instructions_scores[score_idx] = ep["reward"].sum() 442 | score_idx += 1 443 | 444 | tune_driver = common.Driver(train_envs) 445 | tune_driver.on_episode(lambda ep: tune(ep)) 446 | train_envs[0].change_lang_encoding(candidate_instructions_encodings[0]) 447 | tune_driver(eval_policy, episodes=config.num_tune_instructions) 448 | max_score_idx = np.argmax(np.array(candidate_instructions_scores)) 449 | print("Best instruction: {} with reward {}".format(candidate_instructions[max_score_idx], candidate_instructions_scores[max_score_idx])) 450 | train_envs[0].change_lang_encoding(candidate_instructions_encodings[max_score_idx]) 451 | finetune_lang_encoding = candidate_instructions_encodings[max_score_idx] 452 | 453 | make_async_env = lambda mode: common.Async( 454 | functools.partial(make_env, mode, actions_min_max), config.envs_parallel 455 | ) 456 | eval_envs = [make_async_env("eval") for _ in range(num_eval_envs)] 457 | 458 | print("Creating train and eval drivers.") 459 | train_driver = common.Driver(train_envs) 460 | train_driver.on_episode(lambda ep: per_episode(ep, mode="train")) 461 | train_driver.on_step(lambda tran, worker: step.increment()) 462 | train_driver.on_episode(train_replay.add_episode) 463 | eval_driver = common.Driver(eval_envs) 464 | eval_driver.on_episode(lambda ep: per_episode(ep, mode="eval", lang_instructions=lang_instructions)) 465 | eval_driver.on_episode(eval_replay.add_episode) 466 | 467 | prefill = max(0, config.prefill - train_replay.stats["total_steps"]) 468 | if prefill: 469 | print(f"Prefill dataset ({prefill} steps).") 470 | random_agent = common.RandomAgent(act_space) 471 | train_driver(random_agent, steps=prefill, episodes=1) 472 | eval_driver(random_agent, episodes=1) 473 | train_driver.reset() 474 | eval_driver.reset() 475 | 476 | print("Create agent.") 477 | train_dataset = iter(train_replay.dataset(**config.dataset)) 478 | mae_train_dataset = iter(train_replay.dataset(**config.mae_dataset)) 479 | report_dataset = iter(train_replay.dataset(**config.dataset)) 480 | 481 | if not config.use_imagenet_mae: 482 | train_mae = agnt.train_mae 483 | train_agent = common.CarryOverState(agnt.train) 484 | 485 | if not config.use_imagenet_mae: 486 | train_mae(next(mae_train_dataset)) 487 | train_agent(next(train_dataset)) 488 | 489 | if (loaddir / f"variables_{config.ts}.pkl").exists(): 490 | print("Loading agent.") 491 | try: 492 | agnt.load(loaddir / f"variables_{config.ts}.pkl") 493 | except Exception as e: 494 | raise Exception(f"Error loading agent: {e}") 495 | else: 496 | assert config.loaddir == '' 497 | 498 | print("Pretrain agent.") 499 | for _ in range(config.mae_pretrain): 500 | data = next(mae_train_dataset) 501 | if config.use_zero_rewards: 502 | data['reward'] = data['reward'] * tf.cast(data['is_terminal'], tf.float32) 503 | train_mae(data) 504 | for _ in range(config.pretrain): 505 | data = next(train_dataset) 506 | if config.use_zero_rewards: 507 | data['reward'] = data['reward'] * tf.cast(data['is_terminal'], tf.float32) 508 | train_agent(data) 509 | 510 | train_policy = lambda *args: agnt.policy(*args, mode="train") 511 | 512 | def train_step(tran, worker): 513 | if not config.use_imagenet_mae: 514 | if should_train_mae(step): 515 | for _ in range(config.train_mae_steps): 516 | data = next(mae_train_dataset) 517 | if config.use_zero_rewards: 518 | data['reward'] = data['reward'] * tf.cast(data['is_terminal'], tf.float32) 519 | mets = train_mae(data) 520 | [metrics[key].append(value) for key, value in mets.items()] 521 | if should_train(step): 522 | for _ in range(config.train_steps): 523 | data = next(train_dataset) 524 | if config.use_zero_rewards: 525 | data['reward'] = data['reward'] * tf.cast(data['is_terminal'], tf.float32) 526 | mets = train_agent(data) 527 | [metrics[key].append(value) for key, value in mets.items()] 528 | if should_log(step): 529 | for name, values in metrics.items(): 530 | logger.scalar(name, np.array(values, np.float64).mean()) 531 | metrics[name].clear() 532 | logger.add( 533 | agnt.report(next(report_dataset)), 534 | prefix="train", 535 | ) 536 | logger.write(fps=True) 537 | 538 | train_driver.on_step(train_step) 539 | 540 | config.save(logdir / "config_updated.yaml") 541 | while step < config.steps: 542 | logger.write() 543 | print("Start evaluation.") 544 | eval_driver(eval_policy, episodes=config.eval_eps) 545 | print("Start training.") 546 | train_driver(train_policy, steps=config.eval_every) 547 | agnt.save(logdir / f"variables_{step.value}.pkl") 548 | for env in train_envs + eval_envs: 549 | try: 550 | env.close() 551 | except Exception: 552 | pass 553 | agnt.save(logdir / f"variables_final.pkl") 554 | 555 | 556 | if __name__ == "__main__": 557 | main() 558 | -------------------------------------------------------------------------------- /common/envs.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import os 3 | import copy 4 | import random 5 | import sys 6 | import threading 7 | import traceback 8 | 9 | from pyrep.const import TextureMappingMode 10 | from pyrep.const import RenderMode 11 | 12 | import cloudpickle 13 | from functools import partial 14 | import gym 15 | import numpy as np 16 | from rlbench.utils import name_to_task_class 17 | 18 | from rlbench import RandomizeEvery 19 | from rlbench import VisualRandomizationConfig 20 | 21 | import time 22 | 23 | try: 24 | from pyrep.errors import ConfigurationPathError, IKError 25 | from rlbench.backend.exceptions import InvalidActionError 26 | except: 27 | pass 28 | 29 | 30 | class RLBench: 31 | def __init__( 32 | self, 33 | langs, 34 | name, 35 | camera_keys, 36 | size=(64, 64), 37 | actions_min_max=None, 38 | shaped_rewards=False, 39 | use_lang_embeddings=False, 40 | boundary_reward_penalty=False, 41 | randomize=False, 42 | finetune_lang_encoding=None, 43 | ): 44 | from rlbench.action_modes.action_mode import MoveArmThenGripper 45 | from rlbench.action_modes.arm_action_modes import ( 46 | EndEffectorPoseViaPlanning, 47 | ) 48 | from rlbench.action_modes.gripper_action_modes import ( 49 | Discrete, 50 | ) 51 | from rlbench.environment import Environment 52 | from rlbench.observation_config import ObservationConfig 53 | from rlbench.tasks import ( 54 | PhoneOnBase, 55 | PickAndLift, 56 | PickUpCup, 57 | PutRubbishInBin, 58 | TakeLidOffSaucepan, 59 | TakeUmbrellaOutOfUmbrellaStand, 60 | MultiTaskMicrofridgesauce, 61 | # MultiTaskBusplanesauce, 62 | PickShapenetObjects 63 | ) 64 | 65 | # we only support reach_target in this codebase 66 | obs_config = ObservationConfig() 67 | 68 | ## Camera setups 69 | obs_config.front_camera.set_all(False) 70 | obs_config.wrist_camera.set_all(False) 71 | obs_config.left_shoulder_camera.set_all(False) 72 | obs_config.right_shoulder_camera.set_all(False) 73 | obs_config.overhead_camera.set_all(False) 74 | 75 | if "image_front" in camera_keys: 76 | obs_config.front_camera.rgb = True 77 | obs_config.front_camera.image_size = size 78 | obs_config.front_camera.render_mode = RenderMode.OPENGL 79 | 80 | if "image_wrist" in camera_keys: 81 | obs_config.wrist_camera.rgb = True 82 | obs_config.wrist_camera.image_size = size 83 | obs_config.wrist_camera.render_mode = RenderMode.OPENGL 84 | 85 | if "image_overhead" in camera_keys: 86 | obs_config.overhead_camera.rgb = True 87 | obs_config.overhead_camera.image_size = size 88 | obs_config.overhead_camera.render_mode = RenderMode.OPENGL 89 | 90 | obs_config.joint_forces = False 91 | obs_config.joint_positions = True 92 | obs_config.joint_velocities = True 93 | obs_config.task_low_dim_state = True 94 | obs_config.gripper_touch_forces = False 95 | obs_config.gripper_pose = True 96 | obs_config.gripper_open = True 97 | obs_config.gripper_matrix = False 98 | obs_config.gripper_joint_positions = True 99 | 100 | if randomize: 101 | rand_config = [ 102 | VisualRandomizationConfig(image_directory='common/assets/textures/table', whitelist = ['diningTable_visible']), 103 | VisualRandomizationConfig(image_directory='common/assets/textures/wall', whitelist = ['Wall1', 'Wall2', 'Wall3', 'Wall4']), 104 | VisualRandomizationConfig(image_directory='common/assets/textures/floor', whitelist = ['Floor']) 105 | ] 106 | tex_kwargs = [ 107 | {'mapping_mode': TextureMappingMode.PLANE, 'repeat_along_u': False, 'repeat_along_v': False, 'uv_scaling': [1.6, 1.1]}, 108 | {'mapping_mode': TextureMappingMode.PLANE, 'repeat_along_u': False, 'repeat_along_v': False, 'uv_scaling': [5.0, 3.0]}, 109 | {'mapping_mode': TextureMappingMode.PLANE, 'repeat_along_u': False, 'repeat_along_v': False, 'uv_scaling': [5.0, 5.0]} 110 | ] 111 | randomized_every = RandomizeEvery.EPISODE 112 | else: 113 | rand_config = None 114 | randomized_every = None 115 | tex_kwargs = None 116 | 117 | env = Environment( 118 | action_mode=MoveArmThenGripper( 119 | arm_action_mode=EndEffectorPoseViaPlanning(False), 120 | gripper_action_mode=Discrete(), 121 | ), 122 | obs_config=obs_config, 123 | headless=True, 124 | shaped_rewards=shaped_rewards, 125 | randomize_every=randomized_every, 126 | visual_randomization_config=rand_config, 127 | tex_kwargs=tex_kwargs 128 | ) 129 | env.launch() 130 | 131 | if name == "phone_on_base": 132 | task = PhoneOnBase 133 | elif name == "pick_and_lift": 134 | task = PickAndLift 135 | elif name == "pick_up_cup": 136 | task = PickUpCup 137 | elif name == "put_rubbish_in_bin": 138 | task = PutRubbishInBin 139 | elif name == "take_lid_off_saucepan": 140 | task = TakeLidOffSaucepan 141 | elif name == "take_umbrella_out_of_umbrella_stand": 142 | task = TakeUmbrellaOutOfUmbrellaStand 143 | elif name == "multi_task_microfridgesauce": 144 | task = MultiTaskMicrofridgesauce 145 | elif name == "pick_shapenet_objects": 146 | task = PickShapenetObjects 147 | elif name in ["reach_for_bus", "reach_for_plane"]: 148 | task = MultiTaskBusplanesauce 149 | else: 150 | task = name_to_task_class(name) 151 | self._env = env 152 | self._task = env.get_task(task) 153 | self.task_name = name 154 | 155 | if "pick_shapenet_objects" in name: 156 | try: 157 | n_obj = int(name.split("_")[-1]) 158 | self._task._task.set_num_objects(n_obj) 159 | except: 160 | self._task._task.set_num_objects(1) 161 | 162 | _, obs = self._task.reset() 163 | 164 | task_low_dim = obs.task_low_dim_state.shape[0] 165 | self._state_dim = obs.get_low_dim_data().shape[0] - 14 - task_low_dim 166 | self._prev_obs, self._prev_reward = None, None 167 | self._ep_success = None 168 | 169 | self._size = size 170 | self._shaped_rewards = shaped_rewards 171 | self._camera_keys = camera_keys 172 | self._use_lang_embeddings = use_lang_embeddings 173 | self.finetune_lang_encoding = finetune_lang_encoding 174 | self._boundary_reward_penalty = boundary_reward_penalty 175 | self.langs = langs 176 | 177 | if actions_min_max: 178 | self.register_min_max(actions_min_max) 179 | else: 180 | self.low = np.array([-0.03, -0.03, -0.03]) 181 | self.high = np.array([0.03, 0.03, 0.03]) 182 | 183 | 184 | self._name = name 185 | 186 | @property 187 | def container_pos(self): 188 | if "pick_shapenet_objects" in self._name: 189 | return self._task._task.large_container.get_position() 190 | return None 191 | 192 | @property 193 | def obs_space(self): 194 | spaces = { 195 | "reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), 196 | "is_first": gym.spaces.Box(0, 1, (), dtype=bool), 197 | "is_last": gym.spaces.Box(0, 1, (), dtype=bool), 198 | "is_terminal": gym.spaces.Box(0, 1, (), dtype=bool), 199 | "success": gym.spaces.Box(0, 1, (), dtype=bool), 200 | "state": gym.spaces.Box( 201 | -np.inf, np.inf, (self._state_dim,), dtype=np.float32 202 | ), 203 | "image": gym.spaces.Box( 204 | 0, 205 | 255, 206 | (self._size[0], self._size[1] * len(self._camera_keys), 3), 207 | dtype=np.uint8, 208 | ), 209 | "init_state": gym.spaces.Box( 210 | -np.inf, np.inf, (self._state_dim,), dtype=np.float32 211 | ), 212 | "init_image": gym.spaces.Box( 213 | 0, 214 | 255, 215 | (self._size[0], self._size[1] * len(self._camera_keys), 3), 216 | dtype=np.uint8, 217 | ), 218 | "lang_num": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.uint8), 219 | } 220 | if self._use_lang_embeddings: 221 | spaces["lang_embedding"] = gym.spaces.Box( 222 | -np.inf, np.inf, (768,), dtype=np.float32) 223 | return spaces 224 | 225 | def register_min_max(self, actions_min_max): 226 | self.low, self.high = actions_min_max 227 | 228 | @property 229 | def act_space(self): 230 | assert self.low is not None 231 | if self.low.shape[0] == 3: 232 | self.low = np.hstack([self.low, [0.0]]) 233 | self.high = np.hstack([self.high, [1.0]]) 234 | action = gym.spaces.Box( 235 | low=self.low, high=self.high, shape=(self.low.shape[0],), dtype=np.float32 236 | ) 237 | return {"action": action} 238 | 239 | def unnormalize(self, a): 240 | # Un-normalize gripper pose normalized to [-1, 1] 241 | assert self.low is not None 242 | pose = a[:3] 243 | pose = (pose + 1) / 2 * (self.high[:3] - self.low[:3]) + self.low[:3] 244 | 245 | # Manual handling of overflow in z axis 246 | curr_pose = self._task._task.robot.arm.get_tip().get_pose()[:3] 247 | curr_z = curr_pose[2] 248 | init_z = self._init_pose[2] 249 | delta_z = pose[2] 250 | 251 | if curr_z + delta_z >= init_z: 252 | pose[2] = 0.0 253 | 254 | # Un-normalize gripper action normalized to [-1, 1] 255 | gripper = a[3:4] 256 | gripper = (gripper + 1) / 2 * (self.high[3:4] - self.low[3:4]) + self.low[3:4] 257 | 258 | target_pose = pose 259 | 260 | # Identity quaternion 261 | quat = np.array([0.0, 0.0, 0.0, 1.0]) 262 | 263 | action = np.hstack([target_pose, quat, gripper]) 264 | assert action.shape[0] == 8 265 | return action 266 | 267 | def step(self, action): 268 | assert np.isfinite(action["action"]).all(), action["action"] 269 | try: 270 | original_action = self.unnormalize(action["action"]) 271 | _obs, _reward, _ = self._task.step(original_action) 272 | terminal = False 273 | success, _ = self._task._task.success() 274 | if success: 275 | self._ep_success = True 276 | self._prev_obs, self._prev_reward = _obs, _reward 277 | if not self._shaped_rewards: 278 | reward = float(self._ep_success) 279 | else: 280 | reward = _reward 281 | except ConfigurationPathError: 282 | print("ConfigurationPathError") 283 | _obs = self._prev_obs 284 | terminal = False 285 | success = False 286 | if not self._shaped_rewards: 287 | reward = float(self._ep_success) 288 | else: 289 | reward = self._prev_reward 290 | except (IKError, InvalidActionError) as e: 291 | # print(e) 292 | _obs = self._prev_obs 293 | success = False 294 | if self._boundary_reward_penalty: 295 | terminal = True 296 | reward = -0.05 297 | else: 298 | terminal = False 299 | if not self._shaped_rewards: 300 | reward = float(self._ep_success) 301 | else: 302 | reward = self._prev_reward 303 | 304 | _obs.joint_velocities = None 305 | _obs.joint_positions = None 306 | _obs.task_low_dim_state = None 307 | 308 | obs = { 309 | "reward": reward, 310 | "is_first": False, 311 | "is_last": terminal, 312 | "is_terminal": terminal, 313 | "success": success, 314 | "state": _obs.get_low_dim_data(), 315 | 'lang_num': 0 316 | } 317 | images = [] 318 | for key in self._camera_keys: 319 | if key == "image_front": 320 | images.append(_obs.front_rgb) 321 | if key == "image_wrist": 322 | images.append(_obs.wrist_rgb) 323 | if key == "image_overhead": 324 | images.append(_obs.overhead_rgb) 325 | obs["image"] = np.concatenate(images, axis=-2) 326 | if self._use_lang_embeddings and self.finetune_lang_encoding is not None: 327 | obs['lang_embedding'] = self.finetune_lang_encoding 328 | self._time_step += 1 329 | return obs 330 | 331 | def reset(self): 332 | self.lang = random.choice(self.langs) 333 | self._task._task.change_reward(self.lang) 334 | _, _obs = self._task.reset() 335 | print(f"Reset in env {self.task_name}.") 336 | self._prev_obs = _obs 337 | self._init_pose = copy.deepcopy( 338 | self._task._task.robot.arm.get_tip().get_pose()[:3] 339 | ) 340 | self._time_step = 0 341 | self._ep_success = False 342 | 343 | _obs.joint_velocities = None 344 | _obs.joint_positions = None 345 | _obs.task_low_dim_state = None 346 | 347 | obs = { 348 | "reward": 0.0, 349 | "is_first": True, 350 | "is_last": False, 351 | "is_terminal": False, 352 | "success": False, 353 | "state": _obs.get_low_dim_data(), 354 | 'lang_num': 0 355 | } 356 | images = [] 357 | for key in self._camera_keys: 358 | if key == "image_front": 359 | images.append(_obs.front_rgb) 360 | if key == "image_wrist": 361 | images.append(_obs.wrist_rgb) 362 | if key == "image_overhead": 363 | images.append(_obs.overhead_rgb) 364 | obs["image"] = np.concatenate(images, axis=-2) 365 | if self._use_lang_embeddings and self.finetune_lang_encoding is not None: 366 | obs['lang_embedding'] = self.finetune_lang_encoding 367 | return obs 368 | 369 | def change_lang_encoding(self, lang_encoding): 370 | self.finetune_lang_encoding = lang_encoding 371 | 372 | class VidLangRLBench(RLBench): 373 | def __init__(self, name, langs, camera_keys, size=(64, 64), actions_min_max=None, 374 | shaped_rewards=False, lang_to_num=None, lang_to_encoding=None, 375 | use_lang_embeddings=False, boundary_reward_penalty=False, 376 | curriculum=None, lang_curriculum=None, synonym_dict=None, 377 | randomize=False): 378 | self.lang_to_num = lang_to_num 379 | self.lang_to_encoding = lang_to_encoding 380 | self.langs = langs 381 | self.init_langs = copy.deepcopy(langs) 382 | self.episode_count = 0 # maybe not the most accurate, not sure 383 | if curriculum is not None: 384 | self.use_curriculum = curriculum.use 385 | self.curriculum = curriculum 386 | n_episodes = [int(x) for x in curriculum.num_episodes.split("|")] 387 | # prefix sum 388 | self.curriculum_benchmarks = [0 for _ in range(len(n_episodes))] 389 | for i in range(len(n_episodes)): 390 | for j in range(i): 391 | # technically the last_n_episodes doesn't matter 392 | # env will keep running until config.steps 393 | self.curriculum_benchmarks[i] += n_episodes[j] 394 | self.lang_prompts = lang_curriculum 395 | self.synonym_dict = synonym_dict 396 | 397 | self.objects = curriculum.objects.split("|") 398 | for i in range(len(self.objects)): 399 | self.objects[i] = self.objects[i].split(",") 400 | self.num_objects = curriculum.num_objects.split("|") 401 | self.num_objects = [int(x) for x in self.num_objects] 402 | self.num_unique_per_class = curriculum.num_unique_per_class.split("|") 403 | self.num_unique_per_class = [int(x) for x in self.num_unique_per_class] 404 | 405 | super(VidLangRLBench, self).__init__(langs, name, camera_keys, size, actions_min_max, shaped_rewards, use_lang_embeddings, boundary_reward_penalty, randomize) 406 | 407 | def reset_curriculum(self): 408 | # todo: this only works for the shapenet env for now. 409 | # find index of self.episode_count in the list self.curriculum_benchmarks, if it exists 410 | if self.episode_count in self.curriculum_benchmarks: 411 | index = self.curriculum_benchmarks.index(self.episode_count) 412 | self.init_langs, self.langs = self.lang_prompts[index].copy(), self.lang_prompts[index].copy() 413 | self._task._task.reset_samplers(self.objects[index], self.objects[index], self.num_unique_per_class[index], self.num_unique_per_class[index]) 414 | self._task._task.set_num_objects(self.num_objects[index]) 415 | print(f"[curriculum] lang prompts: {self.langs}") 416 | print(f"[curriculum] objects: {self.num_objects[index]} of {self.objects[index]}") 417 | 418 | def reset(self): 419 | if self.use_curriculum: 420 | self.reset_curriculum() 421 | if self.init_langs: 422 | self.lang = random.choice(self.init_langs) 423 | self.init_langs.remove(self.lang) 424 | else: 425 | self.lang = random.choice(self.langs) 426 | 427 | time_step = super(VidLangRLBench, self).reset() 428 | if "[NOUN]" in self.lang: 429 | curr_obj = self._task._task.bin_objects_meta[0] 430 | if isinstance(self.synonym_dict[curr_obj], list): 431 | synonym = random.choice(self.synonym_dict[curr_obj]) 432 | else: 433 | synonym = self.synonym_dict[curr_obj] 434 | self.lang = self.lang.replace("[NOUN]", synonym) 435 | self._task._task.change_reward(self.lang) 436 | 437 | lang_num = self.lang_to_num[self.lang] 438 | print(f"Collecting for {self.lang} language instruction.") 439 | print(f"Reward is for {self._task._task.reward_lang}.") 440 | vidlang_time_step = time_step.copy() 441 | vidlang_time_step['init_image'] = vidlang_time_step['image'] 442 | vidlang_time_step['init_state'] = vidlang_time_step['state'] 443 | vidlang_time_step['lang_num'] = lang_num 444 | if self._use_lang_embeddings: 445 | vidlang_time_step['lang_embedding'] = self.lang_to_encoding[self.lang] 446 | self._init_vidlang_time_step = vidlang_time_step 447 | self.episode_count += 1 448 | return vidlang_time_step 449 | 450 | def step(self, action): 451 | time_step = super(VidLangRLBench, self).step(action) 452 | lang_num = self.lang_to_num[self.lang] 453 | vidlang_time_step = time_step.copy() 454 | vidlang_time_step['init_image'] = self._init_vidlang_time_step['init_image'] 455 | vidlang_time_step['init_state'] = self._init_vidlang_time_step['init_state'] 456 | vidlang_time_step['lang_num'] = lang_num 457 | if self._use_lang_embeddings: 458 | vidlang_time_step['lang_embedding'] = self.lang_to_encoding[self.lang] 459 | return vidlang_time_step 460 | 461 | class MultiTaskRLBench(RLBench): 462 | def __init__(self, task_names, camera_keys, size=(64, 64), actions_min_max=None, shaped_rewards=False, lang_to_num=None, lang_to_encoding=None, use_lang_embeddings=False, boundary_reward_penalty=False): 463 | self.task_names = task_names 464 | self.init_tasks = self.task_names.copy() 465 | self.task_name = random.choice(self.init_langs) 466 | 467 | super(MultiTaskRLBench, self).__init__(self.task_name, camera_keys, size, actions_min_max, shaped_rewards, use_lang_embeddings, boundary_reward_penalty) 468 | self.name_to_class = {} 469 | for t in self.task_names: 470 | self.name_to_class[t] = name_to_task_class(t) 471 | 472 | def reset(self): 473 | if self.init_tasks: 474 | self.task_name = random.choice(self.init_tasks) 475 | self.init_langs.remove(self.task_name) 476 | else: 477 | self.task_name = random.choice(self.task_names) 478 | self._task = self._env.get_task(self.name_to_class[self.task_name]) 479 | return super(MultiTaskRLBench, self).reset() 480 | 481 | class MultiTaskVidLangRLBench(MultiTaskRLBench): 482 | def __init__(self, task_names, camera_keys, size=(64, 64), actions_min_max=None, shaped_rewards=False, lang_to_num=None, lang_to_encoding=None, use_lang_embeddings=False, boundary_reward_penalty=False): 483 | self.lang_to_num = lang_to_num 484 | self.lang_to_encoding = lang_to_encoding 485 | super(MultiTaskVidLangRLBench, self).__init__(task_names, camera_keys, size, actions_min_max, shaped_rewards, use_lang_embeddings, boundary_reward_penalty) 486 | 487 | def reset(self): 488 | time_step = super(MultiTaskVidLangRLBench, self).reset() 489 | lang_num = self.lang_to_num[self.task_name] 490 | vidlang_time_step = time_step.copy() 491 | vidlang_time_step['init_image'] = vidlang_time_step['image'] 492 | vidlang_time_step['init_state'] = vidlang_time_step['state'] 493 | vidlang_time_step['lang_num'] = lang_num 494 | self._init_vidlang_time_step = vidlang_time_step 495 | if self._use_lang_embeddings: 496 | vidlang_time_step['lang_embedding'] = self.lang_to_encoding[self.lang] 497 | return vidlang_time_step 498 | 499 | def step(self, action): 500 | time_step = super(MultiTaskVidLangRLBench, self).step(action) 501 | lang_num = self.lang_to_num[self.task_name] 502 | vidlang_time_step = time_step.copy() 503 | vidlang_time_step['init_image'] = self._init_vidlang_time_step['init_image'] 504 | vidlang_time_step['init_state'] = self._init_vidlang_time_step['init_state'] 505 | vidlang_time_step['lang_num'] = lang_num 506 | if self._use_lang_embeddings: 507 | vidlang_time_step['lang_embedding'] = self.lang_to_encoding[self.lang] 508 | return vidlang_time_step 509 | 510 | class TimeLimit: 511 | def __init__(self, env, duration): 512 | self._env = env 513 | self._duration = duration 514 | self._step = None 515 | 516 | def __getattr__(self, name): 517 | if name.startswith("__"): 518 | raise AttributeError(name) 519 | try: 520 | return getattr(self._env, name) 521 | except AttributeError: 522 | raise ValueError(name) 523 | 524 | def step(self, action): 525 | assert self._step is not None, "Must reset environment." 526 | obs = self._env.step(action) 527 | self._step += 1 528 | if self._duration and self._step >= self._duration: 529 | obs["is_last"] = True 530 | self._step = None 531 | return obs 532 | 533 | def reset(self): 534 | self._step = 0 535 | return self._env.reset() 536 | 537 | 538 | class ResizeImage: 539 | def __init__(self, env, size=(64, 64)): 540 | self._env = env 541 | self._size = size 542 | self._keys = [ 543 | k 544 | for k, v in env.obs_space.items() 545 | if len(v.shape) > 1 and v.shape[:2] != size 546 | ] 547 | print(f'Resizing keys {",".join(self._keys)} to {self._size}.') 548 | if self._keys: 549 | from PIL import Image 550 | 551 | self._Image = Image 552 | 553 | def __getattr__(self, name): 554 | if name.startswith("__"): 555 | raise AttributeError(name) 556 | try: 557 | return getattr(self._env, name) 558 | except AttributeError: 559 | raise ValueError(name) 560 | 561 | @property 562 | def obs_space(self): 563 | spaces = self._env.obs_space 564 | for key in self._keys: 565 | shape = self._size + spaces[key].shape[2:] 566 | spaces[key] = gym.spaces.Box(0, 255, shape, np.uint8) 567 | return spaces 568 | 569 | def step(self, action): 570 | obs = self._env.step(action) 571 | for key in self._keys: 572 | obs[key] = self._resize(obs[key]) 573 | return obs 574 | 575 | def reset(self): 576 | obs = self._env.reset() 577 | for key in self._keys: 578 | obs[key] = self._resize(obs[key]) 579 | return obs 580 | 581 | def _resize(self, image): 582 | image = self._Image.fromarray(image) 583 | image = image.resize(self._size, self._Image.NEAREST) 584 | image = np.array(image) 585 | return image 586 | 587 | 588 | class RenderImage: 589 | def __init__(self, env, key="image"): 590 | self._env = env 591 | self._key = key 592 | self._shape = self._env.render().shape 593 | 594 | def __getattr__(self, name): 595 | if name.startswith("__"): 596 | raise AttributeError(name) 597 | try: 598 | return getattr(self._env, name) 599 | except AttributeError: 600 | raise ValueError(name) 601 | 602 | @property 603 | def obs_space(self): 604 | spaces = self._env.obs_space 605 | spaces[self._key] = gym.spaces.Box(0, 255, self._shape, np.uint8) 606 | return spaces 607 | 608 | def step(self, action): 609 | obs = self._env.step(action) 610 | obs[self._key] = self._env.render("rgb_array") 611 | return obs 612 | 613 | def reset(self): 614 | obs = self._env.reset() 615 | obs[self._key] = self._env.render("rgb_array") 616 | return obs 617 | 618 | 619 | class Async: 620 | 621 | # Message types for communication via the pipe. 622 | _ACCESS = 1 623 | _CALL = 2 624 | _RESULT = 3 625 | _CLOSE = 4 626 | _EXCEPTION = 5 627 | 628 | def __init__(self, constructor, strategy="thread"): 629 | self._pickled_ctor = cloudpickle.dumps(constructor) 630 | if strategy == "process": 631 | import multiprocessing as mp 632 | 633 | context = mp.get_context("spawn") 634 | elif strategy == "thread": 635 | import multiprocessing.dummy as context 636 | else: 637 | raise NotImplementedError(strategy) 638 | self._strategy = strategy 639 | self._conn, conn = context.Pipe() 640 | self._process = context.Process(target=self._worker, args=(conn,)) 641 | atexit.register(self.close) 642 | self._process.start() 643 | self._receive() # Ready. 644 | self._obs_space = None 645 | self._act_space = None 646 | 647 | def access(self, name): 648 | self._conn.send((self._ACCESS, name)) 649 | return self._receive 650 | 651 | def call(self, name, *args, **kwargs): 652 | payload = name, args, kwargs 653 | self._conn.send((self._CALL, payload)) 654 | return self._receive 655 | 656 | def close(self): 657 | try: 658 | self._conn.send((self._CLOSE, None)) 659 | self._conn.close() 660 | except IOError: 661 | pass # The connection was already closed. 662 | self._process.join(5) 663 | 664 | @property 665 | def obs_space(self): 666 | if not self._obs_space: 667 | self._obs_space = self.access("obs_space")() 668 | return self._obs_space 669 | 670 | @property 671 | def act_space(self): 672 | if not self._act_space: 673 | self._act_space = self.access("act_space")() 674 | return self._act_space 675 | 676 | def step(self, action, blocking=False): 677 | promise = self.call("step", action) 678 | if blocking: 679 | return promise() 680 | else: 681 | return promise 682 | 683 | def reset(self, blocking=False): 684 | promise = self.call("reset") 685 | if blocking: 686 | return promise() 687 | else: 688 | return promise 689 | 690 | def _receive(self): 691 | try: 692 | message, payload = self._conn.recv() 693 | except (OSError, EOFError): 694 | raise RuntimeError("Lost connection to environment worker.") 695 | # Re-raise exceptions in the main process. 696 | if message == self._EXCEPTION: 697 | stacktrace = payload 698 | raise Exception(stacktrace) 699 | if message == self._RESULT: 700 | return payload 701 | raise KeyError("Received message of unexpected type {}".format(message)) 702 | 703 | def _worker(self, conn): 704 | try: 705 | ctor = cloudpickle.loads(self._pickled_ctor) 706 | env = ctor() 707 | conn.send((self._RESULT, None)) # Ready. 708 | while True: 709 | try: 710 | # Only block for short times to have keyboard exceptions be raised. 711 | if not conn.poll(0.1): 712 | continue 713 | message, payload = conn.recv() 714 | except (EOFError, KeyboardInterrupt): 715 | break 716 | if message == self._ACCESS: 717 | name = payload 718 | result = getattr(env, name) 719 | conn.send((self._RESULT, result)) 720 | continue 721 | if message == self._CALL: 722 | name, args, kwargs = payload 723 | result = getattr(env, name)(*args, **kwargs) 724 | conn.send((self._RESULT, result)) 725 | continue 726 | if message == self._CLOSE: 727 | break 728 | raise KeyError("Received message of unknown type {}".format(message)) 729 | except Exception: 730 | stacktrace = "".join(traceback.format_exception(*sys.exc_info())) 731 | print("Error in environment process: {}".format(stacktrace)) 732 | conn.send((self._EXCEPTION, stacktrace)) 733 | finally: 734 | try: 735 | conn.close() 736 | except IOError: 737 | pass # The connection was already closed. 738 | --------------------------------------------------------------------------------