├── README.md ├── README图片 ├── 关系矩阵图.png ├── 序列分布情况.png ├── 插值结果.png ├── 热力图.png ├── 箱状图.png ├── 船舶子轨迹.png └── 轨迹提取MMSI.png ├── S-Transformer ├── LICENSE ├── README.md ├── Untitled.ipynb ├── config_trAISformer.py ├── datasets.py ├── figures │ └── t18_3.png ├── models.py ├── requirements.txt ├── requirements.yml ├── trAISformer.py ├── trainers.py └── utils.py ├── 关系矩阵_热力图.py ├── 插值处理.py ├── 数据清洗.py ├── 箱形图.py ├── 轨迹平稳性检验.py └── 轨迹提取.py /README.md: -------------------------------------------------------------------------------- 1 | ## 船舶轨迹数据预处理及分析 2 | 3 | #### 介绍 4 | 5 | 此项目是对AIS数据进行预处理以及分析,一共包含三部分。(注意:使用时请将csv文件下到一个文件夹下,或者自行修改代码,本项目用的数据集为美国 MarineCadastre.gov 网站上提供的。) 6 | 7 | - 第一部分:数据清洗及轨迹提取 8 | - 第二部分:特征相关性与轨迹平稳性分析 9 | - 第三部分:轨迹修复 10 | 11 | #### 一、数据清洗及轨迹提取 12 | 13 | ##### 1.数据清洗(数据清洗.py 箱状图.py) 14 | 15 | 在 AIS 信号的发送、传输、接收过程中,数据难免会出现中断或缺失,且在现实环境中的 AIS 数据会有冗余、异常等情况,所以首先对数据进行清洗工作。 16 | 17 | 数据清洗过程及理由如下: 18 | 19 | (1)首先将 AIS 数据按 MMSI 和时间升幂排序。将原始数据按 MMSI 和时间先后顺序有利于进行后面的数据清洗。 20 | 21 | (2)删除 MMSI 不为 9 位的数据。排完序后发现有的船舶的 MMSI 不为 9 位数,因为船舶的 MMSI 唯一识别码是由 9 位数组成的,不符合规定的删除。 22 | 23 | (3)选择航行状态为正常航行中的数据。选择船舶的航行状态有发动机使用中、锚泊、未操纵、吃水受限、系泊、搁浅、捕捞、航行中,都统称为正常航行中。 24 | 25 | (4)删除船长小于 3 和船宽小于 2 的船舶数据。由于较小的船只的运动轨迹更易受海流、风浪等环境因子的影响,其数据的存在也可能会造成模型出现过拟合,故需剔除过小的船。 26 | 27 | (5)删除超出有效范围的经度、维度、对地航速、对地航向数据。根据AIS 数据有效范围可知超出参照数据的十进制表示有效范围是没有意义的,应当舍去。 28 | 29 | 经度(LON) -180.00000 ~ 180.00000 30 | 31 | 纬度(LAT) -90.00000 ~ 90.00000 32 | 33 | 对地航速(SOG) 0~51.2 34 | 35 | 对地航向(COG) -204.7~204.8 36 | 37 | (6)删除对地航速连续 5 个及以上为 0 的数据点。根据初步判断,当对地航速出现连续 5 个时刻为 0 时,将该点视为停泊点,如果不删除对后续轨迹研究造成影响。 38 | 39 | 清洗后的数据,会输出到一个新文件夹。之后可以使用**箱状图.py**文件画AIS数据的箱状图查看数据分布情况。 40 | 41 | 42 | 43 | ##### 2.轨迹提取(轨迹提取.py) 44 | 45 | 将所有 AIS 数据按照 MMSI 分成多个组,并对每一组中的轨迹点按 BaseDateTime递增的顺序进行排列,便可得到同一 MMSI 代表的船舶在一个时间段内的航行轨迹数据集。 46 | 47 | 当同一条船舶相邻两个轨迹点的时间差超过 30min 时,将相邻两个轨迹点分别视为不同航迹的终点和起点,这就把一条船舶轨迹划分为了多条船舶子轨迹。 48 | 49 | 运行文件后会输出一个文件,文件夹下是以MMSI命名的多个文件夹,每个文件夹里是划出的船舶子轨迹。 50 | 51 | 52 | 53 | 54 | 55 | #### 二、特征相关性与轨迹平稳性分析 56 | 57 | ##### 1.特征相关性(关系矩阵_热力图.py) 58 | 59 | 读取船舶子轨迹,输出各特征之间的关系矩阵图,各数据特征值相关系数热力图。 60 | 61 | 62 | 63 | 64 | 65 | ##### 2.轨迹平稳性(轨迹平稳性检验.py) 66 | 67 | 船舶轨迹序列是一个典型的多变量时间序列,在量化过程中使用时间序列分析工具时,经常需要先考察轨迹序列的平稳性,以便选择合适的轨迹预测方法。此代码直观地展示了船舶轨迹四个特征的序列分布情况。 68 | 69 | 70 | 71 | #### 三、轨迹修复(插值处理.py) 72 | 73 | 在清洗和轨迹提取阶段,可能会舍弃一些轨迹点,这可能导致航迹序列中前后的时间间隔不均匀,此外,由于 AIS 设备的发送频率不等,不同频率下接收到的轨迹序列的时间间隔也会有明显的差异。对于一条航迹序列而言,如果时间间隔太长,将会对后续的算法模型的训练产生极大程度的影响,因此,为了平滑轨迹点,采用三次样条插值函数,加权移动平均插值函数作为参考方法。 74 | 75 | 插值处理.py会输出两种函数的插值结果到新的文件,之后会绘制散点图直观的对比结果。 76 | 77 | 78 | 79 | 如果需要AIS轨迹图,时间-经纬度图可以点击链接异步我另一个仓库 ([点这里](https://github.com/axyqdm/Track-visualization)) -------------------------------------------------------------------------------- /README图片/关系矩阵图.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axyqdm/Ship-trajectory-data-preprocessing-and-analysis/fa2c61a3d177b13f60e2e615fecf4b80f7db6c7c/README图片/关系矩阵图.png -------------------------------------------------------------------------------- /README图片/序列分布情况.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axyqdm/Ship-trajectory-data-preprocessing-and-analysis/fa2c61a3d177b13f60e2e615fecf4b80f7db6c7c/README图片/序列分布情况.png -------------------------------------------------------------------------------- /README图片/插值结果.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axyqdm/Ship-trajectory-data-preprocessing-and-analysis/fa2c61a3d177b13f60e2e615fecf4b80f7db6c7c/README图片/插值结果.png -------------------------------------------------------------------------------- /README图片/热力图.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axyqdm/Ship-trajectory-data-preprocessing-and-analysis/fa2c61a3d177b13f60e2e615fecf4b80f7db6c7c/README图片/热力图.png -------------------------------------------------------------------------------- /README图片/箱状图.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axyqdm/Ship-trajectory-data-preprocessing-and-analysis/fa2c61a3d177b13f60e2e615fecf4b80f7db6c7c/README图片/箱状图.png -------------------------------------------------------------------------------- /README图片/船舶子轨迹.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axyqdm/Ship-trajectory-data-preprocessing-and-analysis/fa2c61a3d177b13f60e2e615fecf4b80f7db6c7c/README图片/船舶子轨迹.png -------------------------------------------------------------------------------- /README图片/轨迹提取MMSI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axyqdm/Ship-trajectory-data-preprocessing-and-analysis/fa2c61a3d177b13f60e2e615fecf4b80f7db6c7c/README图片/轨迹提取MMSI.png -------------------------------------------------------------------------------- /S-Transformer/LICENSE: -------------------------------------------------------------------------------- 1 | CeCILL-C FREE SOFTWARE LICENSE AGREEMENT 2 | 3 | 4 | Notice 5 | 6 | This Agreement is a Free Software license agreement that is the result 7 | of discussions between its authors in order to ensure compliance with 8 | the two main principles guiding its drafting: 9 | 10 | * firstly, compliance with the principles governing the distribution 11 | of Free Software: access to source code, broad rights granted to 12 | users, 13 | * secondly, the election of a governing law, French law, with which 14 | it is conformant, both as regards the law of torts and 15 | intellectual property law, and the protection that it offers to 16 | both authors and holders of the economic rights over software. 17 | 18 | The authors of the CeCILL-C (for Ce[a] C[nrs] I[nria] L[ogiciel] L[ibre]) 19 | license are: 20 | 21 | Commissariat à l'Energie Atomique - CEA, a public scientific, technical 22 | and industrial research establishment, having its principal place of 23 | business at 25 rue Leblanc, immeuble Le Ponant D, 75015 Paris, France. 24 | 25 | Centre National de la Recherche Scientifique - CNRS, a public scientific 26 | and technological establishment, having its principal place of business 27 | at 3 rue Michel-Ange, 75794 Paris cedex 16, France. 28 | 29 | Institut National de Recherche en Informatique et en Automatique - 30 | INRIA, a public scientific and technological establishment, having its 31 | principal place of business at Domaine de Voluceau, Rocquencourt, BP 32 | 105, 78153 Le Chesnay cedex, France. 33 | 34 | 35 | Preamble 36 | 37 | The purpose of this Free Software license agreement is to grant users 38 | the right to modify and re-use the software governed by this license. 39 | 40 | The exercising of this right is conditional upon the obligation to make 41 | available to the community the modifications made to the source code of 42 | the software so as to contribute to its evolution. 43 | 44 | In consideration of access to the source code and the rights to copy, 45 | modify and redistribute granted by the license, users are provided only 46 | with a limited warranty and the software's author, the holder of the 47 | economic rights, and the successive licensors only have limited liability. 48 | 49 | In this respect, the risks associated with loading, using, modifying 50 | and/or developing or reproducing the software by the user are brought to 51 | the user's attention, given its Free Software status, which may make it 52 | complicated to use, with the result that its use is reserved for 53 | developers and experienced professionals having in-depth computer 54 | knowledge. Users are therefore encouraged to load and test the 55 | suitability of the software as regards their requirements in conditions 56 | enabling the security of their systems and/or data to be ensured and, 57 | more generally, to use and operate it in the same conditions of 58 | security. This Agreement may be freely reproduced and published, 59 | provided it is not altered, and that no provisions are either added or 60 | removed herefrom. 61 | 62 | This Agreement may apply to any or all software for which the holder of 63 | the economic rights decides to submit the use thereof to its provisions. 64 | 65 | 66 | Article 1 - DEFINITIONS 67 | 68 | For the purpose of this Agreement, when the following expressions 69 | commence with a capital letter, they shall have the following meaning: 70 | 71 | Agreement: means this license agreement, and its possible subsequent 72 | versions and annexes. 73 | 74 | Software: means the software in its Object Code and/or Source Code form 75 | and, where applicable, its documentation, "as is" when the Licensee 76 | accepts the Agreement. 77 | 78 | Initial Software: means the Software in its Source Code and possibly its 79 | Object Code form and, where applicable, its documentation, "as is" when 80 | it is first distributed under the terms and conditions of the Agreement. 81 | 82 | Modified Software: means the Software modified by at least one 83 | Integrated Contribution. 84 | 85 | Source Code: means all the Software's instructions and program lines to 86 | which access is required so as to modify the Software. 87 | 88 | Object Code: means the binary files originating from the compilation of 89 | the Source Code. 90 | 91 | Holder: means the holder(s) of the economic rights over the Initial 92 | Software. 93 | 94 | Licensee: means the Software user(s) having accepted the Agreement. 95 | 96 | Contributor: means a Licensee having made at least one Integrated 97 | Contribution. 98 | 99 | Licensor: means the Holder, or any other individual or legal entity, who 100 | distributes the Software under the Agreement. 101 | 102 | Integrated Contribution: means any or all modifications, corrections, 103 | translations, adaptations and/or new functions integrated into the 104 | Source Code by any or all Contributors. 105 | 106 | Related Module: means a set of sources files including their 107 | documentation that, without modification to the Source Code, enables 108 | supplementary functions or services in addition to those offered by the 109 | Software. 110 | 111 | Derivative Software: means any combination of the Software, modified or 112 | not, and of a Related Module. 113 | 114 | Parties: mean both the Licensee and the Licensor. 115 | 116 | These expressions may be used both in singular and plural form. 117 | 118 | 119 | Article 2 - PURPOSE 120 | 121 | The purpose of the Agreement is the grant by the Licensor to the 122 | Licensee of a non-exclusive, transferable and worldwide license for the 123 | Software as set forth in Article 5 hereinafter for the whole term of the 124 | protection granted by the rights over said Software. 125 | 126 | 127 | Article 3 - ACCEPTANCE 128 | 129 | 3.1 The Licensee shall be deemed as having accepted the terms and 130 | conditions of this Agreement upon the occurrence of the first of the 131 | following events: 132 | 133 | * (i) loading the Software by any or all means, notably, by 134 | downloading from a remote server, or by loading from a physical 135 | medium; 136 | * (ii) the first time the Licensee exercises any of the rights 137 | granted hereunder. 138 | 139 | 3.2 One copy of the Agreement, containing a notice relating to the 140 | characteristics of the Software, to the limited warranty, and to the 141 | fact that its use is restricted to experienced users has been provided 142 | to the Licensee prior to its acceptance as set forth in Article 3.1 143 | hereinabove, and the Licensee hereby acknowledges that it has read and 144 | understood it. 145 | 146 | 147 | Article 4 - EFFECTIVE DATE AND TERM 148 | 149 | 150 | 4.1 EFFECTIVE DATE 151 | 152 | The Agreement shall become effective on the date when it is accepted by 153 | the Licensee as set forth in Article 3.1. 154 | 155 | 156 | 4.2 TERM 157 | 158 | The Agreement shall remain in force for the entire legal term of 159 | protection of the economic rights over the Software. 160 | 161 | 162 | Article 5 - SCOPE OF RIGHTS GRANTED 163 | 164 | The Licensor hereby grants to the Licensee, who accepts, the following 165 | rights over the Software for any or all use, and for the term of the 166 | Agreement, on the basis of the terms and conditions set forth hereinafter. 167 | 168 | Besides, if the Licensor owns or comes to own one or more patents 169 | protecting all or part of the functions of the Software or of its 170 | components, the Licensor undertakes not to enforce the rights granted by 171 | these patents against successive Licensees using, exploiting or 172 | modifying the Software. If these patents are transferred, the Licensor 173 | undertakes to have the transferees subscribe to the obligations set 174 | forth in this paragraph. 175 | 176 | 177 | 5.1 RIGHT OF USE 178 | 179 | The Licensee is authorized to use the Software, without any limitation 180 | as to its fields of application, with it being hereinafter specified 181 | that this comprises: 182 | 183 | 1. permanent or temporary reproduction of all or part of the Software 184 | by any or all means and in any or all form. 185 | 186 | 2. loading, displaying, running, or storing the Software on any or 187 | all medium. 188 | 189 | 3. entitlement to observe, study or test its operation so as to 190 | determine the ideas and principles behind any or all constituent 191 | elements of said Software. This shall apply when the Licensee 192 | carries out any or all loading, displaying, running, transmission 193 | or storage operation as regards the Software, that it is entitled 194 | to carry out hereunder. 195 | 196 | 197 | 5.2 RIGHT OF MODIFICATION 198 | 199 | The right of modification includes the right to translate, adapt, 200 | arrange, or make any or all modifications to the Software, and the right 201 | to reproduce the resulting software. It includes, in particular, the 202 | right to create a Derivative Software. 203 | 204 | The Licensee is authorized to make any or all modification to the 205 | Software provided that it includes an explicit notice that it is the 206 | author of said modification and indicates the date of the creation thereof. 207 | 208 | 209 | 5.3 RIGHT OF DISTRIBUTION 210 | 211 | In particular, the right of distribution includes the right to publish, 212 | transmit and communicate the Software to the general public on any or 213 | all medium, and by any or all means, and the right to market, either in 214 | consideration of a fee, or free of charge, one or more copies of the 215 | Software by any means. 216 | 217 | The Licensee is further authorized to distribute copies of the modified 218 | or unmodified Software to third parties according to the terms and 219 | conditions set forth hereinafter. 220 | 221 | 222 | 5.3.1 DISTRIBUTION OF SOFTWARE WITHOUT MODIFICATION 223 | 224 | The Licensee is authorized to distribute true copies of the Software in 225 | Source Code or Object Code form, provided that said distribution 226 | complies with all the provisions of the Agreement and is accompanied by: 227 | 228 | 1. a copy of the Agreement, 229 | 230 | 2. a notice relating to the limitation of both the Licensor's 231 | warranty and liability as set forth in Articles 8 and 9, 232 | 233 | and that, in the event that only the Object Code of the Software is 234 | redistributed, the Licensee allows effective access to the full Source 235 | Code of the Software at a minimum during the entire period of its 236 | distribution of the Software, it being understood that the additional 237 | cost of acquiring the Source Code shall not exceed the cost of 238 | transferring the data. 239 | 240 | 241 | 5.3.2 DISTRIBUTION OF MODIFIED SOFTWARE 242 | 243 | When the Licensee makes an Integrated Contribution to the Software, the 244 | terms and conditions for the distribution of the resulting Modified 245 | Software become subject to all the provisions of this Agreement. 246 | 247 | The Licensee is authorized to distribute the Modified Software, in 248 | source code or object code form, provided that said distribution 249 | complies with all the provisions of the Agreement and is accompanied by: 250 | 251 | 1. a copy of the Agreement, 252 | 253 | 2. a notice relating to the limitation of both the Licensor's 254 | warranty and liability as set forth in Articles 8 and 9, 255 | 256 | and that, in the event that only the object code of the Modified 257 | Software is redistributed, the Licensee allows effective access to the 258 | full source code of the Modified Software at a minimum during the entire 259 | period of its distribution of the Modified Software, it being understood 260 | that the additional cost of acquiring the source code shall not exceed 261 | the cost of transferring the data. 262 | 263 | 264 | 5.3.3 DISTRIBUTION OF DERIVATIVE SOFTWARE 265 | 266 | When the Licensee creates Derivative Software, this Derivative Software 267 | may be distributed under a license agreement other than this Agreement, 268 | subject to compliance with the requirement to include a notice 269 | concerning the rights over the Software as defined in Article 6.4. 270 | In the event the creation of the Derivative Software required modification 271 | of the Source Code, the Licensee undertakes that: 272 | 273 | 1. the resulting Modified Software will be governed by this Agreement, 274 | 2. the Integrated Contributions in the resulting Modified Software 275 | will be clearly identified and documented, 276 | 3. the Licensee will allow effective access to the source code of the 277 | Modified Software, at a minimum during the entire period of 278 | distribution of the Derivative Software, such that such 279 | modifications may be carried over in a subsequent version of the 280 | Software; it being understood that the additional cost of 281 | purchasing the source code of the Modified Software shall not 282 | exceed the cost of transferring the data. 283 | 284 | 285 | 5.3.4 COMPATIBILITY WITH THE CeCILL LICENSE 286 | 287 | When a Modified Software contains an Integrated Contribution subject to 288 | the CeCILL license agreement, or when a Derivative Software contains a 289 | Related Module subject to the CeCILL license agreement, the provisions 290 | set forth in the third item of Article 6.4 are optional. 291 | 292 | 293 | Article 6 - INTELLECTUAL PROPERTY 294 | 295 | 296 | 6.1 OVER THE INITIAL SOFTWARE 297 | 298 | The Holder owns the economic rights over the Initial Software. Any or 299 | all use of the Initial Software is subject to compliance with the terms 300 | and conditions under which the Holder has elected to distribute its work 301 | and no one shall be entitled to modify the terms and conditions for the 302 | distribution of said Initial Software. 303 | 304 | The Holder undertakes that the Initial Software will remain ruled at 305 | least by this Agreement, for the duration set forth in Article 4.2. 306 | 307 | 308 | 6.2 OVER THE INTEGRATED CONTRIBUTIONS 309 | 310 | The Licensee who develops an Integrated Contribution is the owner of the 311 | intellectual property rights over this Contribution as defined by 312 | applicable law. 313 | 314 | 315 | 6.3 OVER THE RELATED MODULES 316 | 317 | The Licensee who develops a Related Module is the owner of the 318 | intellectual property rights over this Related Module as defined by 319 | applicable law and is free to choose the type of agreement that shall 320 | govern its distribution under the conditions defined in Article 5.3.3. 321 | 322 | 323 | 6.4 NOTICE OF RIGHTS 324 | 325 | The Licensee expressly undertakes: 326 | 327 | 1. not to remove, or modify, in any manner, the intellectual property 328 | notices attached to the Software; 329 | 330 | 2. to reproduce said notices, in an identical manner, in the copies 331 | of the Software modified or not; 332 | 333 | 3. to ensure that use of the Software, its intellectual property 334 | notices and the fact that it is governed by the Agreement is 335 | indicated in a text that is easily accessible, specifically from 336 | the interface of any Derivative Software. 337 | 338 | The Licensee undertakes not to directly or indirectly infringe the 339 | intellectual property rights of the Holder and/or Contributors on the 340 | Software and to take, where applicable, vis-à-vis its staff, any and all 341 | measures required to ensure respect of said intellectual property rights 342 | of the Holder and/or Contributors. 343 | 344 | 345 | Article 7 - RELATED SERVICES 346 | 347 | 7.1 Under no circumstances shall the Agreement oblige the Licensor to 348 | provide technical assistance or maintenance services for the Software. 349 | 350 | However, the Licensor is entitled to offer this type of services. The 351 | terms and conditions of such technical assistance, and/or such 352 | maintenance, shall be set forth in a separate instrument. Only the 353 | Licensor offering said maintenance and/or technical assistance services 354 | shall incur liability therefor. 355 | 356 | 7.2 Similarly, any Licensor is entitled to offer to its licensees, under 357 | its sole responsibility, a warranty, that shall only be binding upon 358 | itself, for the redistribution of the Software and/or the Modified 359 | Software, under terms and conditions that it is free to decide. Said 360 | warranty, and the financial terms and conditions of its application, 361 | shall be subject of a separate instrument executed between the Licensor 362 | and the Licensee. 363 | 364 | 365 | Article 8 - LIABILITY 366 | 367 | 8.1 Subject to the provisions of Article 8.2, the Licensee shall be 368 | entitled to claim compensation for any direct loss it may have suffered 369 | from the Software as a result of a fault on the part of the relevant 370 | Licensor, subject to providing evidence thereof. 371 | 372 | 8.2 The Licensor's liability is limited to the commitments made under 373 | this Agreement and shall not be incurred as a result of in particular: 374 | (i) loss due the Licensee's total or partial failure to fulfill its 375 | obligations, (ii) direct or consequential loss that is suffered by the 376 | Licensee due to the use or performance of the Software, and (iii) more 377 | generally, any consequential loss. In particular the Parties expressly 378 | agree that any or all pecuniary or business loss (i.e. loss of data, 379 | loss of profits, operating loss, loss of customers or orders, 380 | opportunity cost, any disturbance to business activities) or any or all 381 | legal proceedings instituted against the Licensee by a third party, 382 | shall constitute consequential loss and shall not provide entitlement to 383 | any or all compensation from the Licensor. 384 | 385 | 386 | Article 9 - WARRANTY 387 | 388 | 9.1 The Licensee acknowledges that the scientific and technical 389 | state-of-the-art when the Software was distributed did not enable all 390 | possible uses to be tested and verified, nor for the presence of 391 | possible defects to be detected. In this respect, the Licensee's 392 | attention has been drawn to the risks associated with loading, using, 393 | modifying and/or developing and reproducing the Software which are 394 | reserved for experienced users. 395 | 396 | The Licensee shall be responsible for verifying, by any or all means, 397 | the suitability of the product for its requirements, its good working 398 | order, and for ensuring that it shall not cause damage to either persons 399 | or properties. 400 | 401 | 9.2 The Licensor hereby represents, in good faith, that it is entitled 402 | to grant all the rights over the Software (including in particular the 403 | rights set forth in Article 5). 404 | 405 | 9.3 The Licensee acknowledges that the Software is supplied "as is" by 406 | the Licensor without any other express or tacit warranty, other than 407 | that provided for in Article 9.2 and, in particular, without any warranty 408 | as to its commercial value, its secured, safe, innovative or relevant 409 | nature. 410 | 411 | Specifically, the Licensor does not warrant that the Software is free 412 | from any error, that it will operate without interruption, that it will 413 | be compatible with the Licensee's own equipment and software 414 | configuration, nor that it will meet the Licensee's requirements. 415 | 416 | 9.4 The Licensor does not either expressly or tacitly warrant that the 417 | Software does not infringe any third party intellectual property right 418 | relating to a patent, software or any other property right. Therefore, 419 | the Licensor disclaims any and all liability towards the Licensee 420 | arising out of any or all proceedings for infringement that may be 421 | instituted in respect of the use, modification and redistribution of the 422 | Software. Nevertheless, should such proceedings be instituted against 423 | the Licensee, the Licensor shall provide it with technical and legal 424 | assistance for its defense. Such technical and legal assistance shall be 425 | decided on a case-by-case basis between the relevant Licensor and the 426 | Licensee pursuant to a memorandum of understanding. The Licensor 427 | disclaims any and all liability as regards the Licensee's use of the 428 | name of the Software. No warranty is given as regards the existence of 429 | prior rights over the name of the Software or as regards the existence 430 | of a trademark. 431 | 432 | 433 | Article 10 - TERMINATION 434 | 435 | 10.1 In the event of a breach by the Licensee of its obligations 436 | hereunder, the Licensor may automatically terminate this Agreement 437 | thirty (30) days after notice has been sent to the Licensee and has 438 | remained ineffective. 439 | 440 | 10.2 A Licensee whose Agreement is terminated shall no longer be 441 | authorized to use, modify or distribute the Software. However, any 442 | licenses that it may have granted prior to termination of the Agreement 443 | shall remain valid subject to their having been granted in compliance 444 | with the terms and conditions hereof. 445 | 446 | 447 | Article 11 - MISCELLANEOUS 448 | 449 | 450 | 11.1 EXCUSABLE EVENTS 451 | 452 | Neither Party shall be liable for any or all delay, or failure to 453 | perform the Agreement, that may be attributable to an event of force 454 | majeure, an act of God or an outside cause, such as defective 455 | functioning or interruptions of the electricity or telecommunications 456 | networks, network paralysis following a virus attack, intervention by 457 | government authorities, natural disasters, water damage, earthquakes, 458 | fire, explosions, strikes and labor unrest, war, etc. 459 | 460 | 11.2 Any failure by either Party, on one or more occasions, to invoke 461 | one or more of the provisions hereof, shall under no circumstances be 462 | interpreted as being a waiver by the interested Party of its right to 463 | invoke said provision(s) subsequently. 464 | 465 | 11.3 The Agreement cancels and replaces any or all previous agreements, 466 | whether written or oral, between the Parties and having the same 467 | purpose, and constitutes the entirety of the agreement between said 468 | Parties concerning said purpose. No supplement or modification to the 469 | terms and conditions hereof shall be effective as between the Parties 470 | unless it is made in writing and signed by their duly authorized 471 | representatives. 472 | 473 | 11.4 In the event that one or more of the provisions hereof were to 474 | conflict with a current or future applicable act or legislative text, 475 | said act or legislative text shall prevail, and the Parties shall make 476 | the necessary amendments so as to comply with said act or legislative 477 | text. All other provisions shall remain effective. Similarly, invalidity 478 | of a provision of the Agreement, for any reason whatsoever, shall not 479 | cause the Agreement as a whole to be invalid. 480 | 481 | 482 | 11.5 LANGUAGE 483 | 484 | The Agreement is drafted in both French and English and both versions 485 | are deemed authentic. 486 | 487 | 488 | Article 12 - NEW VERSIONS OF THE AGREEMENT 489 | 490 | 12.1 Any person is authorized to duplicate and distribute copies of this 491 | Agreement. 492 | 493 | 12.2 So as to ensure coherence, the wording of this Agreement is 494 | protected and may only be modified by the authors of the License, who 495 | reserve the right to periodically publish updates or new versions of the 496 | Agreement, each with a separate number. These subsequent versions may 497 | address new issues encountered by Free Software. 498 | 499 | 12.3 Any Software distributed under a given version of the Agreement may 500 | only be subsequently distributed under the same version of the Agreement 501 | or a subsequent version. 502 | 503 | 504 | Article 13 - GOVERNING LAW AND JURISDICTION 505 | 506 | 13.1 The Agreement is governed by French law. The Parties agree to 507 | endeavor to seek an amicable solution to any disagreements or disputes 508 | that may arise during the performance of the Agreement. 509 | 510 | 13.2 Failing an amicable solution within two (2) months as from their 511 | occurrence, and unless emergency proceedings are necessary, the 512 | disagreements or disputes shall be referred to the Paris Courts having 513 | jurisdiction, by the more diligent Party. 514 | -------------------------------------------------------------------------------- /S-Transformer/README.md: -------------------------------------------------------------------------------- 1 | # TrAISformer 2 | 3 | Pytorch implementation of TrAISformer---A generative transformer for AIS trajectory prediction (https://arxiv.org/abs/2109.03958). 4 | 5 | The transformer part is adapted from: https://github.com/karpathy/minGPT 6 | 7 | --- 8 |

9 | 10 |

11 | 12 | 13 | #### Requirements: 14 | See requirements.yml 15 | 16 | ### Datasets: 17 | 18 | The data used in this paper are provided by the [Danish Maritime Authority (DMA)](https://dma.dk/safety-at-sea/navigational-information/ais-data). 19 | Please refer to [the paper](https://arxiv.org/abs/2109.03958) for the details of the pre-processing step. The code is available here: https://github.com/CIA-Oceanix/GeoTrackNet/blob/master/data/csv2pkl.py 20 | 21 | A processed dataset can be found in `./data/ct_dma/` 22 | (the format is `[lat, log, sog, cog, unix_timestamp, mmsi]`). 23 | 24 | ### Run 25 | 26 | Run `trAISformer.py` to train and evaluate the model. 27 | (Please note that the values given by the code are in km, while the values presented in the paper were converted to nautical mile.) 28 | 29 | 30 | ### License 31 | 32 | See `LICENSE` 33 | 34 | ### Contact 35 | For any questions, please open an issue and assign it to @dnguyengithub. 36 | 37 | -------------------------------------------------------------------------------- /S-Transformer/Untitled.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "id": "6a7477f1-1db4-46cf-8e22-7894b8bd4eb5", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import matplotlib.pyplot as plt\n", 11 | "import cartopy.crs as ccrs\n", 12 | "import cartopy.feature as cfeature\n", 13 | "import pandas as pd\n", 14 | "import xlrd" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 4, 20 | "id": "9a5ffab3-fbc5-454e-b6bc-773ccceae510", 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "Collecting xlrd\n", 28 | " Downloading xlrd-2.0.1-py2.py3-none-any.whl.metadata (3.4 kB)\n", 29 | "Downloading xlrd-2.0.1-py2.py3-none-any.whl (96 kB)\n", 30 | " ---------------------------------------- 0.0/96.5 kB ? eta -:--:--\n", 31 | " ---- ----------------------------------- 10.2/96.5 kB ? eta -:--:--\n", 32 | " -------- ------------------------------- 20.5/96.5 kB 330.3 kB/s eta 0:00:01\n", 33 | " ---------------- ----------------------- 41.0/96.5 kB 279.3 kB/s eta 0:00:01\n", 34 | " ------------------------- -------------- 61.4/96.5 kB 328.2 kB/s eta 0:00:01\n", 35 | " ----------------------------- ---------- 71.7/96.5 kB 326.8 kB/s eta 0:00:01\n", 36 | " ---------------------------------------- 96.5/96.5 kB 425.1 kB/s eta 0:00:00\n", 37 | "Installing collected packages: xlrd\n", 38 | "Successfully installed xlrd-2.0.1\n" 39 | ] 40 | } 41 | ], 42 | "source": [ 43 | "! pip install xlrd" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 9, 49 | "id": "7d99710b-9312-4ba1-8211-a62a6dd6d1fe", 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "ename": "XLRDError", 54 | "evalue": "Unsupported format, or corrupt file: Expected BOF record; found b'SeNo\\tBat'", 55 | "output_type": "error", 56 | "traceback": [ 57 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 58 | "\u001b[1;31mXLRDError\u001b[0m Traceback (most recent call last)", 59 | "Cell \u001b[1;32mIn[9], line 2\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# 打开 Excel 文件\u001b[39;00m\n\u001b[1;32m----> 2\u001b[0m workbook \u001b[38;5;241m=\u001b[39m \u001b[43mxlrd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopen_workbook\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mD:/AA_work/AIS数据/18日.xls\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 4\u001b[0m \u001b[38;5;66;03m# 获取第一个工作表\u001b[39;00m\n\u001b[0;32m 5\u001b[0m sheet \u001b[38;5;241m=\u001b[39m workbook\u001b[38;5;241m.\u001b[39msheet_by_index(\u001b[38;5;241m0\u001b[39m)\n", 60 | "File \u001b[1;32mD:\\anaconda3\\envs\\Deep_learn_GPU\\Lib\\site-packages\\xlrd\\__init__.py:172\u001b[0m, in \u001b[0;36mopen_workbook\u001b[1;34m(filename, logfile, verbosity, use_mmap, file_contents, encoding_override, formatting_info, on_demand, ragged_rows, ignore_workbook_corruption)\u001b[0m\n\u001b[0;32m 169\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m file_format \u001b[38;5;129;01mand\u001b[39;00m file_format \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mxls\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m 170\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m XLRDError(FILE_FORMAT_DESCRIPTIONS[file_format]\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m; not supported\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m--> 172\u001b[0m bk \u001b[38;5;241m=\u001b[39m \u001b[43mopen_workbook_xls\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 173\u001b[0m \u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 174\u001b[0m \u001b[43m \u001b[49m\u001b[43mlogfile\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlogfile\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 175\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbosity\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbosity\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 176\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_mmap\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_mmap\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 177\u001b[0m \u001b[43m \u001b[49m\u001b[43mfile_contents\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfile_contents\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 178\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoding_override\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoding_override\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 179\u001b[0m \u001b[43m \u001b[49m\u001b[43mformatting_info\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mformatting_info\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 180\u001b[0m \u001b[43m \u001b[49m\u001b[43mon_demand\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mon_demand\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 181\u001b[0m \u001b[43m \u001b[49m\u001b[43mragged_rows\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mragged_rows\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 182\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_workbook_corruption\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_workbook_corruption\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 183\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 185\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m bk\n", 61 | "File \u001b[1;32mD:\\anaconda3\\envs\\Deep_learn_GPU\\Lib\\site-packages\\xlrd\\book.py:79\u001b[0m, in \u001b[0;36mopen_workbook_xls\u001b[1;34m(filename, logfile, verbosity, use_mmap, file_contents, encoding_override, formatting_info, on_demand, ragged_rows, ignore_workbook_corruption)\u001b[0m\n\u001b[0;32m 77\u001b[0m t1 \u001b[38;5;241m=\u001b[39m perf_counter()\n\u001b[0;32m 78\u001b[0m bk\u001b[38;5;241m.\u001b[39mload_time_stage_1 \u001b[38;5;241m=\u001b[39m t1 \u001b[38;5;241m-\u001b[39m t0\n\u001b[1;32m---> 79\u001b[0m biff_version \u001b[38;5;241m=\u001b[39m \u001b[43mbk\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgetbof\u001b[49m\u001b[43m(\u001b[49m\u001b[43mXL_WORKBOOK_GLOBALS\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 80\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m biff_version:\n\u001b[0;32m 81\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m XLRDError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCan\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt determine file\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ms BIFF version\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", 62 | "File \u001b[1;32mD:\\anaconda3\\envs\\Deep_learn_GPU\\Lib\\site-packages\\xlrd\\book.py:1284\u001b[0m, in \u001b[0;36mBook.getbof\u001b[1;34m(self, rqd_stream)\u001b[0m\n\u001b[0;32m 1282\u001b[0m bof_error(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mExpected BOF record; met end of file\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m 1283\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m opcode \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m bofcodes:\n\u001b[1;32m-> 1284\u001b[0m \u001b[43mbof_error\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mExpected BOF record; found \u001b[39;49m\u001b[38;5;132;43;01m%r\u001b[39;49;00m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m%\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmem\u001b[49m\u001b[43m[\u001b[49m\u001b[43msavpos\u001b[49m\u001b[43m:\u001b[49m\u001b[43msavpos\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m8\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1285\u001b[0m length \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget2bytes()\n\u001b[0;32m 1286\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m length \u001b[38;5;241m==\u001b[39m MY_EOF:\n", 63 | "File \u001b[1;32mD:\\anaconda3\\envs\\Deep_learn_GPU\\Lib\\site-packages\\xlrd\\book.py:1278\u001b[0m, in \u001b[0;36mBook.getbof..bof_error\u001b[1;34m(msg)\u001b[0m\n\u001b[0;32m 1277\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbof_error\u001b[39m(msg):\n\u001b[1;32m-> 1278\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m XLRDError(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mUnsupported format, or corrupt file: \u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m msg)\n", 64 | "\u001b[1;31mXLRDError\u001b[0m: Unsupported format, or corrupt file: Expected BOF record; found b'SeNo\\tBat'" 65 | ] 66 | } 67 | ], 68 | "source": [ 69 | "\n", 70 | "# 打开 Excel 文件\n", 71 | "workbook = xlrd.open_workbook('D:/AA_work/AIS数据/18日.xls')\n", 72 | "\n", 73 | "# 获取第一个工作表\n", 74 | "sheet = workbook.sheet_by_index(0)\n", 75 | "\n", 76 | "# 获取行数和列数\n", 77 | "num_rows = sheet.nrows\n", 78 | "num_cols = sheet.ncols\n", 79 | "num_cols" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "id": "d47e2e20-c3ea-418a-b7d5-d2e6ad900687", 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [] 89 | } 90 | ], 91 | "metadata": { 92 | "kernelspec": { 93 | "display_name": "Python 3 (ipykernel)", 94 | "language": "python", 95 | "name": "python3" 96 | }, 97 | "language_info": { 98 | "codemirror_mode": { 99 | "name": "ipython", 100 | "version": 3 101 | }, 102 | "file_extension": ".py", 103 | "mimetype": "text/x-python", 104 | "name": "python", 105 | "nbconvert_exporter": "python", 106 | "pygments_lexer": "ipython3", 107 | "version": "3.11.7" 108 | } 109 | }, 110 | "nbformat": 4, 111 | "nbformat_minor": 5 112 | } 113 | -------------------------------------------------------------------------------- /S-Transformer/config_trAISformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021, Duong Nguyen 3 | # 4 | # Licensed under the CECILL-C License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.cecill.info 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Configuration flags to run the main script. 17 | """ 18 | 19 | import os 20 | import pickle 21 | import torch 22 | 23 | 24 | class Config(): 25 | retrain = True 26 | tb_log = False 27 | device = torch.device("cuda:0") 28 | # device = torch.device("cpu") 29 | 30 | max_epochs = 50 31 | batch_size = 32 32 | n_samples = 16 33 | 34 | init_seqlen = 18 35 | max_seqlen = 120 36 | min_seqlen = 36 37 | 38 | dataset_name = "ct_dma" 39 | 40 | if dataset_name == "ct_dma": #============================== 41 | 42 | # When mode == "grad" or "pos_grad", sog and cog are actually dlat and 43 | # dlon 44 | lat_size = 250 45 | lon_size = 270 46 | sog_size = 30 47 | cog_size = 72 48 | 49 | 50 | n_lat_embd = 256 51 | n_lon_embd = 256 52 | n_sog_embd = 128 53 | n_cog_embd = 128 54 | 55 | lat_min = 55.5 56 | lat_max = 58.0 57 | lon_min = 10.3 58 | lon_max = 13 59 | 60 | 61 | #=========================================================================== 62 | # Model and sampling flags 63 | mode = "pos" #"pos", "pos_grad", "mlp_pos", "mlpgrid_pos", "velo", "grid_l2", "grid_l1", 64 | # "ce_vicinity", "gridcont_grid", "gridcont_real", "gridcont_gridsin", "gridcont_gridsigmoid" 65 | sample_mode = "pos_vicinity" # "pos", "pos_vicinity" or "velo" 66 | top_k = 10 # int or None 67 | r_vicinity = 40 # int 68 | 69 | # Blur flags 70 | #=================================================== 71 | blur = True 72 | blur_learnable = False 73 | blur_loss_w = 1.0 74 | blur_n = 2 75 | if not blur: 76 | blur_n = 0 77 | blur_loss_w = 0 78 | 79 | # Data flags 80 | #=================================================== 81 | datadir = f"./data/{dataset_name}/" 82 | trainset_name = f"{dataset_name}_train.pkl" 83 | validset_name = f"{dataset_name}_valid.pkl" 84 | testset_name = f"{dataset_name}_test.pkl" 85 | 86 | 87 | # model parameters 88 | #=================================================== 89 | n_head = 8 90 | n_layer = 8 91 | full_size = lat_size + lon_size + sog_size + cog_size 92 | n_embd = n_lat_embd + n_lon_embd + n_sog_embd + n_cog_embd 93 | # base GPT config, params common to all GPT versions 94 | embd_pdrop = 0.1 95 | resid_pdrop = 0.1 96 | attn_pdrop = 0.1 97 | 98 | # optimization parameters 99 | #=================================================== 100 | learning_rate = 6e-4 # 6e-4 101 | betas = (0.9, 0.95) 102 | grad_norm_clip = 1.0 103 | weight_decay = 0.1 # only applied on matmul weights 104 | # learning rate decay params: linear warmup followed by cosine decay to 10% of original 105 | lr_decay = True 106 | warmup_tokens = 512*20 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere 107 | final_tokens = 260e9 # (at what point we reach 10% of original LR) 108 | num_workers = 4 # for DataLoader 109 | 110 | filename = f"{dataset_name}"\ 111 | + f"-{mode}-{sample_mode}-{top_k}-{r_vicinity}"\ 112 | + f"-blur-{blur}-{blur_learnable}-{blur_n}-{blur_loss_w}"\ 113 | + f"-data_size-{lat_size}-{lon_size}-{sog_size}-{cog_size}"\ 114 | + f"-embd_size-{n_lat_embd}-{n_lon_embd}-{n_sog_embd}-{n_cog_embd}"\ 115 | + f"-head-{n_head}-{n_layer}"\ 116 | + f"-bs-{batch_size}"\ 117 | + f"-lr-{learning_rate}"\ 118 | + f"-seqlen-{init_seqlen}-{max_seqlen}" 119 | savedir = "./results/"+filename+"/" 120 | 121 | ckpt_path = os.path.join(savedir,"model.pt") -------------------------------------------------------------------------------- /S-Transformer/datasets.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021, Duong Nguyen 3 | # 4 | # Licensed under the CECILL-C License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.cecill.info 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 这两个数据集类可以帮助用户加载和处理 AIS 船舶轨迹数据,并将其转换为 PyTorch 模型所需的张量格式。 16 | """Customized Pytorch Dataset. 17 | """ 18 | 19 | import numpy as np 20 | import os 21 | import pickle 22 | 23 | import torch 24 | from torch.utils.data import Dataset, DataLoader 25 | 26 | class AISDataset(Dataset): 27 | """Customized Pytorch dataset. 28 | """ 29 | def __init__(self, 30 | l_data, 31 | max_seqlen=96, 32 | dtype=torch.float32, 33 | device=torch.device("cpu")): 34 | """ 35 | Args 36 | l_data: list of dictionaries, each element is an AIS trajectory. 37 | l_data[idx]["mmsi"]: vessel's MMSI. 38 | l_data[idx]["traj"]: a matrix whose columns are 39 | [LAT, LON, SOG, COG, TIMESTAMP] 40 | lat, lon, sog, and cod have been standardized, i.e. range = [0,1). 41 | max_seqlen: (optional) max sequence length. Default is 42 | """ 43 | 44 | self.max_seqlen = max_seqlen 45 | self.device = device 46 | 47 | self.l_data = l_data 48 | 49 | def __len__(self): 50 | return len(self.l_data) 51 | 52 | def __getitem__(self, idx): 53 | """Gets items. 54 | 55 | Returns: 56 | seq: Tensor of (max_seqlen, [lat,lon,sog,cog]). 57 | mask: Tensor of (max_seqlen, 1). mask[i] = 0.0 if x[i] is a 58 | padding. 59 | seqlen: sequence length. 60 | mmsi: vessel's MMSI. 61 | time_start: timestamp of the starting time of the trajectory. 62 | """ 63 | V = self.l_data[idx] 64 | m_v = V["traj"][:,:4] # lat, lon, sog, cog 65 | # m_v[m_v==1] = 0.9999 66 | m_v[m_v>0.9999] = 0.9999 67 | seqlen = min(len(m_v), self.max_seqlen) 68 | seq = np.zeros((self.max_seqlen,4)) 69 | seq[:seqlen,:] = m_v[:seqlen,:] 70 | seq = torch.tensor(seq, dtype=torch.float32) 71 | 72 | mask = torch.zeros(self.max_seqlen) 73 | mask[:seqlen] = 1. 74 | 75 | seqlen = torch.tensor(seqlen, dtype=torch.int) 76 | mmsi = torch.tensor(V["mmsi"], dtype=torch.int) 77 | time_start = torch.tensor(V["traj"][0,4], dtype=torch.int) 78 | 79 | return seq , mask, seqlen, mmsi, time_start 80 | 81 | class AISDataset_grad(Dataset): 82 | """Customized Pytorch dataset. 83 | Return the positions and the gradient of the positions. 84 | """ 85 | def __init__(self, 86 | l_data, 87 | dlat_max=0.04, 88 | dlon_max=0.04, 89 | max_seqlen=96, 90 | dtype=torch.float32, 91 | device=torch.device("cpu")): 92 | """ 93 | Args 94 | l_data: list of dictionaries, each element is an AIS trajectory. 95 | l_data[idx]["mmsi"]: vessel's MMSI. 96 | l_data[idx]["traj"]: a matrix whose columns are 97 | [LAT, LON, SOG, COG, TIMESTAMP] 98 | lat, lon, sog, and cod have been standardized, i.e. range = [0,1). 99 | dlat_max, dlon_max: the maximum value of the gradient of the positions. 100 | dlat_max = max(lat[idx+1]-lat[idx]) for all idx. 101 | max_seqlen: (optional) max sequence length. Default is 102 | """ 103 | 104 | self.dlat_max = dlat_max 105 | self.dlon_max = dlon_max 106 | self.dpos_max = np.array([dlat_max, dlon_max]) 107 | self.max_seqlen = max_seqlen 108 | self.device = device 109 | 110 | self.l_data = l_data 111 | 112 | def __len__(self): 113 | return len(self.l_data) 114 | 115 | def __getitem__(self, idx): 116 | """Gets items. 117 | 118 | Returns: 119 | seq: Tensor of (max_seqlen, [lat,lon,sog,cog]). 120 | mask: Tensor of (max_seqlen, 1). mask[i] = 0.0 if x[i] is a 121 | padding. 122 | seqlen: sequence length. 123 | mmsi: vessel's MMSI. 124 | time_start: timestamp of the starting time of the trajectory. 125 | """ 126 | V = self.l_data[idx] 127 | m_v = V["traj"][:,:4] # lat, lon, sog, cog 128 | m_v[m_v==1] = 0.9999 129 | seqlen = min(len(m_v), self.max_seqlen) 130 | seq = np.zeros((self.max_seqlen,4)) 131 | # lat and lon 132 | seq[:seqlen,:2] = m_v[:seqlen,:2] 133 | # dlat and dlon 134 | dpos = (m_v[1:,:2]-m_v[:-1,:2]+self.dpos_max )/(2*self.dpos_max ) 135 | dpos = np.concatenate((dpos[:1,:],dpos),axis=0) 136 | dpos[dpos>=1] = 0.9999 137 | dpos[dpos<=0] = 0.0 138 | seq[:seqlen,2:] = dpos[:seqlen,:2] 139 | 140 | # convert to Tensor 141 | seq = torch.tensor(seq, dtype=torch.float32) 142 | 143 | mask = torch.zeros(self.max_seqlen) 144 | mask[:seqlen] = 1. 145 | 146 | seqlen = torch.tensor(seqlen, dtype=torch.int) 147 | mmsi = torch.tensor(V["mmsi"], dtype=torch.int) 148 | time_start = torch.tensor(V["traj"][0,4], dtype=torch.int) 149 | 150 | return seq , mask, seqlen, mmsi, time_start -------------------------------------------------------------------------------- /S-Transformer/figures/t18_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axyqdm/Ship-trajectory-data-preprocessing-and-analysis/fa2c61a3d177b13f60e2e615fecf4b80f7db6c7c/S-Transformer/figures/t18_3.png -------------------------------------------------------------------------------- /S-Transformer/models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021, Duong Nguyen 3 | # 4 | # Licensed under the CECILL-C License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.cecill.info 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Models for TrAISformer. 17 | https://arxiv.org/abs/2109.03958 18 | 19 | The code is built upon: 20 | https://github.com/karpathy/minGPT 21 | """ 22 | 23 | import math 24 | import logging 25 | import pdb 26 | 27 | 28 | import torch 29 | import torch.nn as nn 30 | from torch.nn import functional as F 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | class CausalSelfAttention(nn.Module): 36 | """ 37 | A vanilla multi-head masked self-attention layer with a projection at the end. 38 | It is possible to use torch.nn.MultiheadAttention here but I am including an 39 | explicit implementation here to show that there is nothing too scary here. 40 | """ 41 | 42 | def __init__(self, config): 43 | super().__init__() 44 | assert config.n_embd % config.n_head == 0 45 | # key, query, value projections for all heads 46 | self.key = nn.Linear(config.n_embd, config.n_embd) 47 | self.query = nn.Linear(config.n_embd, config.n_embd) 48 | self.value = nn.Linear(config.n_embd, config.n_embd) 49 | # regularization 50 | self.attn_drop = nn.Dropout(config.attn_pdrop) 51 | self.resid_drop = nn.Dropout(config.resid_pdrop) 52 | # output projection 53 | self.proj = nn.Linear(config.n_embd, config.n_embd) 54 | # causal mask to ensure that attention is only applied to the left in the input sequence 55 | self.register_buffer("mask", torch.tril(torch.ones(config.max_seqlen, config.max_seqlen)) 56 | .view(1, 1, config.max_seqlen, config.max_seqlen)) 57 | self.n_head = config.n_head 58 | 59 | def forward(self, x, layer_past=None): 60 | B, T, C = x.size() 61 | 62 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 63 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 64 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 65 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 66 | 67 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 68 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 69 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) 70 | att = F.softmax(att, dim=-1) 71 | att = self.attn_drop(att) 72 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 73 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 74 | 75 | # output projection 76 | y = self.resid_drop(self.proj(y)) 77 | return y 78 | 79 | class Block(nn.Module): 80 | """ an unassuming Transformer block """ 81 | 82 | def __init__(self, config): 83 | super().__init__() 84 | self.ln1 = nn.LayerNorm(config.n_embd) 85 | self.ln2 = nn.LayerNorm(config.n_embd) 86 | self.attn = CausalSelfAttention(config) 87 | self.mlp = nn.Sequential( 88 | nn.Linear(config.n_embd, 4 * config.n_embd), 89 | nn.GELU(), 90 | nn.Linear(4 * config.n_embd, config.n_embd), 91 | nn.Dropout(config.resid_pdrop), 92 | ) 93 | 94 | def forward(self, x): 95 | x = x + self.attn(self.ln1(x)) 96 | x = x + self.mlp(self.ln2(x)) 97 | return x 98 | 99 | class TrAISformer(nn.Module): 100 | """Transformer for AIS trajectories.""" 101 | 102 | def __init__(self, config, partition_model = None): 103 | super().__init__() 104 | 105 | self.lat_size = config.lat_size 106 | self.lon_size = config.lon_size 107 | self.sog_size = config.sog_size 108 | self.cog_size = config.cog_size 109 | self.full_size = config.full_size 110 | self.n_lat_embd = config.n_lat_embd 111 | self.n_lon_embd = config.n_lon_embd 112 | self.n_sog_embd = config.n_sog_embd 113 | self.n_cog_embd = config.n_cog_embd 114 | self.register_buffer( 115 | "att_sizes", 116 | torch.tensor([config.lat_size, config.lon_size, config.sog_size, config.cog_size])) 117 | self.register_buffer( 118 | "emb_sizes", 119 | torch.tensor([config.n_lat_embd, config.n_lon_embd, config.n_sog_embd, config.n_cog_embd])) 120 | 121 | if hasattr(config,"partition_mode"): 122 | self.partition_mode = config.partition_mode 123 | else: 124 | self.partition_mode = "uniform" 125 | self.partition_model = partition_model 126 | 127 | if hasattr(config,"blur"): 128 | self.blur = config.blur 129 | self.blur_learnable = config.blur_learnable 130 | self.blur_loss_w = config.blur_loss_w 131 | self.blur_n = config.blur_n 132 | if self.blur: 133 | self.blur_module = nn.Conv1d(1, 1, 3, padding = 1, padding_mode = 'replicate', groups=1, bias=False) 134 | if not self.blur_learnable: 135 | for params in self.blur_module.parameters(): 136 | params.requires_grad = False 137 | params.fill_(1/3) 138 | else: 139 | self.blur_module = None 140 | 141 | 142 | if hasattr(config,"lat_min"): # the ROI is provided. 143 | self.lat_min = config.lat_min 144 | self.lat_max = config.lat_max 145 | self.lon_min = config.lon_min 146 | self.lon_max = config.lon_max 147 | self.lat_range = config.lat_max-config.lat_min 148 | self.lon_range = config.lon_max-config.lon_min 149 | self.sog_range = 30. 150 | 151 | if hasattr(config,"mode"): # mode: "pos" or "velo". 152 | # "pos": predict directly the next positions. 153 | # "velo": predict the velocities, use them to 154 | # calculate the next positions. 155 | self.mode = config.mode 156 | else: 157 | self.mode = "pos" 158 | 159 | 160 | # Passing from the 4-D space to a high-dimentional space 161 | self.lat_emb = nn.Embedding(self.lat_size, config.n_lat_embd) 162 | self.lon_emb = nn.Embedding(self.lon_size, config.n_lon_embd) 163 | self.sog_emb = nn.Embedding(self.sog_size, config.n_sog_embd) 164 | self.cog_emb = nn.Embedding(self.cog_size, config.n_cog_embd) 165 | 166 | 167 | self.pos_emb = nn.Parameter(torch.zeros(1, config.max_seqlen, config.n_embd)) 168 | self.drop = nn.Dropout(config.embd_pdrop) 169 | 170 | # transformer 171 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) 172 | 173 | 174 | # decoder head 175 | self.ln_f = nn.LayerNorm(config.n_embd) 176 | if self.mode in ("mlp_pos","mlp"): 177 | self.head = nn.Linear(config.n_embd, config.n_embd, bias=False) 178 | else: 179 | self.head = nn.Linear(config.n_embd, self.full_size, bias=False) # Classification head 180 | 181 | self.max_seqlen = config.max_seqlen 182 | self.apply(self._init_weights) 183 | 184 | logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) 185 | 186 | def get_max_seqlen(self): 187 | return self.max_seqlen 188 | 189 | def _init_weights(self, module): 190 | if isinstance(module, (nn.Linear, nn.Embedding)): 191 | module.weight.data.normal_(mean=0.0, std=0.02) 192 | if isinstance(module, nn.Linear) and module.bias is not None: 193 | module.bias.data.zero_() 194 | elif isinstance(module, nn.LayerNorm): 195 | module.bias.data.zero_() 196 | module.weight.data.fill_(1.0) 197 | 198 | def configure_optimizers(self, train_config): 199 | """ 200 | This long function is unfortunately doing something very simple and is being very defensive: 201 | We are separating out all parameters of the model into two buckets: those that will experience 202 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 203 | We are then returning the PyTorch optimizer object. 204 | """ 205 | 206 | # separate out all parameters to those that will and won't experience regularizing weight decay 207 | decay = set() 208 | no_decay = set() 209 | whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv1d) 210 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 211 | for mn, m in self.named_modules(): 212 | for pn, p in m.named_parameters(): 213 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 214 | 215 | if pn.endswith('bias'): 216 | # all biases will not be decayed 217 | no_decay.add(fpn) 218 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 219 | # weights of whitelist modules will be weight decayed 220 | decay.add(fpn) 221 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 222 | # weights of blacklist modules will NOT be weight decayed 223 | no_decay.add(fpn) 224 | 225 | # special case the position embedding parameter in the root GPT module as not decayed 226 | no_decay.add('pos_emb') 227 | 228 | # validate that we considered every parameter 229 | param_dict = {pn: p for pn, p in self.named_parameters()} 230 | inter_params = decay & no_decay 231 | union_params = decay | no_decay 232 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 233 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 234 | % (str(param_dict.keys() - union_params), ) 235 | 236 | # create the pytorch optimizer object 237 | optim_groups = [ 238 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, 239 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 240 | ] 241 | optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) 242 | return optimizer 243 | 244 | 245 | def to_indexes(self, x, mode="uniform"): 246 | """Convert tokens to indexes. 247 | 248 | Args: 249 | x: a Tensor of size (batchsize, seqlen, 4). x has been truncated 250 | to [0,1). 251 | model: currenly only supports "uniform". 252 | 253 | Returns: 254 | idxs: a Tensor (dtype: Long) of indexes. 255 | """ 256 | bs, seqlen, data_dim = x.shape 257 | if mode == "uniform": 258 | idxs = (x*self.att_sizes).long() 259 | return idxs, idxs 260 | elif mode in ("freq", "freq_uniform"): 261 | 262 | idxs = (x*self.att_sizes).long() 263 | idxs_uniform = idxs.clone() 264 | discrete_lats, discrete_lons, lat_ids, lon_ids = self.partition_model(x[:,:,:2]) 265 | # pdb.set_trace() 266 | idxs[:,:,0] = torch.round(lat_ids.reshape((bs,seqlen))).long() 267 | idxs[:,:,1] = torch.round(lon_ids.reshape((bs,seqlen))).long() 268 | return idxs, idxs_uniform 269 | 270 | 271 | def forward(self, x, masks = None, with_targets=False, return_loss_tuple=False): 272 | """ 273 | Args: 274 | x: a Tensor of size (batchsize, seqlen, 4). x has been truncated 275 | to [0,1). 276 | masks: a Tensor of the same size of x. masks[idx] = 0. if 277 | x[idx] is a padding. 278 | with_targets: if True, inputs = x[:,:-1,:], targets = x[:,1:,:], 279 | otherwise inputs = x. 280 | Returns: 281 | logits, loss 282 | """ 283 | 284 | if self.mode in ("mlp_pos","mlp",): 285 | idxs, idxs_uniform = x, x # use the real-values of x. 286 | else: 287 | # Convert to indexes 288 | idxs, idxs_uniform = self.to_indexes(x, mode=self.partition_mode) 289 | 290 | if with_targets: 291 | inputs = idxs[:,:-1,:].contiguous() 292 | targets = idxs[:,1:,:].contiguous() 293 | targets_uniform = idxs_uniform[:,1:,:].contiguous() 294 | inputs_real = x[:,:-1,:].contiguous() 295 | targets_real = x[:,1:,:].contiguous() 296 | else: 297 | inputs_real = x 298 | inputs = idxs 299 | targets = None 300 | 301 | batchsize, seqlen, _ = inputs.size() 302 | assert seqlen <= self.max_seqlen, "Cannot forward, model block size is exhausted." 303 | 304 | # forward the GPT model 305 | lat_embeddings = self.lat_emb(inputs[:,:,0]) # (bs, seqlen, lat_size) 306 | lon_embeddings = self.lon_emb(inputs[:,:,1]) 307 | sog_embeddings = self.sog_emb(inputs[:,:,2]) 308 | cog_embeddings = self.cog_emb(inputs[:,:,3]) 309 | token_embeddings = torch.cat((lat_embeddings, lon_embeddings, sog_embeddings, cog_embeddings),dim=-1) 310 | 311 | position_embeddings = self.pos_emb[:, :seqlen, :] # each position maps to a (learnable) vector (1, seqlen, n_embd) 312 | fea = self.drop(token_embeddings + position_embeddings) 313 | fea = self.blocks(fea) 314 | fea = self.ln_f(fea) # (bs, seqlen, n_embd) 315 | logits = self.head(fea) # (bs, seqlen, full_size) or (bs, seqlen, n_embd) 316 | 317 | lat_logits, lon_logits, sog_logits, cog_logits =\ 318 | torch.split(logits, (self.lat_size, self.lon_size, self.sog_size, self.cog_size), dim=-1) 319 | 320 | # Calculate the loss 321 | loss = None 322 | loss_tuple = None 323 | if targets is not None: 324 | 325 | sog_loss = F.cross_entropy(sog_logits.view(-1, self.sog_size), 326 | targets[:,:,2].view(-1), 327 | reduction="none").view(batchsize,seqlen) 328 | cog_loss = F.cross_entropy(cog_logits.view(-1, self.cog_size), 329 | targets[:,:,3].view(-1), 330 | reduction="none").view(batchsize,seqlen) 331 | lat_loss = F.cross_entropy(lat_logits.view(-1, self.lat_size), 332 | targets[:,:,0].view(-1), 333 | reduction="none").view(batchsize,seqlen) 334 | lon_loss = F.cross_entropy(lon_logits.view(-1, self.lon_size), 335 | targets[:,:,1].view(-1), 336 | reduction="none").view(batchsize,seqlen) 337 | 338 | if self.blur: 339 | lat_probs = F.softmax(lat_logits, dim=-1) 340 | lon_probs = F.softmax(lon_logits, dim=-1) 341 | sog_probs = F.softmax(sog_logits, dim=-1) 342 | cog_probs = F.softmax(cog_logits, dim=-1) 343 | 344 | for _ in range(self.blur_n): 345 | blurred_lat_probs = self.blur_module(lat_probs.reshape(-1,1,self.lat_size)).reshape(lat_probs.shape) 346 | blurred_lon_probs = self.blur_module(lon_probs.reshape(-1,1,self.lon_size)).reshape(lon_probs.shape) 347 | blurred_sog_probs = self.blur_module(sog_probs.reshape(-1,1,self.sog_size)).reshape(sog_probs.shape) 348 | blurred_cog_probs = self.blur_module(cog_probs.reshape(-1,1,self.cog_size)).reshape(cog_probs.shape) 349 | 350 | blurred_lat_loss = F.nll_loss(blurred_lat_probs.view(-1, self.lat_size), 351 | targets[:,:,0].view(-1), 352 | reduction="none").view(batchsize,seqlen) 353 | blurred_lon_loss = F.nll_loss(blurred_lon_probs.view(-1, self.lon_size), 354 | targets[:,:,1].view(-1), 355 | reduction="none").view(batchsize,seqlen) 356 | blurred_sog_loss = F.nll_loss(blurred_sog_probs.view(-1, self.sog_size), 357 | targets[:,:,2].view(-1), 358 | reduction="none").view(batchsize,seqlen) 359 | blurred_cog_loss = F.nll_loss(blurred_cog_probs.view(-1, self.cog_size), 360 | targets[:,:,3].view(-1), 361 | reduction="none").view(batchsize,seqlen) 362 | 363 | lat_loss += self.blur_loss_w*blurred_lat_loss 364 | lon_loss += self.blur_loss_w*blurred_lon_loss 365 | sog_loss += self.blur_loss_w*blurred_sog_loss 366 | cog_loss += self.blur_loss_w*blurred_cog_loss 367 | 368 | lat_probs = blurred_lat_probs 369 | lon_probs = blurred_lon_probs 370 | sog_probs = blurred_sog_probs 371 | cog_probs = blurred_cog_probs 372 | 373 | 374 | loss_tuple = (lat_loss, lon_loss, sog_loss, cog_loss) 375 | loss = sum(loss_tuple) 376 | 377 | if masks is not None: 378 | loss = (loss*masks).sum(dim=1)/masks.sum(dim=1) 379 | 380 | loss = loss.mean() 381 | 382 | if return_loss_tuple: 383 | return logits, loss, loss_tuple 384 | else: 385 | return logits, loss 386 | 387 | -------------------------------------------------------------------------------- /S-Transformer/requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | - _libgcc_mutex=0.1 3 | - _py-xgboost-mutex=2.0 4 | - aiohttp=3.7.4.post0 5 | - argon2-cffi=20.1.0 6 | - async-timeout=3.0.1 7 | - async_generator=1.10 8 | - attrs=20.3.0 9 | - blas=1.0 10 | - bleach=3.3.0 11 | - blinker=1.4 12 | - bokeh=2.3.3 13 | - bottleneck=1.3.2 14 | - brotlipy=0.7.0 15 | - c-ares=1.17.1 16 | - ca-certificates=2021.5.30 17 | - catalogue=1.0.0 18 | - catboost=0.26 19 | - certifi=2021.5.30 20 | - cffi=1.14.5 21 | - cftime=1.5.0 22 | - chardet=4.0.0 23 | - cloudpickle=1.6.0 24 | - confuse=1.4.0 25 | - cryptography=3.4.7 26 | - cudatoolkit=9.2 27 | - curl=7.71.1 28 | - cycler=0.10.0 29 | - cymem=2.0.5 30 | - cython-blis=0.7.4 31 | - cytoolz=0.9.0.1 32 | - dask=2021.7.2 33 | - dask-core=2021.7.2 34 | - dataclasses=0.8 35 | - decorator=5.0.5 36 | - defusedxml=0.7.1 37 | - dill=0.2.9 38 | - distributed=2021.7.2 39 | - entrypoints=0.3 40 | - freetype=2.10.4 41 | - fsspec=2021.7.0 42 | - future=0.18.2 43 | - hdf4=4.2.13 44 | - hdf5=1.10.6 45 | - heapdict=1.0.1 46 | - htmlmin=0.1.12 47 | - idna=2.10 48 | - imagehash=4.2.0 49 | - imbalanced-learn=0.8.0 50 | - importlib-metadata=3.7.3 51 | - importlib_metadata=3.7.3 52 | - intel-openmp=2020.2 53 | - ipykernel=5.3.4 54 | - ipython=5.8.0 55 | - ipython_genutils=0.2.0 56 | - ipywidgets=7.5.1 57 | - jinja2=2.11.3 58 | - jpeg=9b 59 | - json5=0.9.5 60 | - jsonschema=3.0.2 61 | - jupyter_client=6.1.12 62 | - jupyter_core=4.7.1 63 | - jupyterlab=2.2.6 64 | - jupyterlab_pygments=0.1.2 65 | - jupyterlab_server=1.2.0 66 | - kiwisolver=1.3.1 67 | - krb5=1.18.2 68 | - lcms2=2.12 69 | - libcurl=7.71.1 70 | - libffi=3.3 71 | - libllvm10=10.0.1 72 | - libnetcdf=4.6.1 73 | - libpng=1.6.37 74 | - libprotobuf=3.17.2 75 | - libsodium=1.0.18 76 | - libssh2=1.9.0 77 | - libtiff=4.1.0 78 | - libxgboost=1.3.3 79 | - lightgbm=3.1.1 80 | - llvmlite=0.36.0 81 | - locket=0.2.0 82 | - lz4-c=1.9.3 83 | - markupsafe=1.1.1 84 | - matplotlib=3.3.2 85 | - matplotlib-base=3.3.2 86 | - missingno=0.4.2 87 | - mistune=0.8.4 88 | - mkl=2020.2 89 | - mkl-service=2.3.0 90 | - mkl_fft=1.3.0 91 | - mkl_random=1.1.1 92 | - msgpack-numpy=0.4.7.1 93 | - msgpack-python=1.0.2 94 | - multidict=5.1.0 95 | - nb_conda=2.2.1 96 | - nb_conda_kernels=2.3.1 97 | - nbclient=0.5.3 98 | - nbconvert=6.0.7 99 | - nbformat=5.1.3 100 | - nest-asyncio=1.5.1 101 | - netcdf4=1.5.7 102 | - networkx=2.5 103 | - ninja=1.10.2 104 | - notebook=6.3.0 105 | - numba=0.53.1 106 | - numpy=1.19.2 107 | - numpy-base=1.19.2 108 | - olefile=0.46 109 | - openssl=1.1.1k 110 | - packaging=20.9 111 | - pandas=1.2.3 112 | - pandas-profiling=2.9.0 113 | - pandoc=2.12 114 | - pandocfilters=1.4.3 115 | - partd=1.2.0 116 | - pexpect=4.8.0 117 | - phik=0.11.2 118 | - pickleshare=0.7.5 119 | - pillow=8.2.0 120 | - pip=21.0.1 121 | - plac=1.1.0 122 | - preshed=3.0.2 123 | - proj=7.0.1 124 | - prometheus_client=0.10.0 125 | - prompt_toolkit=1.0.15 126 | - psutil=5.8.0 127 | - ptyprocess=0.7.0 128 | - pyasn1=0.4.8 129 | - pycparser=2.20 130 | - pydeprecate=0.3.1 131 | - pygments=2.8.1 132 | - pyjwt=2.1.0 133 | - pyopenssl=20.0.1 134 | - pyparsing=2.4.7 135 | - pyproj=2.6.1.post1 136 | - pyrsistent=0.17.3 137 | - pysocks=1.7.1 138 | - python=3.7.9 139 | - python-dateutil=2.8.1 140 | - python_abi=3.7 141 | - pytorch=1.6.0 142 | - pytorch-lightning=1.4.4 143 | - pytz=2021.1 144 | - pyu2f=0.1.5 145 | - pywavelets=1.1.1 146 | - pyzmq=20.0.0 147 | - requests=2.25.1 148 | - requests-oauthlib=1.3.0 149 | - seaborn=0.11.1 150 | - send2trash=1.5.0 151 | - setuptools=52.0.0 152 | - simplegeneric=0.8.1 153 | - six=1.15.0 154 | - sortedcontainers=2.4.0 155 | - spacy=2.3.5 156 | - sqlite=3.35.4 157 | - srsly=1.0.5 158 | - tangled-up-in-unicode=0.1.0 159 | - tbb=2020.3 160 | - tblib=1.7.0 161 | - tensorboard-data-server=0.6.0 162 | - terminado=0.9.4 163 | - testpath=0.4.4 164 | - thinc=7.4.5 165 | - threadpoolctl=2.1.0 166 | - tk=8.6.10 167 | - toolz=0.11.1 168 | - torchmetrics=0.5.0 169 | - torchtext=0.7.0 170 | - torchvision=0.7.0 171 | - tornado=6.1 172 | - traitlets=5.0.5 173 | - typing-extensions=3.10.0.0 174 | - typing_extensions=3.10.0.0 175 | - ujson=4.0.2 176 | - visions=0.5.0 177 | - wasabi=0.8.2 178 | - wcwidth=0.2.5 179 | - webencodings=0.5.1 180 | - wheel=0.36.2 181 | - widgetsnbextension=3.5.1 182 | - wrapt=1.12.1 183 | - xarray=0.19.0 184 | - xgboost=1.3.3 185 | - xz=5.2.5 186 | - yaml=0.2.5 187 | - yarl=1.6.3 188 | - zeromq=4.3.4 189 | - zict=2.0.0 190 | - zlib=1.2.11 191 | - zstd=1.4.9 192 | - pip: 193 | - readline==8.1 194 | - ncurses==6.2 195 | - libgomp==9.3.0 196 | - libstdcxx-ng==9.1.0 197 | - libgcc-ng==9.3.0 198 | - libgfortran-ng==7.3.0 199 | - py-xgboost==1.3.30 200 | - ld_impl_linux-64==2.33.1 201 | - murmurhash==1.0.50 202 | - libedit==3.1.20210216 203 | - _openmp_mutex==4.5 204 | - absl-py==0.10.0 205 | - addict==2.3.0 206 | - aiohttp-cors==0.7.0 207 | - aioredis==1.3.1 208 | - blessings==1.7 209 | - boltons==20.2.1 210 | - cachetools==4.1.1 211 | - click==7.1.2 212 | - colorama==0.4.4 213 | - colorful==0.5.4 214 | - de-core-news-sm==2.0.0 215 | - dm-tree==0.1.5 216 | - einops==0.3.0 217 | - en-core-web-sm==2.0.0 218 | - filelock==3.0.12 219 | - fire==0.4.0 220 | - google-api-core==1.26.3 221 | - google-auth==1.22.1 222 | - google-auth-oauthlib==0.4.1 223 | - googleapis-common-protos==1.53.0 224 | - gpustat==0.6.0 225 | - grpcio==1.32.0 226 | - hiredis==2.0.0 227 | - joblib==0.17.0 228 | - lambda-networks==0.4.0 229 | - markdown==3.3 230 | - nltk==3.5 231 | - nvidia-ml-py3==7.352.0 232 | - oauthlib==3.1.0 233 | - opencensus==0.7.12 234 | - opencensus-context==0.1.2 235 | - opencv-python==4.4.0.44 236 | - pathspec==0.8.0 237 | - protobuf==3.13.0 238 | - py-spy==0.3.5 239 | - pyasn1-modules==0.2.8 240 | - pyyaml==5.4.1 241 | - ray==1.2.0 242 | - redis==3.5.3 243 | - regex==2020.11.13 244 | - rsa==4.6 245 | - scikit-learn==0.23.2 246 | - scipy==1.5.2 247 | - sdeint==0.2.1 248 | - sklearn==0.0 249 | - tabulate==0.8.9 250 | - tensorboard==2.3.0 251 | - tensorboard-plugin-wit==1.7.0 252 | - termcolor==1.1.0 253 | - torchsde==0.2.1 254 | - tqdm==4.50.2 255 | - trampoline==0.1.2 256 | - urllib3==1.25.10 257 | - werkzeug==1.0.1 258 | - zipp==3.3.0 259 | 260 | -------------------------------------------------------------------------------- /S-Transformer/requirements.yml: -------------------------------------------------------------------------------- 1 | name: TrAISS 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1 8 | - _py-xgboost-mutex=2.0 9 | - aiohttp=3.7.4.post0 10 | - argon2-cffi=20.1.0 11 | - async-timeout=3.0.1 12 | - async_generator=1.10 13 | - attrs=20.3.0 14 | - blas=1.0 15 | - bleach=3.3.0 16 | - blinker=1.4 17 | - bokeh=2.3.3 18 | - bottleneck=1.3.2 19 | - brotlipy=0.7.0 20 | - c-ares=1.17.1 21 | - ca-certificates=2021.5.30 22 | - catalogue=1.0.0 23 | - catboost=0.26 24 | - certifi=2021.5.30 25 | - cffi=1.14.5 26 | - cftime=1.5.0 27 | - chardet=4.0.0 28 | - cloudpickle=1.6.0 29 | - confuse=1.4.0 30 | - cryptography=3.4.7 31 | - cudatoolkit=9.2 32 | - curl=7.71.1 33 | - cycler=0.10.0 34 | - cymem=2.0.5 35 | - cython-blis=0.7.4 36 | - cytoolz=0.9.0.1 37 | - dask=2021.7.2 38 | - dask-core=2021.7.2 39 | - dataclasses=0.8 40 | - decorator=5.0.5 41 | - defusedxml=0.7.1 42 | - dill=0.2.9 43 | - distributed=2021.7.2 44 | - entrypoints=0.3 45 | - freetype=2.10.4 46 | - fsspec=2021.7.0 47 | - future=0.18.2 48 | - hdf4=4.2.13 49 | - hdf5=1.10.6 50 | - heapdict=1.0.1 51 | - htmlmin=0.1.12 52 | - idna=2.10 53 | - imagehash=4.2.0 54 | - imbalanced-learn=0.8.0 55 | - importlib-metadata=3.7.3 56 | - importlib_metadata=3.7.3 57 | - intel-openmp=2020.2 58 | - ipykernel=5.3.4 59 | - ipython=5.8.0 60 | - ipython_genutils=0.2.0 61 | - ipywidgets=7.5.1 62 | - jinja2=2.11.3 63 | - jpeg=9b 64 | - json5=0.9.5 65 | - jsonschema=3.0.2 66 | - jupyter_client=6.1.12 67 | - jupyter_core=4.7.1 68 | - jupyterlab=2.2.6 69 | - jupyterlab_pygments=0.1.2 70 | - jupyterlab_server=1.2.0 71 | - kiwisolver=1.3.1 72 | - krb5=1.18.2 73 | - lcms2=2.12 74 | - libcurl=7.71.1 75 | - libffi=3.3 76 | - libllvm10=10.0.1 77 | - libnetcdf=4.6.1 78 | - libpng=1.6.37 79 | - libprotobuf=3.17.2 80 | - libsodium=1.0.18 81 | - libssh2=1.9.0 82 | - libtiff=4.1.0 83 | - libxgboost=1.3.3 84 | - lightgbm=3.1.1 85 | - llvmlite=0.36.0 86 | - locket=0.2.0 87 | - lz4-c=1.9.3 88 | - markupsafe=1.1.1 89 | - matplotlib=3.3.2 90 | - matplotlib-base=3.3.2 91 | - missingno=0.4.2 92 | - mistune=0.8.4 93 | - mkl=2020.2 94 | - mkl-service=2.3.0 95 | - mkl_fft=1.3.0 96 | - mkl_random=1.1.1 97 | - msgpack-numpy=0.4.7.1 98 | - msgpack-python=1.0.2 99 | - multidict=5.1.0 100 | - nb_conda=2.2.1 101 | - nb_conda_kernels=2.3.1 102 | - nbclient=0.5.3 103 | - nbconvert=6.0.7 104 | - nbformat=5.1.3 105 | - nest-asyncio=1.5.1 106 | - netcdf4=1.5.7 107 | - networkx=2.5 108 | - ninja=1.10.2 109 | - notebook=6.3.0 110 | - numba=0.53.1 111 | - numpy=1.19.2 112 | - numpy-base=1.19.2 113 | - olefile=0.46 114 | - openssl=1.1.1k 115 | - packaging=20.9 116 | - pandas=1.2.3 117 | - pandas-profiling=2.9.0 118 | - pandoc=2.12 119 | - pandocfilters=1.4.3 120 | - partd=1.2.0 121 | - pexpect=4.8.0 122 | - phik=0.11.2 123 | - pickleshare=0.7.5 124 | - pillow=8.2.0 125 | - pip=21.0.1 126 | - plac=1.1.0 127 | - preshed=3.0.2 128 | - proj=7.0.1 129 | - prometheus_client=0.10.0 130 | - prompt_toolkit=1.0.15 131 | - psutil=5.8.0 132 | - ptyprocess=0.7.0 133 | - pyasn1=0.4.8 134 | - pycparser=2.20 135 | - pydeprecate=0.3.1 136 | - pygments=2.8.1 137 | - pyjwt=2.1.0 138 | - pyopenssl=20.0.1 139 | - pyparsing=2.4.7 140 | - pyproj=2.6.1.post1 141 | - pyrsistent=0.17.3 142 | - pysocks=1.7.1 143 | - python=3.7.9 144 | - python-dateutil=2.8.1 145 | - python_abi=3.7 146 | - pytorch=1.6.0 147 | - pytorch-lightning=1.4.4 148 | - pytz=2021.1 149 | - pyu2f=0.1.5 150 | - pywavelets=1.1.1 151 | - pyzmq=20.0.0 152 | - requests=2.25.1 153 | - requests-oauthlib=1.3.0 154 | - seaborn=0.11.1 155 | - send2trash=1.5.0 156 | - setuptools=52.0.0 157 | - simplegeneric=0.8.1 158 | - six=1.15.0 159 | - sortedcontainers=2.4.0 160 | - spacy=2.3.5 161 | - sqlite=3.35.4 162 | - srsly=1.0.5 163 | - tangled-up-in-unicode=0.1.0 164 | - tbb=2020.3 165 | - tblib=1.7.0 166 | - tensorboard-data-server=0.6.0 167 | - terminado=0.9.4 168 | - testpath=0.4.4 169 | - thinc=7.4.5 170 | - threadpoolctl=2.1.0 171 | - tk=8.6.10 172 | - toolz=0.11.1 173 | - torchmetrics=0.5.0 174 | - torchtext=0.7.0 175 | - torchvision=0.7.0 176 | - tornado=6.1 177 | - traitlets=5.0.5 178 | - typing-extensions=3.10.0.0 179 | - typing_extensions=3.10.0.0 180 | - ujson=4.0.2 181 | - visions=0.5.0 182 | - wasabi=0.8.2 183 | - wcwidth=0.2.5 184 | - webencodings=0.5.1 185 | - wheel=0.36.2 186 | - widgetsnbextension=3.5.1 187 | - wrapt=1.12.1 188 | - xarray=0.19.0 189 | - xgboost=1.3.3 190 | - xz=5.2.5 191 | - yaml=0.2.5 192 | - yarl=1.6.3 193 | - zeromq=4.3.4 194 | - zict=2.0.0 195 | - zlib=1.2.11 196 | - zstd=1.4.9 197 | - pip: 198 | - readline==8.1 199 | - ncurses==6.2 200 | - libgomp==9.3.0 201 | - libstdcxx-ng==9.1.0 202 | - libgcc-ng==9.3.0 203 | - libgfortran-ng==7.3.0 204 | - py-xgboost==1.3.30 205 | - ld_impl_linux-64==2.33.1 206 | - murmurhash==1.0.50 207 | - libedit==3.1.20210216 208 | - _openmp_mutex==4.5 209 | - absl-py==0.10.0 210 | - addict==2.3.0 211 | - aiohttp-cors==0.7.0 212 | - aioredis==1.3.1 213 | - blessings==1.7 214 | - boltons==20.2.1 215 | - cachetools==4.1.1 216 | - click==7.1.2 217 | - colorama==0.4.4 218 | - colorful==0.5.4 219 | - de-core-news-sm==2.0.0 220 | - dm-tree==0.1.5 221 | - einops==0.3.0 222 | - en-core-web-sm==2.0.0 223 | - filelock==3.0.12 224 | - fire==0.4.0 225 | - google-api-core==1.26.3 226 | - google-auth==1.22.1 227 | - google-auth-oauthlib==0.4.1 228 | - googleapis-common-protos==1.53.0 229 | - gpustat==0.6.0 230 | - grpcio==1.32.0 231 | - hiredis==2.0.0 232 | - joblib==0.17.0 233 | - lambda-networks==0.4.0 234 | - markdown==3.3 235 | - nltk==3.5 236 | - nvidia-ml-py3==7.352.0 237 | - oauthlib==3.1.0 238 | - opencensus==0.7.12 239 | - opencensus-context==0.1.2 240 | - opencv-python==4.4.0.44 241 | - pathspec==0.8.0 242 | - protobuf==3.13.0 243 | - py-spy==0.3.5 244 | - pyasn1-modules==0.2.8 245 | - pyyaml==5.4.1 246 | - ray==1.2.0 247 | - redis==3.5.3 248 | - regex==2020.11.13 249 | - rsa==4.6 250 | - scikit-learn==0.23.2 251 | - scipy==1.5.2 252 | - sdeint==0.2.1 253 | - sklearn==0.0 254 | - tabulate==0.8.9 255 | - tensorboard==2.3.0 256 | - tensorboard-plugin-wit==1.7.0 257 | - termcolor==1.1.0 258 | - torchsde==0.2.1 259 | - tqdm==4.50.2 260 | - trampoline==0.1.2 261 | - urllib3==1.25.10 262 | - werkzeug==1.0.1 263 | - zipp==3.3.0 264 | 265 | -------------------------------------------------------------------------------- /S-Transformer/trAISformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # coding=utf-8 4 | # Copyright 2021, Duong Nguyen 5 | # 6 | # Licensed under the CECILL-C License; 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.cecill.info 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """Pytorch implementation of TrAISformer---A generative transformer for 19 | AIS trajectory prediction 20 | 21 | https://arxiv.org/abs/2109.03958 22 | 23 | """ 24 | import numpy as np 25 | from numpy import linalg 26 | import matplotlib.pyplot as plt 27 | import os 28 | import sys 29 | import pickle 30 | from tqdm import tqdm 31 | import math 32 | import logging 33 | import pdb 34 | 35 | import torch 36 | import torch.nn as nn 37 | from torch.nn import functional as F 38 | import torch.optim as optim 39 | from torch.optim.lr_scheduler import LambdaLR 40 | from torch.utils.data import Dataset, DataLoader 41 | 42 | import models, trainers, datasets, utils 43 | from config_trAISformer import Config 44 | 45 | cf = Config() 46 | TB_LOG = cf.tb_log 47 | if TB_LOG: 48 | from torch.utils.tensorboard import SummaryWriter 49 | 50 | tb = SummaryWriter() 51 | 52 | # make deterministic 53 | utils.set_seed(42) 54 | torch.pi = torch.acos(torch.zeros(1)).item() * 2 55 | 56 | if __name__ == "__main__": 57 | 58 | device = cf.device 59 | init_seqlen = cf.init_seqlen 60 | 61 | ## Logging 62 | # =============================== 63 | if not os.path.isdir(cf.savedir): 64 | os.makedirs(cf.savedir) 65 | print('======= Create directory to store trained models: ' + cf.savedir) 66 | else: 67 | print('======= Directory to store trained models: ' + cf.savedir) 68 | utils.new_log(cf.savedir, "log") 69 | 70 | ## Data 71 | # =============================== 72 | moving_threshold = 0.05 73 | l_pkl_filenames = [cf.trainset_name, cf.validset_name, cf.testset_name] 74 | Data, aisdatasets, aisdls = {}, {}, {} 75 | for phase, filename in zip(("train", "valid", "test"), l_pkl_filenames): 76 | datapath = os.path.join(cf.datadir, filename) 77 | print(f"Loading {datapath}...") 78 | with open(datapath, "rb") as f: 79 | l_pred_errors = pickle.load(f) 80 | for V in l_pred_errors: 81 | try: 82 | moving_idx = np.where(V["traj"][:, 2] > moving_threshold)[0][0] 83 | except: 84 | moving_idx = len(V["traj"]) - 1 # This track will be removed 85 | V["traj"] = V["traj"][moving_idx:, :] 86 | Data[phase] = [x for x in l_pred_errors if not np.isnan(x["traj"]).any() and len(x["traj"]) > cf.min_seqlen] 87 | print(len(l_pred_errors), len(Data[phase])) 88 | print(f"Length: {len(Data[phase])}") 89 | print("Creating pytorch dataset...") 90 | # Latter in this scipt, we will use inputs = x[:-1], targets = x[1:], hence 91 | # max_seqlen = cf.max_seqlen + 1. 92 | if cf.mode in ("pos_grad", "grad"): 93 | aisdatasets[phase] = datasets.AISDataset_grad(Data[phase], 94 | max_seqlen=cf.max_seqlen + 1, 95 | device=cf.device) 96 | else: 97 | aisdatasets[phase] = datasets.AISDataset(Data[phase], 98 | max_seqlen=cf.max_seqlen + 1, 99 | device=cf.device) 100 | if phase == "test": 101 | shuffle = False 102 | else: 103 | shuffle = True 104 | aisdls[phase] = DataLoader(aisdatasets[phase], 105 | batch_size=cf.batch_size, 106 | shuffle=shuffle) 107 | cf.final_tokens = 2 * len(aisdatasets["train"]) * cf.max_seqlen 108 | 109 | ## Model 110 | # =============================== 111 | model = models.TrAISformer(cf, partition_model=None) 112 | 113 | ## Trainer 114 | # =============================== 115 | trainer = trainers.Trainer( 116 | model, aisdatasets["train"], aisdatasets["valid"], cf, savedir=cf.savedir, device=cf.device, aisdls=aisdls, INIT_SEQLEN=init_seqlen) 117 | 118 | ## Training 119 | # =============================== 120 | if cf.retrain: 121 | trainer.train() 122 | 123 | ## Evaluation 124 | # =============================== 125 | # Load the best model 126 | model.load_state_dict(torch.load(cf.ckpt_path)) 127 | 128 | v_ranges = torch.tensor([2, 3, 0, 0]).to(cf.device) 129 | v_roi_min = torch.tensor([model.lat_min, -7, 0, 0]).to(cf.device) 130 | max_seqlen = init_seqlen + 6 * 4 131 | 132 | model.eval() 133 | l_min_errors, l_mean_errors, l_masks = [], [], [] 134 | pbar = tqdm(enumerate(aisdls["test"]), total=len(aisdls["test"])) 135 | with torch.no_grad(): 136 | for it, (seqs, masks, seqlens, mmsis, time_starts) in pbar: 137 | seqs_init = seqs[:, :init_seqlen, :].to(cf.device) 138 | masks = masks[:, :max_seqlen].to(cf.device) 139 | batchsize = seqs.shape[0] 140 | error_ens = torch.zeros((batchsize, max_seqlen - cf.init_seqlen, cf.n_samples)).to(cf.device) 141 | for i_sample in range(cf.n_samples): 142 | preds = trainers.sample(model, 143 | seqs_init, 144 | max_seqlen - init_seqlen, 145 | temperature=1.0, 146 | sample=True, 147 | sample_mode=cf.sample_mode, 148 | r_vicinity=cf.r_vicinity, 149 | top_k=cf.top_k) 150 | inputs = seqs[:, :max_seqlen, :].to(cf.device) 151 | input_coords = (inputs * v_ranges + v_roi_min) * torch.pi / 180 152 | pred_coords = (preds * v_ranges + v_roi_min) * torch.pi / 180 153 | d = utils.haversine(input_coords, pred_coords) * masks 154 | error_ens[:, :, i_sample] = d[:, cf.init_seqlen:] 155 | # Accumulation through batches 156 | l_min_errors.append(error_ens.min(dim=-1)) 157 | l_mean_errors.append(error_ens.mean(dim=-1)) 158 | l_masks.append(masks[:, cf.init_seqlen:]) 159 | 160 | l_min = [x.values for x in l_min_errors] 161 | m_masks = torch.cat(l_masks, dim=0) 162 | min_errors = torch.cat(l_min, dim=0) * m_masks 163 | pred_errors = min_errors.sum(dim=0) / m_masks.sum(dim=0) 164 | pred_errors = pred_errors.detach().cpu().numpy() 165 | 166 | ## Plot 167 | # =============================== 168 | plt.figure(figsize=(9, 6), dpi=150) 169 | v_times = np.arange(len(pred_errors)) / 6 170 | plt.plot(v_times, pred_errors) 171 | 172 | timestep = 6 173 | plt.plot(1, pred_errors[timestep], "o") 174 | plt.plot([1, 1], [0, pred_errors[timestep]], "r") 175 | plt.plot([0, 1], [pred_errors[timestep], pred_errors[timestep]], "r") 176 | plt.text(1.12, pred_errors[timestep] - 0.5, "{:.4f}".format(pred_errors[timestep]), fontsize=10) 177 | 178 | timestep = 12 179 | plt.plot(2, pred_errors[timestep], "o") 180 | plt.plot([2, 2], [0, pred_errors[timestep]], "r") 181 | plt.plot([0, 2], [pred_errors[timestep], pred_errors[timestep]], "r") 182 | plt.text(2.12, pred_errors[timestep] - 0.5, "{:.4f}".format(pred_errors[timestep]), fontsize=10) 183 | 184 | timestep = 18 185 | plt.plot(3, pred_errors[timestep], "o") 186 | plt.plot([3, 3], [0, pred_errors[timestep]], "r") 187 | plt.plot([0, 3], [pred_errors[timestep], pred_errors[timestep]], "r") 188 | plt.text(3.12, pred_errors[timestep] - 0.5, "{:.4f}".format(pred_errors[timestep]), fontsize=10) 189 | plt.xlabel("Time (hours)") 190 | plt.ylabel("Prediction errors (km)") 191 | plt.xlim([0, 12]) 192 | plt.ylim([0, 20]) 193 | # plt.ylim([0,pred_errors.max()+0.5]) 194 | plt.savefig(cf.savedir + "prediction_error.png") 195 | 196 | # Yeah, done!!! 197 | -------------------------------------------------------------------------------- /S-Transformer/trainers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021, Duong Nguyen 3 | # 4 | # Licensed under the CECILL-C License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.cecill.info 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Boilerplate for training a neural network. 17 | 18 | References: 19 | https://github.com/karpathy/minGPT 20 | """ 21 | 22 | import os 23 | import math 24 | import logging 25 | 26 | from tqdm import tqdm 27 | import numpy as np 28 | import matplotlib.pyplot as plt 29 | 30 | import torch 31 | import torch.optim as optim 32 | from torch.optim.lr_scheduler import LambdaLR 33 | from torch.utils.data.dataloader import DataLoader 34 | from torch.nn import functional as F 35 | import utils 36 | 37 | from trAISformer import TB_LOG 38 | 39 | logger = logging.getLogger(__name__) 40 | 41 | 42 | @torch.no_grad() 43 | def sample(model, 44 | seqs, 45 | steps, 46 | temperature=1.0, 47 | sample=False, 48 | sample_mode="pos_vicinity", 49 | r_vicinity=20, 50 | top_k=None): 51 | """ 52 | Take a conditoning sequence of AIS observations seq and predict the next observation, 53 | feed the predictions back into the model each time. 54 | """ 55 | max_seqlen = model.get_max_seqlen() 56 | model.eval() 57 | for k in range(steps): 58 | seqs_cond = seqs if seqs.size(1) <= max_seqlen else seqs[:, -max_seqlen:] # crop context if needed 59 | 60 | # logits.shape: (batch_size, seq_len, data_size) 61 | logits, _ = model(seqs_cond) 62 | d2inf_pred = torch.zeros((logits.shape[0], 4)).to(seqs.device) + 0.5 63 | 64 | # pluck the logits at the final step and scale by temperature 65 | logits = logits[:, -1, :] / temperature # (batch_size, data_size) 66 | 67 | lat_logits, lon_logits, sog_logits, cog_logits = \ 68 | torch.split(logits, (model.lat_size, model.lon_size, model.sog_size, model.cog_size), dim=-1) 69 | 70 | # optionally crop probabilities to only the top k options 71 | if sample_mode in ("pos_vicinity",): 72 | idxs, idxs_uniform = model.to_indexes(seqs_cond[:, -1:, :]) 73 | lat_idxs, lon_idxs = idxs_uniform[:, 0, 0:1], idxs_uniform[:, 0, 1:2] 74 | lat_logits = utils.top_k_nearest_idx(lat_logits, lat_idxs, r_vicinity) 75 | lon_logits = utils.top_k_nearest_idx(lon_logits, lon_idxs, r_vicinity) 76 | 77 | if top_k is not None: 78 | lat_logits = utils.top_k_logits(lat_logits, top_k) 79 | lon_logits = utils.top_k_logits(lon_logits, top_k) 80 | sog_logits = utils.top_k_logits(sog_logits, top_k) 81 | cog_logits = utils.top_k_logits(cog_logits, top_k) 82 | 83 | # apply softmax to convert to probabilities 84 | lat_probs = F.softmax(lat_logits, dim=-1) 85 | lon_probs = F.softmax(lon_logits, dim=-1) 86 | sog_probs = F.softmax(sog_logits, dim=-1) 87 | cog_probs = F.softmax(cog_logits, dim=-1) 88 | 89 | # sample from the distribution or take the most likely 90 | if sample: 91 | lat_ix = torch.multinomial(lat_probs, num_samples=1) # (batch_size, 1) 92 | lon_ix = torch.multinomial(lon_probs, num_samples=1) 93 | sog_ix = torch.multinomial(sog_probs, num_samples=1) 94 | cog_ix = torch.multinomial(cog_probs, num_samples=1) 95 | else: 96 | _, lat_ix = torch.topk(lat_probs, k=1, dim=-1) 97 | _, lon_ix = torch.topk(lon_probs, k=1, dim=-1) 98 | _, sog_ix = torch.topk(sog_probs, k=1, dim=-1) 99 | _, cog_ix = torch.topk(cog_probs, k=1, dim=-1) 100 | 101 | ix = torch.cat((lat_ix, lon_ix, sog_ix, cog_ix), dim=-1) 102 | # convert to x (range: [0,1)) 103 | x_sample = (ix.float() + d2inf_pred) / model.att_sizes 104 | 105 | # append to the sequence and continue 106 | seqs = torch.cat((seqs, x_sample.unsqueeze(1)), dim=1) 107 | 108 | return seqs 109 | 110 | 111 | class TrainerConfig: 112 | # optimization parameters 113 | max_epochs = 10 114 | batch_size = 64 115 | learning_rate = 3e-4 116 | betas = (0.9, 0.95) 117 | grad_norm_clip = 1.0 118 | weight_decay = 0.1 # only applied on matmul weights 119 | # learning rate decay params: linear warmup followed by cosine decay to 10% of original 120 | lr_decay = False 121 | warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere 122 | final_tokens = 260e9 # (at what point we reach 10% of original LR) 123 | # checkpoint settings 124 | ckpt_path = None 125 | num_workers = 0 # for DataLoader 126 | 127 | def __init__(self, **kwargs): 128 | for k, v in kwargs.items(): 129 | setattr(self, k, v) 130 | 131 | 132 | class Trainer: 133 | 134 | def __init__(self, model, train_dataset, test_dataset, config, savedir=None, device=torch.device("cpu"), aisdls={}, 135 | INIT_SEQLEN=0): 136 | self.train_dataset = train_dataset 137 | self.test_dataset = test_dataset 138 | self.config = config 139 | self.savedir = savedir 140 | 141 | self.device = device 142 | self.model = model.to(device) 143 | self.aisdls = aisdls 144 | self.INIT_SEQLEN = INIT_SEQLEN 145 | 146 | def save_checkpoint(self, best_epoch): 147 | # DataParallel wrappers keep raw model object in .module attribute 148 | raw_model = self.model.module if hasattr(self.model, "module") else self.model 149 | # logging.info("saving %s", self.config.ckpt_path) 150 | logging.info(f"Best epoch: {best_epoch:03d}, saving model to {self.config.ckpt_path}") 151 | torch.save(raw_model.state_dict(), self.config.ckpt_path) 152 | 153 | def train(self): 154 | model, config, aisdls, INIT_SEQLEN, = self.model, self.config, self.aisdls, self.INIT_SEQLEN 155 | raw_model = model.module if hasattr(self.model, "module") else model 156 | optimizer = raw_model.configure_optimizers(config) 157 | if model.mode in ("gridcont_gridsin", "gridcont_gridsigmoid", "gridcont2_gridsigmoid",): 158 | return_loss_tuple = True 159 | else: 160 | return_loss_tuple = False 161 | 162 | def run_epoch(split, epoch=0): 163 | is_train = split == 'Training' 164 | model.train(is_train) 165 | data = self.train_dataset if is_train else self.test_dataset 166 | loader = DataLoader(data, shuffle=True, pin_memory=True, 167 | batch_size=config.batch_size, 168 | num_workers=config.num_workers) 169 | 170 | losses = [] 171 | n_batches = len(loader) 172 | pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader) 173 | d_loss, d_reg_loss, d_n = 0, 0, 0 174 | for it, (seqs, masks, seqlens, mmsis, time_starts) in pbar: 175 | 176 | # place data on the correct device 177 | seqs = seqs.to(self.device) 178 | masks = masks[:, :-1].to(self.device) 179 | 180 | # forward the model 181 | with torch.set_grad_enabled(is_train): 182 | if return_loss_tuple: 183 | logits, loss, loss_tuple = model(seqs, 184 | masks=masks, 185 | with_targets=True, 186 | return_loss_tuple=return_loss_tuple) 187 | else: 188 | logits, loss = model(seqs, masks=masks, with_targets=True) 189 | loss = loss.mean() # collapse all losses if they are scattered on multiple gpus 190 | losses.append(loss.item()) 191 | 192 | d_loss += loss.item() * seqs.shape[0] 193 | if return_loss_tuple: 194 | reg_loss = loss_tuple[-1] 195 | reg_loss = reg_loss.mean() 196 | d_reg_loss += reg_loss.item() * seqs.shape[0] 197 | d_n += seqs.shape[0] 198 | if is_train: 199 | 200 | # backprop and update the parameters 201 | model.zero_grad() 202 | loss.backward() 203 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) 204 | optimizer.step() 205 | 206 | # decay the learning rate based on our progress 207 | if config.lr_decay: 208 | self.tokens += ( 209 | seqs >= 0).sum() # number of tokens processed this step (i.e. label is not -100) 210 | if self.tokens < config.warmup_tokens: 211 | # linear warmup 212 | lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens)) 213 | else: 214 | # cosine learning rate decay 215 | progress = float(self.tokens - config.warmup_tokens) / float( 216 | max(1, config.final_tokens - config.warmup_tokens)) 217 | lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress))) 218 | lr = config.learning_rate * lr_mult 219 | for param_group in optimizer.param_groups: 220 | param_group['lr'] = lr 221 | else: 222 | lr = config.learning_rate 223 | 224 | # report progress 225 | pbar.set_description(f"epoch {epoch + 1} iter {it}: loss {loss.item():.5f}. lr {lr:e}") 226 | 227 | # tb logging 228 | if TB_LOG: 229 | tb.add_scalar("loss", 230 | loss.item(), 231 | epoch * n_batches + it) 232 | tb.add_scalar("lr", 233 | lr, 234 | epoch * n_batches + it) 235 | 236 | for name, params in model.head.named_parameters(): 237 | tb.add_histogram(f"head.{name}", params, epoch * n_batches + it) 238 | tb.add_histogram(f"head.{name}.grad", params.grad, epoch * n_batches + it) 239 | if model.mode in ("gridcont_real",): 240 | for name, params in model.res_pred.named_parameters(): 241 | tb.add_histogram(f"res_pred.{name}", params, epoch * n_batches + it) 242 | tb.add_histogram(f"res_pred.{name}.grad", params.grad, epoch * n_batches + it) 243 | 244 | if is_train: 245 | if return_loss_tuple: 246 | logging.info( 247 | f"{split}, epoch {epoch + 1}, loss {d_loss / d_n:.5f}, {d_reg_loss / d_n:.5f}, lr {lr:e}.") 248 | else: 249 | logging.info(f"{split}, epoch {epoch + 1}, loss {d_loss / d_n:.5f}, lr {lr:e}.") 250 | else: 251 | if return_loss_tuple: 252 | logging.info(f"{split}, epoch {epoch + 1}, loss {d_loss / d_n:.5f}.") 253 | else: 254 | logging.info(f"{split}, epoch {epoch + 1}, loss {d_loss / d_n:.5f}.") 255 | 256 | if not is_train: 257 | test_loss = float(np.mean(losses)) 258 | # logging.info("test loss: %f", test_loss) 259 | return test_loss 260 | 261 | best_loss = float('inf') 262 | self.tokens = 0 # counter used for learning rate decay 263 | best_epoch = 0 264 | 265 | for epoch in range(config.max_epochs): 266 | 267 | run_epoch('Training', epoch=epoch) 268 | if self.test_dataset is not None: 269 | test_loss = run_epoch('Valid', epoch=epoch) 270 | 271 | # supports early stopping based on the test loss, or just save always if no test set is provided 272 | good_model = self.test_dataset is None or test_loss < best_loss 273 | if self.config.ckpt_path is not None and good_model: 274 | best_loss = test_loss 275 | best_epoch = epoch 276 | self.save_checkpoint(best_epoch + 1) 277 | 278 | ## SAMPLE AND PLOT 279 | # ========================================================================================== 280 | # ========================================================================================== 281 | raw_model = model.module if hasattr(self.model, "module") else model 282 | seqs, masks, seqlens, mmsis, time_starts = iter(aisdls["test"]).next() 283 | n_plots = 7 284 | init_seqlen = INIT_SEQLEN 285 | seqs_init = seqs[:n_plots, :init_seqlen, :].to(self.device) 286 | preds = sample(raw_model, 287 | seqs_init, 288 | 96 - init_seqlen, 289 | temperature=1.0, 290 | sample=True, 291 | sample_mode=self.config.sample_mode, 292 | r_vicinity=self.config.r_vicinity, 293 | top_k=self.config.top_k) 294 | 295 | img_path = os.path.join(self.savedir, f'epoch_{epoch + 1:03d}.jpg') 296 | plt.figure(figsize=(9, 6), dpi=150) 297 | cmap = plt.cm.get_cmap("jet") 298 | preds_np = preds.detach().cpu().numpy() 299 | inputs_np = seqs.detach().cpu().numpy() 300 | for idx in range(n_plots): 301 | c = cmap(float(idx) / (n_plots)) 302 | try: 303 | seqlen = seqlens[idx].item() 304 | except: 305 | continue 306 | plt.plot(inputs_np[idx][:init_seqlen, 1], inputs_np[idx][:init_seqlen, 0], color=c) 307 | plt.plot(inputs_np[idx][:init_seqlen, 1], inputs_np[idx][:init_seqlen, 0], "o", markersize=3, color=c) 308 | plt.plot(inputs_np[idx][:seqlen, 1], inputs_np[idx][:seqlen, 0], linestyle="-.", color=c) 309 | plt.plot(preds_np[idx][init_seqlen:, 1], preds_np[idx][init_seqlen:, 0], "x", markersize=4, color=c) 310 | plt.xlim([-0.05, 1.05]) 311 | plt.ylim([-0.05, 1.05]) 312 | plt.savefig(img_path, dpi=150) 313 | plt.close() 314 | 315 | # Final state 316 | raw_model = self.model.module if hasattr(self.model, "module") else self.model 317 | # logging.info("saving %s", self.config.ckpt_path) 318 | logging.info(f"Last epoch: {epoch:03d}, saving model to {self.config.ckpt_path}") 319 | save_path = self.config.ckpt_path.replace("model.pt", f"model_{epoch + 1:03d}.pt") 320 | torch.save(raw_model.state_dict(), save_path) 321 | -------------------------------------------------------------------------------- /S-Transformer/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021, Duong Nguyen 3 | # 4 | # Licensed under the CECILL-C License; 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.cecill.info 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility functions for GPTrajectory. 17 | 18 | References: 19 | https://github.com/karpathy/minGPT 20 | """ 21 | import numpy as np 22 | import os 23 | import math 24 | import logging 25 | import random 26 | import datetime 27 | import socket 28 | 29 | 30 | import torch 31 | import torch.nn as nn 32 | from torch.nn import functional as F 33 | torch.pi = torch.acos(torch.zeros(1)).item()*2 34 | 35 | 36 | def set_seed(seed): 37 | random.seed(seed) 38 | np.random.seed(seed) 39 | torch.manual_seed(seed) 40 | torch.cuda.manual_seed_all(seed) 41 | torch.backends.cudnn.deterministic = True 42 | 43 | 44 | def new_log(logdir,filename): 45 | """Defines logging format. 46 | """ 47 | filename = os.path.join(logdir, 48 | datetime.datetime.now().strftime("log_%Y-%m-%d-%H-%M-%S_"+socket.gethostname()+"_"+filename+".log")) 49 | logging.basicConfig(level=logging.INFO, 50 | filename=filename, 51 | format="%(asctime)s - %(name)s - %(message)s", 52 | filemode="w") 53 | console = logging.StreamHandler() 54 | console.setLevel(logging.INFO) 55 | formatter = logging.Formatter("%(asctime)s - %(name)s - %(message)s") 56 | console.setFormatter(formatter) 57 | logging.getLogger('').addHandler(console) 58 | 59 | def haversine(input_coords, 60 | pred_coords): 61 | """ Calculate the haversine distances between input_coords and pred_coords. 62 | 63 | Args: 64 | input_coords, pred_coords: Tensors of size (...,N), with (...,0) and (...,1) are 65 | the latitude and longitude in radians. 66 | 67 | Returns: 68 | The havesine distances between 69 | """ 70 | R = 6371 71 | lat_errors = pred_coords[...,0] - input_coords[...,0] 72 | lon_errors = pred_coords[...,1] - input_coords[...,1] 73 | a = torch.sin(lat_errors/2)**2\ 74 | +torch.cos(input_coords[:,:,0])*torch.cos(pred_coords[:,:,0])*torch.sin(lon_errors/2)**2 75 | c = 2*torch.atan2(torch.sqrt(a),torch.sqrt(1-a)) 76 | d = R*c 77 | return d 78 | 79 | def top_k_logits(logits, k): 80 | v, ix = torch.topk(logits, k) 81 | out = logits.clone() 82 | out[out < v[:, [-1]]] = -float('Inf') 83 | return out 84 | 85 | def top_k_nearest_idx(att_logits, att_idxs, r_vicinity): 86 | """Keep only k values nearest the current idx. 87 | 88 | Args: 89 | att_logits: a Tensor of shape (bachsize, data_size). 90 | att_idxs: a Tensor of shape (bachsize, 1), indicates 91 | the current idxs. 92 | r_vicinity: number of values to be kept. 93 | """ 94 | device = att_logits.device 95 | idx_range = torch.arange(att_logits.shape[-1]).to(device).repeat(att_logits.shape[0],1) 96 | idx_dists = torch.abs(idx_range - att_idxs) 97 | out = att_logits.clone() 98 | out[idx_dists >= r_vicinity/2] = -float('Inf') 99 | return out -------------------------------------------------------------------------------- /关系矩阵_热力图.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | import matplotlib.pyplot as plt 3 | import pandas as pd 4 | 5 | # 读取数据 6 | df = pd.read_csv('轨迹提取文件夹/205609000/subtrack_12.csv') 7 | 8 | # 选择分析的特征 9 | features = ['LAT', 'LON', 'SOG', 'COG', 'Heading'] 10 | 11 | # 使用Seaborn的pairplot函数创建成对关系图 12 | sns.pairplot(df[features]) 13 | plt.show() # 显示成对关系图 14 | 15 | # 计算斯皮尔曼相关性系数 16 | spearman_corr = df[features].corr(method='spearman') 17 | 18 | # 创建热力图,使用'GnBu'配色方案 19 | plt.figure(figsize=(10, 8)) # 调整图的大小以适应你的需要 20 | sns.heatmap(spearman_corr, annot=True, cmap='GnBu', square=True, linewidths=.5) 21 | plt.title('Spearman Correlation Heatmap') 22 | plt.show() # 显示热力图 23 | """ 24 | 热力图颜色类型 25 | 'Blues', 'BuGn', 'BuPu' 26 | 'GnBu', 'Greens', 'Greys' 27 | 'Oranges', 'OrRd', 'PuBu' 28 | 'PuBuGn', 'PuRd', 'Purples' 29 | 'RdPu', 'Reds', 'YlGn' 30 | 'YlGnBu', 'YlOrBr', 'YlOrRd' 31 | """ -------------------------------------------------------------------------------- /插值处理.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from scipy.interpolate import lagrange 5 | plt.rcParams['font.sans-serif'] = ['SimHei'] # 把字体设置为SimHei 6 | plt.rcParams['axes.unicode_minus'] = False 7 | # 读取原始数据 8 | df = pd.read_csv('轨迹提取文件夹/210959000/subtrack_5.csv', parse_dates=['BaseDateTime']) 9 | df.sort_values('BaseDateTime', inplace=True) 10 | df.set_index('BaseDateTime', inplace=True) 11 | 12 | # 创建新的时间索引 13 | new_index = pd.date_range(start=df.index.min(), end=df.index.max(), freq='2T') 14 | 15 | # 三次样条插值函数 16 | def cubic_spline_interpolation(df): 17 | df_resampled = df.resample('2T').mean().interpolate('cubic') 18 | return df_resampled 19 | 20 | # 应用三次样条插值 21 | df_cubic = cubic_spline_interpolation(df[['LAT', 'LON']]) 22 | 23 | # 加权移动平均插值函数 24 | def weighted_moving_average_interpolation(df, k, weights): 25 | if len(weights) != 2 * k: 26 | raise ValueError("权重的数量必须是 2k。") 27 | 28 | df_wma = pd.DataFrame(index=df.index, columns=df.columns) 29 | for col in df.columns: 30 | interpolated_values = [] 31 | for t in df.index: 32 | window = df[col].rolling(window=2*k, center=True).apply(lambda x: np.dot(x, weights) / sum(weights), raw=True) 33 | window = window.reindex(df.index, method='nearest') 34 | interpolated_values.append(window.loc[t]) 35 | df_wma[col] = interpolated_values 36 | return df_wma 37 | 38 | # 设置 k 和权重 39 | k = 2 40 | weights = [1, 2, 2, 1] 41 | 42 | # 应用加权移动平均插值 43 | df_weighted = weighted_moving_average_interpolation(df[['LAT', 'LON']], k, weights) 44 | 45 | # 保存插值结果 46 | df_weighted.to_csv('插值文件夹/均值插值.csv') 47 | df_cubic.to_csv('插值文件夹/三次样条.csv') 48 | 49 | # 绘制散点图比较 50 | fig, axs = plt.subplots(2, 2, figsize=(10, 18)) # 创建三个子图 51 | # 调整subplot布局 52 | plt.subplots_adjust(top=1.5, bottom=0.1, left=0.1, right=0.9, hspace=0.5, wspace=0.5) 53 | # 原始数据散点图 54 | axs[0, 0].scatter(df['LON'], df['LAT'], alpha=0.5, color='blue') 55 | axs[0, 0].set_title('原始数据') 56 | axs[0, 0].set_xlabel('LON') 57 | axs[0, 0].set_ylabel('LAT') 58 | axs[0, 0].grid(True) 59 | 60 | # 三次样条插值数据散点图 61 | axs[0, 1].scatter(df_cubic['LON'], df_cubic['LAT'], alpha=0.5, color='green') 62 | axs[0, 1].set_title('三次样条插值') 63 | axs[0, 1].set_xlabel('LON') 64 | axs[0, 1].set_ylabel('LAT') 65 | axs[0, 1].grid(True) 66 | 67 | # 加权移动平均插值数据散点图 68 | axs[1, 0].scatter(df_weighted['LON'], df_weighted['LAT'], alpha=0.5, color='red') 69 | axs[1, 0].set_title('加权移动平均插值') 70 | axs[1, 0].set_xlabel('LON') 71 | axs[1, 0].set_ylabel('LAT') 72 | axs[1, 0].grid(True) 73 | 74 | plt.tight_layout() 75 | plt.show() 76 | 77 | print("运行完毕") 78 | -------------------------------------------------------------------------------- /数据清洗.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | 4 | # 文件夹路径 5 | input_folder = '数据集' 6 | output_folder = '处理完毕的数据集' 7 | 8 | # 如果输出文件夹不存在,则创建它 9 | if not os.path.exists(output_folder): 10 | os.makedirs(output_folder) 11 | 12 | # 遍历文件夹中的所有CSV文件 13 | for filename in os.listdir(input_folder): 14 | if filename.endswith('.csv'): 15 | file_path = os.path.join(input_folder, filename) 16 | 17 | # Load the dataset 18 | df = pd.read_csv(file_path) 19 | 20 | # 1. Sort by MMSI and BaseDateTime 21 | df.sort_values(by=['MMSI', 'BaseDateTime'], inplace=True) 22 | 23 | # 2. Remove records where MMSI is not 9 digits 24 | df = df[df['MMSI'].apply(lambda x: len(str(x)) == 9)] 25 | 26 | # 3. Filter for 'normal' navigation statuses 27 | """normal_statuses = ['under engine', 'at anchor', 'not under command', 'restricted manoeuvrability', 28 | 'moored', 'aground', 'fishing', 'under way'] 29 | df = df[df['Status'].isin(normal_statuses)]""" 30 | 31 | # 4. Remove records where Length < 3 or Width < 2 32 | df = df[(df['Length'] >= 3) & (df['Width'] >= 2)] 33 | 34 | # 5. Remove records with out-of-range values 35 | df = df[(df['LON'] >= -180.0) & (df['LON'] <= 180.0)] 36 | df = df[(df['LAT'] >= -90.0) & (df['LAT'] <= 90.0)] 37 | df = df[(df['SOG'] >= 0) & (df['SOG'] <= 51.2)] 38 | df = df[(df['COG'] >= -204.7) & (df['COG'] <= 204.8)] 39 | 40 | # 6. Remove records where SOG is zero for five consecutive points 41 | df['zero_sog'] = (df['SOG'] == 0).astype(int) 42 | df['rolling_sum'] = df.groupby('MMSI')['zero_sog'].rolling(window=5).sum().reset_index(level=0, drop=True) 43 | df = df[df['rolling_sum'] < 5] 44 | df.drop(columns=['zero_sog', 'rolling_sum'], inplace=True) 45 | 46 | # 保存处理后的数据集 47 | new_filename = 'new_' + filename 48 | output_path = os.path.join(output_folder, new_filename) 49 | df.to_csv(output_path, index=False) 50 | 51 | print(f"清洗完毕,已经保存到 '{output_path}'.") -------------------------------------------------------------------------------- /箱形图.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | 5 | input_folder = 'D:\\AA_work\\数据预处理\\处理完毕的数据集' # 设置你的CSV文件所在文件夹路径 6 | 7 | # 初始化一个空的DataFrame列表 8 | data_frames = [] 9 | 10 | # 遍历文件夹中的所有CSV文件 11 | for filename in os.listdir(input_folder): 12 | if filename.endswith('.csv'): 13 | file_path = os.path.join(input_folder, filename) 14 | df = pd.read_csv(file_path) 15 | data_frames.append(df) # 将每个文件的DataFrame添加到列表中 16 | 17 | # 合并所有的DataFrame为一个大的DataFrame 18 | combined_df = pd.concat(data_frames, ignore_index=True) 19 | 20 | # 创建箱形图 21 | fig, axs = plt.subplots(3, 2, figsize=(12, 9)) # 修改图形的大小以避免挤压和重叠 22 | 23 | # 定义一个绘制箱形图的函数 24 | def draw_boxplot(column, title, ax, ylim=None): 25 | ax.boxplot(combined_df[column].dropna(), 26 | patch_artist=True, # 启用补丁艺术家模式,允许填充箱体 27 | boxprops={'color': 'blue', 'facecolor': 'lightblue'}, # 箱体轮廓和填充颜色 28 | flierprops={'marker': 'o', 'markerfacecolor': 'red', 'markersize': 1, 'markeredgecolor': 'red'}, # 异常点样式 29 | whiskerprops={'color': 'blue'}, # 括号颜色 30 | capprops={'color': 'blue'}, # 条帽颜色 31 | medianprops={'color': 'red'}) # 中位线颜色设置为红色 32 | ax.set_title(title) 33 | ax.set_ylim(ylim) # 设置y轴范围,如果提供了ylim参数 34 | ax.grid(True) # 添加网格线 35 | 36 | 37 | # 绘制箱形图 38 | draw_boxplot('LON', '经度', axs[0, 0]) # 指定y轴范围 39 | draw_boxplot('LAT', '纬度', axs[0, 1]) 40 | draw_boxplot('SOG', '对地航速', axs[1, 0]) 41 | draw_boxplot('COG', '航向', axs[1, 1]) 42 | draw_boxplot('Heading', '航向角', axs[2, 0]) 43 | 44 | # 删除多余的子图 45 | fig.delaxes(axs[2, 1]) 46 | 47 | fig.tight_layout() # 自动调整子图参数, 使之填充整个图像区域 48 | plt.show() -------------------------------------------------------------------------------- /轨迹平稳性检验.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | 4 | plt.rcParams['font.sans-serif'] = ['SimHei'] # 把字体设置为SimHei 5 | plt.rcParams['axes.unicode_minus'] = False 6 | # 读取数据 7 | df = pd.read_csv('轨迹提取文件夹/205609000/subtrack_12.csv') 8 | 9 | # 绘制LAT, LON, SOG, COG随时间的变化 10 | plt.figure(figsize=(15, 10)) # 设置图形的大小 11 | 12 | # 绘制LAT 13 | plt.subplot(2, 2, 1) # 2行2列的子图中的第1个 14 | plt.plot(df.index, df['LAT'], label='LAT') 15 | plt.xlabel('序列数') 16 | plt.ylabel('LAT') 17 | plt.title('(a) 纬度序列分布') 18 | 19 | # 绘制LON 20 | plt.subplot(2, 2, 2) # 2行2列的子图中的第2个 21 | plt.plot(df.index, df['LON'], label='LON', color='orange') 22 | plt.xlabel('序列数') 23 | plt.ylabel('LON') 24 | plt.title('(b) 经度序列分布') 25 | 26 | # 绘制SOG 27 | plt.subplot(2, 2, 3) # 2行2列的子图中的第3个 28 | plt.plot(df.index, df['SOG'], label='SOG', color='green') 29 | plt.xlabel('序列数') 30 | plt.ylabel('SOG') 31 | plt.title('(c) 对地航速序列分布') 32 | 33 | # 绘制COG 34 | plt.subplot(2, 2, 4) # 2行2列的子图中的第4个 35 | plt.plot(df.index, df['COG'], label='COG', color='red') 36 | plt.xlabel('序列数') 37 | plt.ylabel('COG') 38 | plt.title('(d) 对地航向序列分布') 39 | 40 | # 调整子图间的间距 41 | plt.tight_layout() 42 | plt.show() 43 | -------------------------------------------------------------------------------- /轨迹提取.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | from datetime import datetime 4 | 5 | # 原始数据文件夹路径 6 | data_folder_path = '处理完毕的数据集' 7 | # 结果保存的基本路径 8 | output_base_path = '轨迹提取文件夹' 9 | 10 | # 读取所有CSV文件 11 | all_files = [os.path.join(data_folder_path, f) for f in os.listdir(data_folder_path) if f.endswith('.csv')] 12 | all_data = [pd.read_csv(file) for file in all_files] 13 | 14 | # 合并DataFrame 15 | df = pd.concat(all_data, ignore_index=True) 16 | 17 | # 确保BaseDateTime是datetime类型 18 | df['BaseDateTime'] = pd.to_datetime(df['BaseDateTime']) 19 | 20 | # 按MMSI分组并排序 21 | grouped = df.sort_values('BaseDateTime').groupby('MMSI') 22 | 23 | # 处理每个MMSI的数据 24 | for mmsi, group in grouped: 25 | # 创建MMSI对应的目录在新的输出路径下 26 | mmsi_folder = os.path.join(output_base_path, str(mmsi)) 27 | if not os.path.exists(mmsi_folder): 28 | os.makedirs(mmsi_folder) 29 | 30 | start_index = 0 31 | subtrack_number = 0 32 | for i in range(1, len(group)): 33 | # 计算时间差 34 | time_diff = (group.iloc[i]['BaseDateTime'] - group.iloc[i - 1]['BaseDateTime']).total_seconds() / 60 35 | if time_diff > 30: 36 | # 保存子轨迹 37 | sub_df = group.iloc[start_index:i] 38 | sub_df.to_csv(os.path.join(mmsi_folder, f'subtrack_{subtrack_number}.csv'), index=False) 39 | subtrack_number += 1 40 | start_index = i 41 | # 保存最后一个子轨迹 42 | sub_df = group.iloc[start_index:] 43 | sub_df.to_csv(os.path.join(mmsi_folder, f'subtrack_{subtrack_number}.csv'), index=False) 44 | 45 | --------------------------------------------------------------------------------