├── README.md └── main └── ota_updater.py /README.md: -------------------------------------------------------------------------------- 1 | # Micropython-ESP32-OTA 2 | ------------------------------- 3 | Example code for updating firmware over the air 4 | -------------------------------------------------------------------------------- /main/ota_updater.py: -------------------------------------------------------------------------------- 1 | 2 | import usocket 3 | import os 4 | import gc 5 | import machine 6 | 7 | 8 | class OTAUpdater: 9 | 10 | def __init__(self, github_repo, module='', main_dir='main'): 11 | self.http_client = HttpClient() 12 | self.github_repo = github_repo.rstrip('/').replace('https://github.com', 'https://api.github.com/repos') 13 | self.main_dir = main_dir 14 | self.module = module.rstrip('/') 15 | 16 | @staticmethod 17 | def using_network(ssid, password): 18 | import network 19 | sta_if = network.WLAN(network.STA_IF) 20 | if not sta_if.isconnected(): 21 | print('connecting to network...') 22 | sta_if.active(True) 23 | sta_if.connect(ssid, password) 24 | while not sta_if.isconnected(): 25 | pass 26 | print('network config:', sta_if.ifconfig()) 27 | 28 | def check_for_update_to_install_during_next_reboot(self): 29 | current_version = self.get_version(self.modulepath(self.main_dir)) 30 | latest_version = self.get_latest_version() 31 | 32 | print('Checking version... ') 33 | print('\tCurrent version: ', current_version) 34 | print('\tLatest version: ', latest_version) 35 | if latest_version > current_version: 36 | print('New version available, will download and install on next reboot') 37 | os.mkdir(self.modulepath('next')) 38 | with open(self.modulepath('next/.version_on_reboot'), 'w') as versionfile: 39 | versionfile.write(latest_version) 40 | versionfile.close() 41 | 42 | def download_and_install_update_if_available(self, ssid, password): 43 | if 'next' in os.listdir(self.module): 44 | if '.version_on_reboot' in os.listdir(self.modulepath('next')): 45 | latest_version = self.get_version(self.modulepath('next'), '.version_on_reboot') 46 | print('New update found: ', latest_version) 47 | self._download_and_install_update(latest_version, ssid, password) 48 | else: 49 | print('No new updates found...') 50 | 51 | def _download_and_install_update(self, latest_version, ssid, password): 52 | OTAUpdater.using_network(ssid, password) 53 | 54 | self.download_all_files(self.github_repo + '/contents/' + self.main_dir, latest_version) 55 | self.rmtree(self.modulepath(self.main_dir)) 56 | os.rename(self.modulepath('next/.version_on_reboot'), self.modulepath('next/.version')) 57 | os.rename(self.modulepath('next'), self.modulepath(self.main_dir)) 58 | print('Update installed (', latest_version, '), will reboot now') 59 | machine.reset() 60 | 61 | def apply_pending_updates_if_available(self): 62 | if 'next' in os.listdir(self.module): 63 | if '.version' in os.listdir(self.modulepath('next')): 64 | pending_update_version = self.get_version(self.modulepath('next')) 65 | print('Pending update found: ', pending_update_version) 66 | self.rmtree(self.modulepath(self.main_dir)) 67 | os.rename(self.modulepath('next'), self.modulepath(self.main_dir)) 68 | print('Update applied (', pending_update_version, '), ready to rock and roll') 69 | else: 70 | print('Corrupt pending update found, discarding...') 71 | self.rmtree(self.modulepath('next')) 72 | else: 73 | print('No pending update found') 74 | 75 | def download_updates_if_available(self): 76 | current_version = self.get_version(self.modulepath(self.main_dir)) 77 | latest_version = self.get_latest_version() 78 | 79 | print('Checking version... ') 80 | print('\tCurrent version: ', current_version) 81 | print('\tLatest version: ', latest_version) 82 | if latest_version > current_version: 83 | print('Updating...') 84 | os.mkdir(self.modulepath('next')) 85 | self.download_all_files(self.github_repo + '/contents/' + self.main_dir, latest_version) 86 | with open(self.modulepath('next/.version'), 'w') as versionfile: 87 | versionfile.write(latest_version) 88 | versionfile.close() 89 | 90 | return True 91 | return False 92 | 93 | def rmtree(self, directory): 94 | for entry in os.ilistdir(directory): 95 | is_dir = entry[1] == 0x4000 96 | if is_dir: 97 | self.rmtree(directory + '/' + entry[0]) 98 | 99 | else: 100 | os.remove(directory + '/' + entry[0]) 101 | os.rmdir(directory) 102 | 103 | def get_version(self, directory, version_file_name='.version'): 104 | if version_file_name in os.listdir(directory): 105 | f = open(directory + '/' + version_file_name) 106 | version = f.read() 107 | f.close() 108 | return version 109 | return '0.0' 110 | 111 | def get_latest_version(self): 112 | latest_release = self.http_client.get(self.github_repo + '/releases/latest') 113 | version = latest_release.json()['tag_name'] 114 | latest_release.close() 115 | return version 116 | 117 | def download_all_files(self, root_url, version): 118 | file_list = self.http_client.get(root_url + '?ref=refs/tags/' + version) 119 | for file in file_list.json(): 120 | if file['type'] == 'file': 121 | download_url = file['download_url'] 122 | download_path = self.modulepath('next/' + file['path'].replace(self.main_dir + '/', '')) 123 | self.download_file(download_url.replace('refs/tags/', ''), download_path) 124 | elif file['type'] == 'dir': 125 | path = self.modulepath('next/' + file['path'].replace(self.main_dir + '/', '')) 126 | os.mkdir(path) 127 | self.download_all_files(root_url + '/' + file['name'], version) 128 | 129 | file_list.close() 130 | 131 | def download_file(self, url, path): 132 | print('\tDownloading: ', path) 133 | with open(path, 'w') as outfile: 134 | try: 135 | response = self.http_client.get(url) 136 | outfile.write(response.text) 137 | finally: 138 | response.close() 139 | outfile.close() 140 | gc.collect() 141 | 142 | def modulepath(self, path): 143 | return self.module + '/' + path if self.module else path 144 | 145 | 146 | class Response: 147 | 148 | def __init__(self, f): 149 | self.raw = f 150 | self.encoding = 'utf-8' 151 | self._cached = None 152 | 153 | def close(self): 154 | if self.raw: 155 | self.raw.close() 156 | self.raw = None 157 | self._cached = None 158 | 159 | @property 160 | def content(self): 161 | if self._cached is None: 162 | try: 163 | self._cached = self.raw.read() 164 | finally: 165 | self.raw.close() 166 | self.raw = None 167 | return self._cached 168 | 169 | @property 170 | def text(self): 171 | return str(self.content, self.encoding) 172 | 173 | def json(self): 174 | import ujson 175 | return ujson.loads(self.content) 176 | 177 | 178 | class HttpClient: 179 | 180 | def request(self, method, url, data=None, json=None, headers={}, stream=None): 181 | try: 182 | proto, dummy, host, path = url.split('/', 3) 183 | except ValueError: 184 | proto, dummy, host = url.split('/', 2) 185 | path = '' 186 | if proto == 'http:': 187 | port = 80 188 | elif proto == 'https:': 189 | import ussl 190 | port = 443 191 | else: 192 | raise ValueError('Unsupported protocol: ' + proto) 193 | 194 | if ':' in host: 195 | host, port = host.split(':', 1) 196 | port = int(port) 197 | 198 | ai = usocket.getaddrinfo(host, port, 0, usocket.SOCK_STREAM) 199 | ai = ai[0] 200 | 201 | s = usocket.socket(ai[0], ai[1], ai[2]) 202 | try: 203 | s.connect(ai[-1]) 204 | if proto == 'https:': 205 | s = ussl.wrap_socket(s, server_hostname=host) 206 | s.write(b'%s /%s HTTP/1.0\r\n' % (method, path)) 207 | if not 'Host' in headers: 208 | s.write(b'Host: %s\r\n' % host) 209 | # Iterate over keys to avoid tuple alloc 210 | for k in headers: 211 | s.write(k) 212 | s.write(b': ') 213 | s.write(headers[k]) 214 | s.write(b'\r\n') 215 | # add user agent 216 | s.write('User-Agent') 217 | s.write(b': ') 218 | s.write('MicroPython OTAUpdater') 219 | s.write(b'\r\n') 220 | if json is not None: 221 | assert data is None 222 | import ujson 223 | data = ujson.dumps(json) 224 | s.write(b'Content-Type: application/json\r\n') 225 | if data: 226 | s.write(b'Content-Length: %d\r\n' % len(data)) 227 | s.write(b'\r\n') 228 | if data: 229 | s.write(data) 230 | 231 | l = s.readline() 232 | # print(l) 233 | l = l.split(None, 2) 234 | status = int(l[1]) 235 | reason = '' 236 | if len(l) > 2: 237 | reason = l[2].rstrip() 238 | while True: 239 | l = s.readline() 240 | if not l or l == b'\r\n': 241 | break 242 | # print(l) 243 | if l.startswith(b'Transfer-Encoding:'): 244 | if b'chunked' in l: 245 | raise ValueError('Unsupported ' + l) 246 | elif l.startswith(b'Location:') and not 200 <= status <= 299: 247 | raise NotImplementedError('Redirects not yet supported') 248 | except OSError: 249 | s.close() 250 | raise 251 | 252 | resp = Response(s) 253 | resp.status_code = status 254 | resp.reason = reason 255 | return resp 256 | 257 | def head(self, url, **kw): 258 | return self.request('HEAD', url, **kw) 259 | 260 | def get(self, url, **kw): 261 | return self.request('GET', url, **kw) 262 | 263 | def post(self, url, **kw): 264 | return self.request('POST', url, **kw) 265 | 266 | def put(self, url, **kw): 267 | return self.request('PUT', url, **kw) 268 | 269 | def patch(self, url, **kw): 270 | return self.request('PATCH', url, **kw) 271 | 272 | def delete(self, url, **kw): 273 | return self.request('DELETE', url, **kw) 274 | --------------------------------------------------------------------------------