├── README.md
├── README图片
├── 关系矩阵图.png
├── 序列分布情况.png
├── 插值结果.png
├── 热力图.png
├── 箱状图.png
├── 船舶子轨迹.png
└── 轨迹提取MMSI.png
├── S-Transformer
├── LICENSE
├── README.md
├── Untitled.ipynb
├── config_trAISformer.py
├── datasets.py
├── figures
│ └── t18_3.png
├── models.py
├── requirements.txt
├── requirements.yml
├── trAISformer.py
├── trainers.py
└── utils.py
├── 关系矩阵_热力图.py
├── 插值处理.py
├── 数据清洗.py
├── 箱形图.py
├── 轨迹平稳性检验.py
└── 轨迹提取.py
/README.md:
--------------------------------------------------------------------------------
1 | ## 船舶轨迹数据预处理及分析
2 |
3 | #### 介绍
4 |
5 | 此项目是对AIS数据进行预处理以及分析,一共包含三部分。(注意:使用时请将csv文件下到一个文件夹下,或者自行修改代码,本项目用的数据集为美国 MarineCadastre.gov 网站上提供的。)
6 |
7 | - 第一部分:数据清洗及轨迹提取
8 | - 第二部分:特征相关性与轨迹平稳性分析
9 | - 第三部分:轨迹修复
10 |
11 | #### 一、数据清洗及轨迹提取
12 |
13 | ##### 1.数据清洗(数据清洗.py 箱状图.py)
14 |
15 | 在 AIS 信号的发送、传输、接收过程中,数据难免会出现中断或缺失,且在现实环境中的 AIS 数据会有冗余、异常等情况,所以首先对数据进行清洗工作。
16 |
17 | 数据清洗过程及理由如下:
18 |
19 | (1)首先将 AIS 数据按 MMSI 和时间升幂排序。将原始数据按 MMSI 和时间先后顺序有利于进行后面的数据清洗。
20 |
21 | (2)删除 MMSI 不为 9 位的数据。排完序后发现有的船舶的 MMSI 不为 9 位数,因为船舶的 MMSI 唯一识别码是由 9 位数组成的,不符合规定的删除。
22 |
23 | (3)选择航行状态为正常航行中的数据。选择船舶的航行状态有发动机使用中、锚泊、未操纵、吃水受限、系泊、搁浅、捕捞、航行中,都统称为正常航行中。
24 |
25 | (4)删除船长小于 3 和船宽小于 2 的船舶数据。由于较小的船只的运动轨迹更易受海流、风浪等环境因子的影响,其数据的存在也可能会造成模型出现过拟合,故需剔除过小的船。
26 |
27 | (5)删除超出有效范围的经度、维度、对地航速、对地航向数据。根据AIS 数据有效范围可知超出参照数据的十进制表示有效范围是没有意义的,应当舍去。
28 |
29 | 经度(LON) -180.00000 ~ 180.00000
30 |
31 | 纬度(LAT) -90.00000 ~ 90.00000
32 |
33 | 对地航速(SOG) 0~51.2
34 |
35 | 对地航向(COG) -204.7~204.8
36 |
37 | (6)删除对地航速连续 5 个及以上为 0 的数据点。根据初步判断,当对地航速出现连续 5 个时刻为 0 时,将该点视为停泊点,如果不删除对后续轨迹研究造成影响。
38 |
39 | 清洗后的数据,会输出到一个新文件夹。之后可以使用**箱状图.py**文件画AIS数据的箱状图查看数据分布情况。
40 |
41 |
42 |
43 | ##### 2.轨迹提取(轨迹提取.py)
44 |
45 | 将所有 AIS 数据按照 MMSI 分成多个组,并对每一组中的轨迹点按 BaseDateTime递增的顺序进行排列,便可得到同一 MMSI 代表的船舶在一个时间段内的航行轨迹数据集。
46 |
47 | 当同一条船舶相邻两个轨迹点的时间差超过 30min 时,将相邻两个轨迹点分别视为不同航迹的终点和起点,这就把一条船舶轨迹划分为了多条船舶子轨迹。
48 |
49 | 运行文件后会输出一个文件,文件夹下是以MMSI命名的多个文件夹,每个文件夹里是划出的船舶子轨迹。
50 |
51 |
52 |
53 |
54 |
55 | #### 二、特征相关性与轨迹平稳性分析
56 |
57 | ##### 1.特征相关性(关系矩阵_热力图.py)
58 |
59 | 读取船舶子轨迹,输出各特征之间的关系矩阵图,各数据特征值相关系数热力图。
60 |
61 |
62 |
63 |
64 |
65 | ##### 2.轨迹平稳性(轨迹平稳性检验.py)
66 |
67 | 船舶轨迹序列是一个典型的多变量时间序列,在量化过程中使用时间序列分析工具时,经常需要先考察轨迹序列的平稳性,以便选择合适的轨迹预测方法。此代码直观地展示了船舶轨迹四个特征的序列分布情况。
68 |
69 |
70 |
71 | #### 三、轨迹修复(插值处理.py)
72 |
73 | 在清洗和轨迹提取阶段,可能会舍弃一些轨迹点,这可能导致航迹序列中前后的时间间隔不均匀,此外,由于 AIS 设备的发送频率不等,不同频率下接收到的轨迹序列的时间间隔也会有明显的差异。对于一条航迹序列而言,如果时间间隔太长,将会对后续的算法模型的训练产生极大程度的影响,因此,为了平滑轨迹点,采用三次样条插值函数,加权移动平均插值函数作为参考方法。
74 |
75 | 插值处理.py会输出两种函数的插值结果到新的文件,之后会绘制散点图直观的对比结果。
76 |
77 |
78 |
79 | 如果需要AIS轨迹图,时间-经纬度图可以点击链接异步我另一个仓库 ([点这里](https://github.com/axyqdm/Track-visualization))
--------------------------------------------------------------------------------
/README图片/关系矩阵图.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/axyqdm/Ship-trajectory-data-preprocessing-and-analysis/fa2c61a3d177b13f60e2e615fecf4b80f7db6c7c/README图片/关系矩阵图.png
--------------------------------------------------------------------------------
/README图片/序列分布情况.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/axyqdm/Ship-trajectory-data-preprocessing-and-analysis/fa2c61a3d177b13f60e2e615fecf4b80f7db6c7c/README图片/序列分布情况.png
--------------------------------------------------------------------------------
/README图片/插值结果.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/axyqdm/Ship-trajectory-data-preprocessing-and-analysis/fa2c61a3d177b13f60e2e615fecf4b80f7db6c7c/README图片/插值结果.png
--------------------------------------------------------------------------------
/README图片/热力图.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/axyqdm/Ship-trajectory-data-preprocessing-and-analysis/fa2c61a3d177b13f60e2e615fecf4b80f7db6c7c/README图片/热力图.png
--------------------------------------------------------------------------------
/README图片/箱状图.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/axyqdm/Ship-trajectory-data-preprocessing-and-analysis/fa2c61a3d177b13f60e2e615fecf4b80f7db6c7c/README图片/箱状图.png
--------------------------------------------------------------------------------
/README图片/船舶子轨迹.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/axyqdm/Ship-trajectory-data-preprocessing-and-analysis/fa2c61a3d177b13f60e2e615fecf4b80f7db6c7c/README图片/船舶子轨迹.png
--------------------------------------------------------------------------------
/README图片/轨迹提取MMSI.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/axyqdm/Ship-trajectory-data-preprocessing-and-analysis/fa2c61a3d177b13f60e2e615fecf4b80f7db6c7c/README图片/轨迹提取MMSI.png
--------------------------------------------------------------------------------
/S-Transformer/LICENSE:
--------------------------------------------------------------------------------
1 | CeCILL-C FREE SOFTWARE LICENSE AGREEMENT
2 |
3 |
4 | Notice
5 |
6 | This Agreement is a Free Software license agreement that is the result
7 | of discussions between its authors in order to ensure compliance with
8 | the two main principles guiding its drafting:
9 |
10 | * firstly, compliance with the principles governing the distribution
11 | of Free Software: access to source code, broad rights granted to
12 | users,
13 | * secondly, the election of a governing law, French law, with which
14 | it is conformant, both as regards the law of torts and
15 | intellectual property law, and the protection that it offers to
16 | both authors and holders of the economic rights over software.
17 |
18 | The authors of the CeCILL-C (for Ce[a] C[nrs] I[nria] L[ogiciel] L[ibre])
19 | license are:
20 |
21 | Commissariat à l'Energie Atomique - CEA, a public scientific, technical
22 | and industrial research establishment, having its principal place of
23 | business at 25 rue Leblanc, immeuble Le Ponant D, 75015 Paris, France.
24 |
25 | Centre National de la Recherche Scientifique - CNRS, a public scientific
26 | and technological establishment, having its principal place of business
27 | at 3 rue Michel-Ange, 75794 Paris cedex 16, France.
28 |
29 | Institut National de Recherche en Informatique et en Automatique -
30 | INRIA, a public scientific and technological establishment, having its
31 | principal place of business at Domaine de Voluceau, Rocquencourt, BP
32 | 105, 78153 Le Chesnay cedex, France.
33 |
34 |
35 | Preamble
36 |
37 | The purpose of this Free Software license agreement is to grant users
38 | the right to modify and re-use the software governed by this license.
39 |
40 | The exercising of this right is conditional upon the obligation to make
41 | available to the community the modifications made to the source code of
42 | the software so as to contribute to its evolution.
43 |
44 | In consideration of access to the source code and the rights to copy,
45 | modify and redistribute granted by the license, users are provided only
46 | with a limited warranty and the software's author, the holder of the
47 | economic rights, and the successive licensors only have limited liability.
48 |
49 | In this respect, the risks associated with loading, using, modifying
50 | and/or developing or reproducing the software by the user are brought to
51 | the user's attention, given its Free Software status, which may make it
52 | complicated to use, with the result that its use is reserved for
53 | developers and experienced professionals having in-depth computer
54 | knowledge. Users are therefore encouraged to load and test the
55 | suitability of the software as regards their requirements in conditions
56 | enabling the security of their systems and/or data to be ensured and,
57 | more generally, to use and operate it in the same conditions of
58 | security. This Agreement may be freely reproduced and published,
59 | provided it is not altered, and that no provisions are either added or
60 | removed herefrom.
61 |
62 | This Agreement may apply to any or all software for which the holder of
63 | the economic rights decides to submit the use thereof to its provisions.
64 |
65 |
66 | Article 1 - DEFINITIONS
67 |
68 | For the purpose of this Agreement, when the following expressions
69 | commence with a capital letter, they shall have the following meaning:
70 |
71 | Agreement: means this license agreement, and its possible subsequent
72 | versions and annexes.
73 |
74 | Software: means the software in its Object Code and/or Source Code form
75 | and, where applicable, its documentation, "as is" when the Licensee
76 | accepts the Agreement.
77 |
78 | Initial Software: means the Software in its Source Code and possibly its
79 | Object Code form and, where applicable, its documentation, "as is" when
80 | it is first distributed under the terms and conditions of the Agreement.
81 |
82 | Modified Software: means the Software modified by at least one
83 | Integrated Contribution.
84 |
85 | Source Code: means all the Software's instructions and program lines to
86 | which access is required so as to modify the Software.
87 |
88 | Object Code: means the binary files originating from the compilation of
89 | the Source Code.
90 |
91 | Holder: means the holder(s) of the economic rights over the Initial
92 | Software.
93 |
94 | Licensee: means the Software user(s) having accepted the Agreement.
95 |
96 | Contributor: means a Licensee having made at least one Integrated
97 | Contribution.
98 |
99 | Licensor: means the Holder, or any other individual or legal entity, who
100 | distributes the Software under the Agreement.
101 |
102 | Integrated Contribution: means any or all modifications, corrections,
103 | translations, adaptations and/or new functions integrated into the
104 | Source Code by any or all Contributors.
105 |
106 | Related Module: means a set of sources files including their
107 | documentation that, without modification to the Source Code, enables
108 | supplementary functions or services in addition to those offered by the
109 | Software.
110 |
111 | Derivative Software: means any combination of the Software, modified or
112 | not, and of a Related Module.
113 |
114 | Parties: mean both the Licensee and the Licensor.
115 |
116 | These expressions may be used both in singular and plural form.
117 |
118 |
119 | Article 2 - PURPOSE
120 |
121 | The purpose of the Agreement is the grant by the Licensor to the
122 | Licensee of a non-exclusive, transferable and worldwide license for the
123 | Software as set forth in Article 5 hereinafter for the whole term of the
124 | protection granted by the rights over said Software.
125 |
126 |
127 | Article 3 - ACCEPTANCE
128 |
129 | 3.1 The Licensee shall be deemed as having accepted the terms and
130 | conditions of this Agreement upon the occurrence of the first of the
131 | following events:
132 |
133 | * (i) loading the Software by any or all means, notably, by
134 | downloading from a remote server, or by loading from a physical
135 | medium;
136 | * (ii) the first time the Licensee exercises any of the rights
137 | granted hereunder.
138 |
139 | 3.2 One copy of the Agreement, containing a notice relating to the
140 | characteristics of the Software, to the limited warranty, and to the
141 | fact that its use is restricted to experienced users has been provided
142 | to the Licensee prior to its acceptance as set forth in Article 3.1
143 | hereinabove, and the Licensee hereby acknowledges that it has read and
144 | understood it.
145 |
146 |
147 | Article 4 - EFFECTIVE DATE AND TERM
148 |
149 |
150 | 4.1 EFFECTIVE DATE
151 |
152 | The Agreement shall become effective on the date when it is accepted by
153 | the Licensee as set forth in Article 3.1.
154 |
155 |
156 | 4.2 TERM
157 |
158 | The Agreement shall remain in force for the entire legal term of
159 | protection of the economic rights over the Software.
160 |
161 |
162 | Article 5 - SCOPE OF RIGHTS GRANTED
163 |
164 | The Licensor hereby grants to the Licensee, who accepts, the following
165 | rights over the Software for any or all use, and for the term of the
166 | Agreement, on the basis of the terms and conditions set forth hereinafter.
167 |
168 | Besides, if the Licensor owns or comes to own one or more patents
169 | protecting all or part of the functions of the Software or of its
170 | components, the Licensor undertakes not to enforce the rights granted by
171 | these patents against successive Licensees using, exploiting or
172 | modifying the Software. If these patents are transferred, the Licensor
173 | undertakes to have the transferees subscribe to the obligations set
174 | forth in this paragraph.
175 |
176 |
177 | 5.1 RIGHT OF USE
178 |
179 | The Licensee is authorized to use the Software, without any limitation
180 | as to its fields of application, with it being hereinafter specified
181 | that this comprises:
182 |
183 | 1. permanent or temporary reproduction of all or part of the Software
184 | by any or all means and in any or all form.
185 |
186 | 2. loading, displaying, running, or storing the Software on any or
187 | all medium.
188 |
189 | 3. entitlement to observe, study or test its operation so as to
190 | determine the ideas and principles behind any or all constituent
191 | elements of said Software. This shall apply when the Licensee
192 | carries out any or all loading, displaying, running, transmission
193 | or storage operation as regards the Software, that it is entitled
194 | to carry out hereunder.
195 |
196 |
197 | 5.2 RIGHT OF MODIFICATION
198 |
199 | The right of modification includes the right to translate, adapt,
200 | arrange, or make any or all modifications to the Software, and the right
201 | to reproduce the resulting software. It includes, in particular, the
202 | right to create a Derivative Software.
203 |
204 | The Licensee is authorized to make any or all modification to the
205 | Software provided that it includes an explicit notice that it is the
206 | author of said modification and indicates the date of the creation thereof.
207 |
208 |
209 | 5.3 RIGHT OF DISTRIBUTION
210 |
211 | In particular, the right of distribution includes the right to publish,
212 | transmit and communicate the Software to the general public on any or
213 | all medium, and by any or all means, and the right to market, either in
214 | consideration of a fee, or free of charge, one or more copies of the
215 | Software by any means.
216 |
217 | The Licensee is further authorized to distribute copies of the modified
218 | or unmodified Software to third parties according to the terms and
219 | conditions set forth hereinafter.
220 |
221 |
222 | 5.3.1 DISTRIBUTION OF SOFTWARE WITHOUT MODIFICATION
223 |
224 | The Licensee is authorized to distribute true copies of the Software in
225 | Source Code or Object Code form, provided that said distribution
226 | complies with all the provisions of the Agreement and is accompanied by:
227 |
228 | 1. a copy of the Agreement,
229 |
230 | 2. a notice relating to the limitation of both the Licensor's
231 | warranty and liability as set forth in Articles 8 and 9,
232 |
233 | and that, in the event that only the Object Code of the Software is
234 | redistributed, the Licensee allows effective access to the full Source
235 | Code of the Software at a minimum during the entire period of its
236 | distribution of the Software, it being understood that the additional
237 | cost of acquiring the Source Code shall not exceed the cost of
238 | transferring the data.
239 |
240 |
241 | 5.3.2 DISTRIBUTION OF MODIFIED SOFTWARE
242 |
243 | When the Licensee makes an Integrated Contribution to the Software, the
244 | terms and conditions for the distribution of the resulting Modified
245 | Software become subject to all the provisions of this Agreement.
246 |
247 | The Licensee is authorized to distribute the Modified Software, in
248 | source code or object code form, provided that said distribution
249 | complies with all the provisions of the Agreement and is accompanied by:
250 |
251 | 1. a copy of the Agreement,
252 |
253 | 2. a notice relating to the limitation of both the Licensor's
254 | warranty and liability as set forth in Articles 8 and 9,
255 |
256 | and that, in the event that only the object code of the Modified
257 | Software is redistributed, the Licensee allows effective access to the
258 | full source code of the Modified Software at a minimum during the entire
259 | period of its distribution of the Modified Software, it being understood
260 | that the additional cost of acquiring the source code shall not exceed
261 | the cost of transferring the data.
262 |
263 |
264 | 5.3.3 DISTRIBUTION OF DERIVATIVE SOFTWARE
265 |
266 | When the Licensee creates Derivative Software, this Derivative Software
267 | may be distributed under a license agreement other than this Agreement,
268 | subject to compliance with the requirement to include a notice
269 | concerning the rights over the Software as defined in Article 6.4.
270 | In the event the creation of the Derivative Software required modification
271 | of the Source Code, the Licensee undertakes that:
272 |
273 | 1. the resulting Modified Software will be governed by this Agreement,
274 | 2. the Integrated Contributions in the resulting Modified Software
275 | will be clearly identified and documented,
276 | 3. the Licensee will allow effective access to the source code of the
277 | Modified Software, at a minimum during the entire period of
278 | distribution of the Derivative Software, such that such
279 | modifications may be carried over in a subsequent version of the
280 | Software; it being understood that the additional cost of
281 | purchasing the source code of the Modified Software shall not
282 | exceed the cost of transferring the data.
283 |
284 |
285 | 5.3.4 COMPATIBILITY WITH THE CeCILL LICENSE
286 |
287 | When a Modified Software contains an Integrated Contribution subject to
288 | the CeCILL license agreement, or when a Derivative Software contains a
289 | Related Module subject to the CeCILL license agreement, the provisions
290 | set forth in the third item of Article 6.4 are optional.
291 |
292 |
293 | Article 6 - INTELLECTUAL PROPERTY
294 |
295 |
296 | 6.1 OVER THE INITIAL SOFTWARE
297 |
298 | The Holder owns the economic rights over the Initial Software. Any or
299 | all use of the Initial Software is subject to compliance with the terms
300 | and conditions under which the Holder has elected to distribute its work
301 | and no one shall be entitled to modify the terms and conditions for the
302 | distribution of said Initial Software.
303 |
304 | The Holder undertakes that the Initial Software will remain ruled at
305 | least by this Agreement, for the duration set forth in Article 4.2.
306 |
307 |
308 | 6.2 OVER THE INTEGRATED CONTRIBUTIONS
309 |
310 | The Licensee who develops an Integrated Contribution is the owner of the
311 | intellectual property rights over this Contribution as defined by
312 | applicable law.
313 |
314 |
315 | 6.3 OVER THE RELATED MODULES
316 |
317 | The Licensee who develops a Related Module is the owner of the
318 | intellectual property rights over this Related Module as defined by
319 | applicable law and is free to choose the type of agreement that shall
320 | govern its distribution under the conditions defined in Article 5.3.3.
321 |
322 |
323 | 6.4 NOTICE OF RIGHTS
324 |
325 | The Licensee expressly undertakes:
326 |
327 | 1. not to remove, or modify, in any manner, the intellectual property
328 | notices attached to the Software;
329 |
330 | 2. to reproduce said notices, in an identical manner, in the copies
331 | of the Software modified or not;
332 |
333 | 3. to ensure that use of the Software, its intellectual property
334 | notices and the fact that it is governed by the Agreement is
335 | indicated in a text that is easily accessible, specifically from
336 | the interface of any Derivative Software.
337 |
338 | The Licensee undertakes not to directly or indirectly infringe the
339 | intellectual property rights of the Holder and/or Contributors on the
340 | Software and to take, where applicable, vis-à-vis its staff, any and all
341 | measures required to ensure respect of said intellectual property rights
342 | of the Holder and/or Contributors.
343 |
344 |
345 | Article 7 - RELATED SERVICES
346 |
347 | 7.1 Under no circumstances shall the Agreement oblige the Licensor to
348 | provide technical assistance or maintenance services for the Software.
349 |
350 | However, the Licensor is entitled to offer this type of services. The
351 | terms and conditions of such technical assistance, and/or such
352 | maintenance, shall be set forth in a separate instrument. Only the
353 | Licensor offering said maintenance and/or technical assistance services
354 | shall incur liability therefor.
355 |
356 | 7.2 Similarly, any Licensor is entitled to offer to its licensees, under
357 | its sole responsibility, a warranty, that shall only be binding upon
358 | itself, for the redistribution of the Software and/or the Modified
359 | Software, under terms and conditions that it is free to decide. Said
360 | warranty, and the financial terms and conditions of its application,
361 | shall be subject of a separate instrument executed between the Licensor
362 | and the Licensee.
363 |
364 |
365 | Article 8 - LIABILITY
366 |
367 | 8.1 Subject to the provisions of Article 8.2, the Licensee shall be
368 | entitled to claim compensation for any direct loss it may have suffered
369 | from the Software as a result of a fault on the part of the relevant
370 | Licensor, subject to providing evidence thereof.
371 |
372 | 8.2 The Licensor's liability is limited to the commitments made under
373 | this Agreement and shall not be incurred as a result of in particular:
374 | (i) loss due the Licensee's total or partial failure to fulfill its
375 | obligations, (ii) direct or consequential loss that is suffered by the
376 | Licensee due to the use or performance of the Software, and (iii) more
377 | generally, any consequential loss. In particular the Parties expressly
378 | agree that any or all pecuniary or business loss (i.e. loss of data,
379 | loss of profits, operating loss, loss of customers or orders,
380 | opportunity cost, any disturbance to business activities) or any or all
381 | legal proceedings instituted against the Licensee by a third party,
382 | shall constitute consequential loss and shall not provide entitlement to
383 | any or all compensation from the Licensor.
384 |
385 |
386 | Article 9 - WARRANTY
387 |
388 | 9.1 The Licensee acknowledges that the scientific and technical
389 | state-of-the-art when the Software was distributed did not enable all
390 | possible uses to be tested and verified, nor for the presence of
391 | possible defects to be detected. In this respect, the Licensee's
392 | attention has been drawn to the risks associated with loading, using,
393 | modifying and/or developing and reproducing the Software which are
394 | reserved for experienced users.
395 |
396 | The Licensee shall be responsible for verifying, by any or all means,
397 | the suitability of the product for its requirements, its good working
398 | order, and for ensuring that it shall not cause damage to either persons
399 | or properties.
400 |
401 | 9.2 The Licensor hereby represents, in good faith, that it is entitled
402 | to grant all the rights over the Software (including in particular the
403 | rights set forth in Article 5).
404 |
405 | 9.3 The Licensee acknowledges that the Software is supplied "as is" by
406 | the Licensor without any other express or tacit warranty, other than
407 | that provided for in Article 9.2 and, in particular, without any warranty
408 | as to its commercial value, its secured, safe, innovative or relevant
409 | nature.
410 |
411 | Specifically, the Licensor does not warrant that the Software is free
412 | from any error, that it will operate without interruption, that it will
413 | be compatible with the Licensee's own equipment and software
414 | configuration, nor that it will meet the Licensee's requirements.
415 |
416 | 9.4 The Licensor does not either expressly or tacitly warrant that the
417 | Software does not infringe any third party intellectual property right
418 | relating to a patent, software or any other property right. Therefore,
419 | the Licensor disclaims any and all liability towards the Licensee
420 | arising out of any or all proceedings for infringement that may be
421 | instituted in respect of the use, modification and redistribution of the
422 | Software. Nevertheless, should such proceedings be instituted against
423 | the Licensee, the Licensor shall provide it with technical and legal
424 | assistance for its defense. Such technical and legal assistance shall be
425 | decided on a case-by-case basis between the relevant Licensor and the
426 | Licensee pursuant to a memorandum of understanding. The Licensor
427 | disclaims any and all liability as regards the Licensee's use of the
428 | name of the Software. No warranty is given as regards the existence of
429 | prior rights over the name of the Software or as regards the existence
430 | of a trademark.
431 |
432 |
433 | Article 10 - TERMINATION
434 |
435 | 10.1 In the event of a breach by the Licensee of its obligations
436 | hereunder, the Licensor may automatically terminate this Agreement
437 | thirty (30) days after notice has been sent to the Licensee and has
438 | remained ineffective.
439 |
440 | 10.2 A Licensee whose Agreement is terminated shall no longer be
441 | authorized to use, modify or distribute the Software. However, any
442 | licenses that it may have granted prior to termination of the Agreement
443 | shall remain valid subject to their having been granted in compliance
444 | with the terms and conditions hereof.
445 |
446 |
447 | Article 11 - MISCELLANEOUS
448 |
449 |
450 | 11.1 EXCUSABLE EVENTS
451 |
452 | Neither Party shall be liable for any or all delay, or failure to
453 | perform the Agreement, that may be attributable to an event of force
454 | majeure, an act of God or an outside cause, such as defective
455 | functioning or interruptions of the electricity or telecommunications
456 | networks, network paralysis following a virus attack, intervention by
457 | government authorities, natural disasters, water damage, earthquakes,
458 | fire, explosions, strikes and labor unrest, war, etc.
459 |
460 | 11.2 Any failure by either Party, on one or more occasions, to invoke
461 | one or more of the provisions hereof, shall under no circumstances be
462 | interpreted as being a waiver by the interested Party of its right to
463 | invoke said provision(s) subsequently.
464 |
465 | 11.3 The Agreement cancels and replaces any or all previous agreements,
466 | whether written or oral, between the Parties and having the same
467 | purpose, and constitutes the entirety of the agreement between said
468 | Parties concerning said purpose. No supplement or modification to the
469 | terms and conditions hereof shall be effective as between the Parties
470 | unless it is made in writing and signed by their duly authorized
471 | representatives.
472 |
473 | 11.4 In the event that one or more of the provisions hereof were to
474 | conflict with a current or future applicable act or legislative text,
475 | said act or legislative text shall prevail, and the Parties shall make
476 | the necessary amendments so as to comply with said act or legislative
477 | text. All other provisions shall remain effective. Similarly, invalidity
478 | of a provision of the Agreement, for any reason whatsoever, shall not
479 | cause the Agreement as a whole to be invalid.
480 |
481 |
482 | 11.5 LANGUAGE
483 |
484 | The Agreement is drafted in both French and English and both versions
485 | are deemed authentic.
486 |
487 |
488 | Article 12 - NEW VERSIONS OF THE AGREEMENT
489 |
490 | 12.1 Any person is authorized to duplicate and distribute copies of this
491 | Agreement.
492 |
493 | 12.2 So as to ensure coherence, the wording of this Agreement is
494 | protected and may only be modified by the authors of the License, who
495 | reserve the right to periodically publish updates or new versions of the
496 | Agreement, each with a separate number. These subsequent versions may
497 | address new issues encountered by Free Software.
498 |
499 | 12.3 Any Software distributed under a given version of the Agreement may
500 | only be subsequently distributed under the same version of the Agreement
501 | or a subsequent version.
502 |
503 |
504 | Article 13 - GOVERNING LAW AND JURISDICTION
505 |
506 | 13.1 The Agreement is governed by French law. The Parties agree to
507 | endeavor to seek an amicable solution to any disagreements or disputes
508 | that may arise during the performance of the Agreement.
509 |
510 | 13.2 Failing an amicable solution within two (2) months as from their
511 | occurrence, and unless emergency proceedings are necessary, the
512 | disagreements or disputes shall be referred to the Paris Courts having
513 | jurisdiction, by the more diligent Party.
514 |
--------------------------------------------------------------------------------
/S-Transformer/README.md:
--------------------------------------------------------------------------------
1 | # TrAISformer
2 |
3 | Pytorch implementation of TrAISformer---A generative transformer for AIS trajectory prediction (https://arxiv.org/abs/2109.03958).
4 |
5 | The transformer part is adapted from: https://github.com/karpathy/minGPT
6 |
7 | ---
8 |
9 |
10 |
11 |
12 |
13 | #### Requirements:
14 | See requirements.yml
15 |
16 | ### Datasets:
17 |
18 | The data used in this paper are provided by the [Danish Maritime Authority (DMA)](https://dma.dk/safety-at-sea/navigational-information/ais-data).
19 | Please refer to [the paper](https://arxiv.org/abs/2109.03958) for the details of the pre-processing step. The code is available here: https://github.com/CIA-Oceanix/GeoTrackNet/blob/master/data/csv2pkl.py
20 |
21 | A processed dataset can be found in `./data/ct_dma/`
22 | (the format is `[lat, log, sog, cog, unix_timestamp, mmsi]`).
23 |
24 | ### Run
25 |
26 | Run `trAISformer.py` to train and evaluate the model.
27 | (Please note that the values given by the code are in km, while the values presented in the paper were converted to nautical mile.)
28 |
29 |
30 | ### License
31 |
32 | See `LICENSE`
33 |
34 | ### Contact
35 | For any questions, please open an issue and assign it to @dnguyengithub.
36 |
37 |
--------------------------------------------------------------------------------
/S-Transformer/Untitled.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 6,
6 | "id": "6a7477f1-1db4-46cf-8e22-7894b8bd4eb5",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import matplotlib.pyplot as plt\n",
11 | "import cartopy.crs as ccrs\n",
12 | "import cartopy.feature as cfeature\n",
13 | "import pandas as pd\n",
14 | "import xlrd"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 4,
20 | "id": "9a5ffab3-fbc5-454e-b6bc-773ccceae510",
21 | "metadata": {},
22 | "outputs": [
23 | {
24 | "name": "stdout",
25 | "output_type": "stream",
26 | "text": [
27 | "Collecting xlrd\n",
28 | " Downloading xlrd-2.0.1-py2.py3-none-any.whl.metadata (3.4 kB)\n",
29 | "Downloading xlrd-2.0.1-py2.py3-none-any.whl (96 kB)\n",
30 | " ---------------------------------------- 0.0/96.5 kB ? eta -:--:--\n",
31 | " ---- ----------------------------------- 10.2/96.5 kB ? eta -:--:--\n",
32 | " -------- ------------------------------- 20.5/96.5 kB 330.3 kB/s eta 0:00:01\n",
33 | " ---------------- ----------------------- 41.0/96.5 kB 279.3 kB/s eta 0:00:01\n",
34 | " ------------------------- -------------- 61.4/96.5 kB 328.2 kB/s eta 0:00:01\n",
35 | " ----------------------------- ---------- 71.7/96.5 kB 326.8 kB/s eta 0:00:01\n",
36 | " ---------------------------------------- 96.5/96.5 kB 425.1 kB/s eta 0:00:00\n",
37 | "Installing collected packages: xlrd\n",
38 | "Successfully installed xlrd-2.0.1\n"
39 | ]
40 | }
41 | ],
42 | "source": [
43 | "! pip install xlrd"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 9,
49 | "id": "7d99710b-9312-4ba1-8211-a62a6dd6d1fe",
50 | "metadata": {},
51 | "outputs": [
52 | {
53 | "ename": "XLRDError",
54 | "evalue": "Unsupported format, or corrupt file: Expected BOF record; found b'SeNo\\tBat'",
55 | "output_type": "error",
56 | "traceback": [
57 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
58 | "\u001b[1;31mXLRDError\u001b[0m Traceback (most recent call last)",
59 | "Cell \u001b[1;32mIn[9], line 2\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# 打开 Excel 文件\u001b[39;00m\n\u001b[1;32m----> 2\u001b[0m workbook \u001b[38;5;241m=\u001b[39m \u001b[43mxlrd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopen_workbook\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mD:/AA_work/AIS数据/18日.xls\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 4\u001b[0m \u001b[38;5;66;03m# 获取第一个工作表\u001b[39;00m\n\u001b[0;32m 5\u001b[0m sheet \u001b[38;5;241m=\u001b[39m workbook\u001b[38;5;241m.\u001b[39msheet_by_index(\u001b[38;5;241m0\u001b[39m)\n",
60 | "File \u001b[1;32mD:\\anaconda3\\envs\\Deep_learn_GPU\\Lib\\site-packages\\xlrd\\__init__.py:172\u001b[0m, in \u001b[0;36mopen_workbook\u001b[1;34m(filename, logfile, verbosity, use_mmap, file_contents, encoding_override, formatting_info, on_demand, ragged_rows, ignore_workbook_corruption)\u001b[0m\n\u001b[0;32m 169\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m file_format \u001b[38;5;129;01mand\u001b[39;00m file_format \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mxls\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m 170\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m XLRDError(FILE_FORMAT_DESCRIPTIONS[file_format]\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m; not supported\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m--> 172\u001b[0m bk \u001b[38;5;241m=\u001b[39m \u001b[43mopen_workbook_xls\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 173\u001b[0m \u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 174\u001b[0m \u001b[43m \u001b[49m\u001b[43mlogfile\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlogfile\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 175\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbosity\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbosity\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 176\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_mmap\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_mmap\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 177\u001b[0m \u001b[43m \u001b[49m\u001b[43mfile_contents\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfile_contents\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 178\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoding_override\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoding_override\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 179\u001b[0m \u001b[43m \u001b[49m\u001b[43mformatting_info\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mformatting_info\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 180\u001b[0m \u001b[43m \u001b[49m\u001b[43mon_demand\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mon_demand\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 181\u001b[0m \u001b[43m \u001b[49m\u001b[43mragged_rows\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mragged_rows\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 182\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_workbook_corruption\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_workbook_corruption\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 183\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 185\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m bk\n",
61 | "File \u001b[1;32mD:\\anaconda3\\envs\\Deep_learn_GPU\\Lib\\site-packages\\xlrd\\book.py:79\u001b[0m, in \u001b[0;36mopen_workbook_xls\u001b[1;34m(filename, logfile, verbosity, use_mmap, file_contents, encoding_override, formatting_info, on_demand, ragged_rows, ignore_workbook_corruption)\u001b[0m\n\u001b[0;32m 77\u001b[0m t1 \u001b[38;5;241m=\u001b[39m perf_counter()\n\u001b[0;32m 78\u001b[0m bk\u001b[38;5;241m.\u001b[39mload_time_stage_1 \u001b[38;5;241m=\u001b[39m t1 \u001b[38;5;241m-\u001b[39m t0\n\u001b[1;32m---> 79\u001b[0m biff_version \u001b[38;5;241m=\u001b[39m \u001b[43mbk\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgetbof\u001b[49m\u001b[43m(\u001b[49m\u001b[43mXL_WORKBOOK_GLOBALS\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 80\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m biff_version:\n\u001b[0;32m 81\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m XLRDError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCan\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt determine file\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ms BIFF version\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
62 | "File \u001b[1;32mD:\\anaconda3\\envs\\Deep_learn_GPU\\Lib\\site-packages\\xlrd\\book.py:1284\u001b[0m, in \u001b[0;36mBook.getbof\u001b[1;34m(self, rqd_stream)\u001b[0m\n\u001b[0;32m 1282\u001b[0m bof_error(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mExpected BOF record; met end of file\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m 1283\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m opcode \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m bofcodes:\n\u001b[1;32m-> 1284\u001b[0m \u001b[43mbof_error\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mExpected BOF record; found \u001b[39;49m\u001b[38;5;132;43;01m%r\u001b[39;49;00m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m%\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmem\u001b[49m\u001b[43m[\u001b[49m\u001b[43msavpos\u001b[49m\u001b[43m:\u001b[49m\u001b[43msavpos\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m8\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1285\u001b[0m length \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget2bytes()\n\u001b[0;32m 1286\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m length \u001b[38;5;241m==\u001b[39m MY_EOF:\n",
63 | "File \u001b[1;32mD:\\anaconda3\\envs\\Deep_learn_GPU\\Lib\\site-packages\\xlrd\\book.py:1278\u001b[0m, in \u001b[0;36mBook.getbof..bof_error\u001b[1;34m(msg)\u001b[0m\n\u001b[0;32m 1277\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbof_error\u001b[39m(msg):\n\u001b[1;32m-> 1278\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m XLRDError(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mUnsupported format, or corrupt file: \u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m msg)\n",
64 | "\u001b[1;31mXLRDError\u001b[0m: Unsupported format, or corrupt file: Expected BOF record; found b'SeNo\\tBat'"
65 | ]
66 | }
67 | ],
68 | "source": [
69 | "\n",
70 | "# 打开 Excel 文件\n",
71 | "workbook = xlrd.open_workbook('D:/AA_work/AIS数据/18日.xls')\n",
72 | "\n",
73 | "# 获取第一个工作表\n",
74 | "sheet = workbook.sheet_by_index(0)\n",
75 | "\n",
76 | "# 获取行数和列数\n",
77 | "num_rows = sheet.nrows\n",
78 | "num_cols = sheet.ncols\n",
79 | "num_cols"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "id": "d47e2e20-c3ea-418a-b7d5-d2e6ad900687",
86 | "metadata": {},
87 | "outputs": [],
88 | "source": []
89 | }
90 | ],
91 | "metadata": {
92 | "kernelspec": {
93 | "display_name": "Python 3 (ipykernel)",
94 | "language": "python",
95 | "name": "python3"
96 | },
97 | "language_info": {
98 | "codemirror_mode": {
99 | "name": "ipython",
100 | "version": 3
101 | },
102 | "file_extension": ".py",
103 | "mimetype": "text/x-python",
104 | "name": "python",
105 | "nbconvert_exporter": "python",
106 | "pygments_lexer": "ipython3",
107 | "version": "3.11.7"
108 | }
109 | },
110 | "nbformat": 4,
111 | "nbformat_minor": 5
112 | }
113 |
--------------------------------------------------------------------------------
/S-Transformer/config_trAISformer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021, Duong Nguyen
3 | #
4 | # Licensed under the CECILL-C License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.cecill.info
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Configuration flags to run the main script.
17 | """
18 |
19 | import os
20 | import pickle
21 | import torch
22 |
23 |
24 | class Config():
25 | retrain = True
26 | tb_log = False
27 | device = torch.device("cuda:0")
28 | # device = torch.device("cpu")
29 |
30 | max_epochs = 50
31 | batch_size = 32
32 | n_samples = 16
33 |
34 | init_seqlen = 18
35 | max_seqlen = 120
36 | min_seqlen = 36
37 |
38 | dataset_name = "ct_dma"
39 |
40 | if dataset_name == "ct_dma": #==============================
41 |
42 | # When mode == "grad" or "pos_grad", sog and cog are actually dlat and
43 | # dlon
44 | lat_size = 250
45 | lon_size = 270
46 | sog_size = 30
47 | cog_size = 72
48 |
49 |
50 | n_lat_embd = 256
51 | n_lon_embd = 256
52 | n_sog_embd = 128
53 | n_cog_embd = 128
54 |
55 | lat_min = 55.5
56 | lat_max = 58.0
57 | lon_min = 10.3
58 | lon_max = 13
59 |
60 |
61 | #===========================================================================
62 | # Model and sampling flags
63 | mode = "pos" #"pos", "pos_grad", "mlp_pos", "mlpgrid_pos", "velo", "grid_l2", "grid_l1",
64 | # "ce_vicinity", "gridcont_grid", "gridcont_real", "gridcont_gridsin", "gridcont_gridsigmoid"
65 | sample_mode = "pos_vicinity" # "pos", "pos_vicinity" or "velo"
66 | top_k = 10 # int or None
67 | r_vicinity = 40 # int
68 |
69 | # Blur flags
70 | #===================================================
71 | blur = True
72 | blur_learnable = False
73 | blur_loss_w = 1.0
74 | blur_n = 2
75 | if not blur:
76 | blur_n = 0
77 | blur_loss_w = 0
78 |
79 | # Data flags
80 | #===================================================
81 | datadir = f"./data/{dataset_name}/"
82 | trainset_name = f"{dataset_name}_train.pkl"
83 | validset_name = f"{dataset_name}_valid.pkl"
84 | testset_name = f"{dataset_name}_test.pkl"
85 |
86 |
87 | # model parameters
88 | #===================================================
89 | n_head = 8
90 | n_layer = 8
91 | full_size = lat_size + lon_size + sog_size + cog_size
92 | n_embd = n_lat_embd + n_lon_embd + n_sog_embd + n_cog_embd
93 | # base GPT config, params common to all GPT versions
94 | embd_pdrop = 0.1
95 | resid_pdrop = 0.1
96 | attn_pdrop = 0.1
97 |
98 | # optimization parameters
99 | #===================================================
100 | learning_rate = 6e-4 # 6e-4
101 | betas = (0.9, 0.95)
102 | grad_norm_clip = 1.0
103 | weight_decay = 0.1 # only applied on matmul weights
104 | # learning rate decay params: linear warmup followed by cosine decay to 10% of original
105 | lr_decay = True
106 | warmup_tokens = 512*20 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
107 | final_tokens = 260e9 # (at what point we reach 10% of original LR)
108 | num_workers = 4 # for DataLoader
109 |
110 | filename = f"{dataset_name}"\
111 | + f"-{mode}-{sample_mode}-{top_k}-{r_vicinity}"\
112 | + f"-blur-{blur}-{blur_learnable}-{blur_n}-{blur_loss_w}"\
113 | + f"-data_size-{lat_size}-{lon_size}-{sog_size}-{cog_size}"\
114 | + f"-embd_size-{n_lat_embd}-{n_lon_embd}-{n_sog_embd}-{n_cog_embd}"\
115 | + f"-head-{n_head}-{n_layer}"\
116 | + f"-bs-{batch_size}"\
117 | + f"-lr-{learning_rate}"\
118 | + f"-seqlen-{init_seqlen}-{max_seqlen}"
119 | savedir = "./results/"+filename+"/"
120 |
121 | ckpt_path = os.path.join(savedir,"model.pt")
--------------------------------------------------------------------------------
/S-Transformer/datasets.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021, Duong Nguyen
3 | #
4 | # Licensed under the CECILL-C License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.cecill.info
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # 这两个数据集类可以帮助用户加载和处理 AIS 船舶轨迹数据,并将其转换为 PyTorch 模型所需的张量格式。
16 | """Customized Pytorch Dataset.
17 | """
18 |
19 | import numpy as np
20 | import os
21 | import pickle
22 |
23 | import torch
24 | from torch.utils.data import Dataset, DataLoader
25 |
26 | class AISDataset(Dataset):
27 | """Customized Pytorch dataset.
28 | """
29 | def __init__(self,
30 | l_data,
31 | max_seqlen=96,
32 | dtype=torch.float32,
33 | device=torch.device("cpu")):
34 | """
35 | Args
36 | l_data: list of dictionaries, each element is an AIS trajectory.
37 | l_data[idx]["mmsi"]: vessel's MMSI.
38 | l_data[idx]["traj"]: a matrix whose columns are
39 | [LAT, LON, SOG, COG, TIMESTAMP]
40 | lat, lon, sog, and cod have been standardized, i.e. range = [0,1).
41 | max_seqlen: (optional) max sequence length. Default is
42 | """
43 |
44 | self.max_seqlen = max_seqlen
45 | self.device = device
46 |
47 | self.l_data = l_data
48 |
49 | def __len__(self):
50 | return len(self.l_data)
51 |
52 | def __getitem__(self, idx):
53 | """Gets items.
54 |
55 | Returns:
56 | seq: Tensor of (max_seqlen, [lat,lon,sog,cog]).
57 | mask: Tensor of (max_seqlen, 1). mask[i] = 0.0 if x[i] is a
58 | padding.
59 | seqlen: sequence length.
60 | mmsi: vessel's MMSI.
61 | time_start: timestamp of the starting time of the trajectory.
62 | """
63 | V = self.l_data[idx]
64 | m_v = V["traj"][:,:4] # lat, lon, sog, cog
65 | # m_v[m_v==1] = 0.9999
66 | m_v[m_v>0.9999] = 0.9999
67 | seqlen = min(len(m_v), self.max_seqlen)
68 | seq = np.zeros((self.max_seqlen,4))
69 | seq[:seqlen,:] = m_v[:seqlen,:]
70 | seq = torch.tensor(seq, dtype=torch.float32)
71 |
72 | mask = torch.zeros(self.max_seqlen)
73 | mask[:seqlen] = 1.
74 |
75 | seqlen = torch.tensor(seqlen, dtype=torch.int)
76 | mmsi = torch.tensor(V["mmsi"], dtype=torch.int)
77 | time_start = torch.tensor(V["traj"][0,4], dtype=torch.int)
78 |
79 | return seq , mask, seqlen, mmsi, time_start
80 |
81 | class AISDataset_grad(Dataset):
82 | """Customized Pytorch dataset.
83 | Return the positions and the gradient of the positions.
84 | """
85 | def __init__(self,
86 | l_data,
87 | dlat_max=0.04,
88 | dlon_max=0.04,
89 | max_seqlen=96,
90 | dtype=torch.float32,
91 | device=torch.device("cpu")):
92 | """
93 | Args
94 | l_data: list of dictionaries, each element is an AIS trajectory.
95 | l_data[idx]["mmsi"]: vessel's MMSI.
96 | l_data[idx]["traj"]: a matrix whose columns are
97 | [LAT, LON, SOG, COG, TIMESTAMP]
98 | lat, lon, sog, and cod have been standardized, i.e. range = [0,1).
99 | dlat_max, dlon_max: the maximum value of the gradient of the positions.
100 | dlat_max = max(lat[idx+1]-lat[idx]) for all idx.
101 | max_seqlen: (optional) max sequence length. Default is
102 | """
103 |
104 | self.dlat_max = dlat_max
105 | self.dlon_max = dlon_max
106 | self.dpos_max = np.array([dlat_max, dlon_max])
107 | self.max_seqlen = max_seqlen
108 | self.device = device
109 |
110 | self.l_data = l_data
111 |
112 | def __len__(self):
113 | return len(self.l_data)
114 |
115 | def __getitem__(self, idx):
116 | """Gets items.
117 |
118 | Returns:
119 | seq: Tensor of (max_seqlen, [lat,lon,sog,cog]).
120 | mask: Tensor of (max_seqlen, 1). mask[i] = 0.0 if x[i] is a
121 | padding.
122 | seqlen: sequence length.
123 | mmsi: vessel's MMSI.
124 | time_start: timestamp of the starting time of the trajectory.
125 | """
126 | V = self.l_data[idx]
127 | m_v = V["traj"][:,:4] # lat, lon, sog, cog
128 | m_v[m_v==1] = 0.9999
129 | seqlen = min(len(m_v), self.max_seqlen)
130 | seq = np.zeros((self.max_seqlen,4))
131 | # lat and lon
132 | seq[:seqlen,:2] = m_v[:seqlen,:2]
133 | # dlat and dlon
134 | dpos = (m_v[1:,:2]-m_v[:-1,:2]+self.dpos_max )/(2*self.dpos_max )
135 | dpos = np.concatenate((dpos[:1,:],dpos),axis=0)
136 | dpos[dpos>=1] = 0.9999
137 | dpos[dpos<=0] = 0.0
138 | seq[:seqlen,2:] = dpos[:seqlen,:2]
139 |
140 | # convert to Tensor
141 | seq = torch.tensor(seq, dtype=torch.float32)
142 |
143 | mask = torch.zeros(self.max_seqlen)
144 | mask[:seqlen] = 1.
145 |
146 | seqlen = torch.tensor(seqlen, dtype=torch.int)
147 | mmsi = torch.tensor(V["mmsi"], dtype=torch.int)
148 | time_start = torch.tensor(V["traj"][0,4], dtype=torch.int)
149 |
150 | return seq , mask, seqlen, mmsi, time_start
--------------------------------------------------------------------------------
/S-Transformer/figures/t18_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/axyqdm/Ship-trajectory-data-preprocessing-and-analysis/fa2c61a3d177b13f60e2e615fecf4b80f7db6c7c/S-Transformer/figures/t18_3.png
--------------------------------------------------------------------------------
/S-Transformer/models.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021, Duong Nguyen
3 | #
4 | # Licensed under the CECILL-C License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.cecill.info
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Models for TrAISformer.
17 | https://arxiv.org/abs/2109.03958
18 |
19 | The code is built upon:
20 | https://github.com/karpathy/minGPT
21 | """
22 |
23 | import math
24 | import logging
25 | import pdb
26 |
27 |
28 | import torch
29 | import torch.nn as nn
30 | from torch.nn import functional as F
31 |
32 | logger = logging.getLogger(__name__)
33 |
34 |
35 | class CausalSelfAttention(nn.Module):
36 | """
37 | A vanilla multi-head masked self-attention layer with a projection at the end.
38 | It is possible to use torch.nn.MultiheadAttention here but I am including an
39 | explicit implementation here to show that there is nothing too scary here.
40 | """
41 |
42 | def __init__(self, config):
43 | super().__init__()
44 | assert config.n_embd % config.n_head == 0
45 | # key, query, value projections for all heads
46 | self.key = nn.Linear(config.n_embd, config.n_embd)
47 | self.query = nn.Linear(config.n_embd, config.n_embd)
48 | self.value = nn.Linear(config.n_embd, config.n_embd)
49 | # regularization
50 | self.attn_drop = nn.Dropout(config.attn_pdrop)
51 | self.resid_drop = nn.Dropout(config.resid_pdrop)
52 | # output projection
53 | self.proj = nn.Linear(config.n_embd, config.n_embd)
54 | # causal mask to ensure that attention is only applied to the left in the input sequence
55 | self.register_buffer("mask", torch.tril(torch.ones(config.max_seqlen, config.max_seqlen))
56 | .view(1, 1, config.max_seqlen, config.max_seqlen))
57 | self.n_head = config.n_head
58 |
59 | def forward(self, x, layer_past=None):
60 | B, T, C = x.size()
61 |
62 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
63 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
64 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
65 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
66 |
67 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
68 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
69 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
70 | att = F.softmax(att, dim=-1)
71 | att = self.attn_drop(att)
72 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
73 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
74 |
75 | # output projection
76 | y = self.resid_drop(self.proj(y))
77 | return y
78 |
79 | class Block(nn.Module):
80 | """ an unassuming Transformer block """
81 |
82 | def __init__(self, config):
83 | super().__init__()
84 | self.ln1 = nn.LayerNorm(config.n_embd)
85 | self.ln2 = nn.LayerNorm(config.n_embd)
86 | self.attn = CausalSelfAttention(config)
87 | self.mlp = nn.Sequential(
88 | nn.Linear(config.n_embd, 4 * config.n_embd),
89 | nn.GELU(),
90 | nn.Linear(4 * config.n_embd, config.n_embd),
91 | nn.Dropout(config.resid_pdrop),
92 | )
93 |
94 | def forward(self, x):
95 | x = x + self.attn(self.ln1(x))
96 | x = x + self.mlp(self.ln2(x))
97 | return x
98 |
99 | class TrAISformer(nn.Module):
100 | """Transformer for AIS trajectories."""
101 |
102 | def __init__(self, config, partition_model = None):
103 | super().__init__()
104 |
105 | self.lat_size = config.lat_size
106 | self.lon_size = config.lon_size
107 | self.sog_size = config.sog_size
108 | self.cog_size = config.cog_size
109 | self.full_size = config.full_size
110 | self.n_lat_embd = config.n_lat_embd
111 | self.n_lon_embd = config.n_lon_embd
112 | self.n_sog_embd = config.n_sog_embd
113 | self.n_cog_embd = config.n_cog_embd
114 | self.register_buffer(
115 | "att_sizes",
116 | torch.tensor([config.lat_size, config.lon_size, config.sog_size, config.cog_size]))
117 | self.register_buffer(
118 | "emb_sizes",
119 | torch.tensor([config.n_lat_embd, config.n_lon_embd, config.n_sog_embd, config.n_cog_embd]))
120 |
121 | if hasattr(config,"partition_mode"):
122 | self.partition_mode = config.partition_mode
123 | else:
124 | self.partition_mode = "uniform"
125 | self.partition_model = partition_model
126 |
127 | if hasattr(config,"blur"):
128 | self.blur = config.blur
129 | self.blur_learnable = config.blur_learnable
130 | self.blur_loss_w = config.blur_loss_w
131 | self.blur_n = config.blur_n
132 | if self.blur:
133 | self.blur_module = nn.Conv1d(1, 1, 3, padding = 1, padding_mode = 'replicate', groups=1, bias=False)
134 | if not self.blur_learnable:
135 | for params in self.blur_module.parameters():
136 | params.requires_grad = False
137 | params.fill_(1/3)
138 | else:
139 | self.blur_module = None
140 |
141 |
142 | if hasattr(config,"lat_min"): # the ROI is provided.
143 | self.lat_min = config.lat_min
144 | self.lat_max = config.lat_max
145 | self.lon_min = config.lon_min
146 | self.lon_max = config.lon_max
147 | self.lat_range = config.lat_max-config.lat_min
148 | self.lon_range = config.lon_max-config.lon_min
149 | self.sog_range = 30.
150 |
151 | if hasattr(config,"mode"): # mode: "pos" or "velo".
152 | # "pos": predict directly the next positions.
153 | # "velo": predict the velocities, use them to
154 | # calculate the next positions.
155 | self.mode = config.mode
156 | else:
157 | self.mode = "pos"
158 |
159 |
160 | # Passing from the 4-D space to a high-dimentional space
161 | self.lat_emb = nn.Embedding(self.lat_size, config.n_lat_embd)
162 | self.lon_emb = nn.Embedding(self.lon_size, config.n_lon_embd)
163 | self.sog_emb = nn.Embedding(self.sog_size, config.n_sog_embd)
164 | self.cog_emb = nn.Embedding(self.cog_size, config.n_cog_embd)
165 |
166 |
167 | self.pos_emb = nn.Parameter(torch.zeros(1, config.max_seqlen, config.n_embd))
168 | self.drop = nn.Dropout(config.embd_pdrop)
169 |
170 | # transformer
171 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
172 |
173 |
174 | # decoder head
175 | self.ln_f = nn.LayerNorm(config.n_embd)
176 | if self.mode in ("mlp_pos","mlp"):
177 | self.head = nn.Linear(config.n_embd, config.n_embd, bias=False)
178 | else:
179 | self.head = nn.Linear(config.n_embd, self.full_size, bias=False) # Classification head
180 |
181 | self.max_seqlen = config.max_seqlen
182 | self.apply(self._init_weights)
183 |
184 | logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
185 |
186 | def get_max_seqlen(self):
187 | return self.max_seqlen
188 |
189 | def _init_weights(self, module):
190 | if isinstance(module, (nn.Linear, nn.Embedding)):
191 | module.weight.data.normal_(mean=0.0, std=0.02)
192 | if isinstance(module, nn.Linear) and module.bias is not None:
193 | module.bias.data.zero_()
194 | elif isinstance(module, nn.LayerNorm):
195 | module.bias.data.zero_()
196 | module.weight.data.fill_(1.0)
197 |
198 | def configure_optimizers(self, train_config):
199 | """
200 | This long function is unfortunately doing something very simple and is being very defensive:
201 | We are separating out all parameters of the model into two buckets: those that will experience
202 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
203 | We are then returning the PyTorch optimizer object.
204 | """
205 |
206 | # separate out all parameters to those that will and won't experience regularizing weight decay
207 | decay = set()
208 | no_decay = set()
209 | whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv1d)
210 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
211 | for mn, m in self.named_modules():
212 | for pn, p in m.named_parameters():
213 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
214 |
215 | if pn.endswith('bias'):
216 | # all biases will not be decayed
217 | no_decay.add(fpn)
218 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
219 | # weights of whitelist modules will be weight decayed
220 | decay.add(fpn)
221 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
222 | # weights of blacklist modules will NOT be weight decayed
223 | no_decay.add(fpn)
224 |
225 | # special case the position embedding parameter in the root GPT module as not decayed
226 | no_decay.add('pos_emb')
227 |
228 | # validate that we considered every parameter
229 | param_dict = {pn: p for pn, p in self.named_parameters()}
230 | inter_params = decay & no_decay
231 | union_params = decay | no_decay
232 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
233 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
234 | % (str(param_dict.keys() - union_params), )
235 |
236 | # create the pytorch optimizer object
237 | optim_groups = [
238 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
239 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
240 | ]
241 | optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
242 | return optimizer
243 |
244 |
245 | def to_indexes(self, x, mode="uniform"):
246 | """Convert tokens to indexes.
247 |
248 | Args:
249 | x: a Tensor of size (batchsize, seqlen, 4). x has been truncated
250 | to [0,1).
251 | model: currenly only supports "uniform".
252 |
253 | Returns:
254 | idxs: a Tensor (dtype: Long) of indexes.
255 | """
256 | bs, seqlen, data_dim = x.shape
257 | if mode == "uniform":
258 | idxs = (x*self.att_sizes).long()
259 | return idxs, idxs
260 | elif mode in ("freq", "freq_uniform"):
261 |
262 | idxs = (x*self.att_sizes).long()
263 | idxs_uniform = idxs.clone()
264 | discrete_lats, discrete_lons, lat_ids, lon_ids = self.partition_model(x[:,:,:2])
265 | # pdb.set_trace()
266 | idxs[:,:,0] = torch.round(lat_ids.reshape((bs,seqlen))).long()
267 | idxs[:,:,1] = torch.round(lon_ids.reshape((bs,seqlen))).long()
268 | return idxs, idxs_uniform
269 |
270 |
271 | def forward(self, x, masks = None, with_targets=False, return_loss_tuple=False):
272 | """
273 | Args:
274 | x: a Tensor of size (batchsize, seqlen, 4). x has been truncated
275 | to [0,1).
276 | masks: a Tensor of the same size of x. masks[idx] = 0. if
277 | x[idx] is a padding.
278 | with_targets: if True, inputs = x[:,:-1,:], targets = x[:,1:,:],
279 | otherwise inputs = x.
280 | Returns:
281 | logits, loss
282 | """
283 |
284 | if self.mode in ("mlp_pos","mlp",):
285 | idxs, idxs_uniform = x, x # use the real-values of x.
286 | else:
287 | # Convert to indexes
288 | idxs, idxs_uniform = self.to_indexes(x, mode=self.partition_mode)
289 |
290 | if with_targets:
291 | inputs = idxs[:,:-1,:].contiguous()
292 | targets = idxs[:,1:,:].contiguous()
293 | targets_uniform = idxs_uniform[:,1:,:].contiguous()
294 | inputs_real = x[:,:-1,:].contiguous()
295 | targets_real = x[:,1:,:].contiguous()
296 | else:
297 | inputs_real = x
298 | inputs = idxs
299 | targets = None
300 |
301 | batchsize, seqlen, _ = inputs.size()
302 | assert seqlen <= self.max_seqlen, "Cannot forward, model block size is exhausted."
303 |
304 | # forward the GPT model
305 | lat_embeddings = self.lat_emb(inputs[:,:,0]) # (bs, seqlen, lat_size)
306 | lon_embeddings = self.lon_emb(inputs[:,:,1])
307 | sog_embeddings = self.sog_emb(inputs[:,:,2])
308 | cog_embeddings = self.cog_emb(inputs[:,:,3])
309 | token_embeddings = torch.cat((lat_embeddings, lon_embeddings, sog_embeddings, cog_embeddings),dim=-1)
310 |
311 | position_embeddings = self.pos_emb[:, :seqlen, :] # each position maps to a (learnable) vector (1, seqlen, n_embd)
312 | fea = self.drop(token_embeddings + position_embeddings)
313 | fea = self.blocks(fea)
314 | fea = self.ln_f(fea) # (bs, seqlen, n_embd)
315 | logits = self.head(fea) # (bs, seqlen, full_size) or (bs, seqlen, n_embd)
316 |
317 | lat_logits, lon_logits, sog_logits, cog_logits =\
318 | torch.split(logits, (self.lat_size, self.lon_size, self.sog_size, self.cog_size), dim=-1)
319 |
320 | # Calculate the loss
321 | loss = None
322 | loss_tuple = None
323 | if targets is not None:
324 |
325 | sog_loss = F.cross_entropy(sog_logits.view(-1, self.sog_size),
326 | targets[:,:,2].view(-1),
327 | reduction="none").view(batchsize,seqlen)
328 | cog_loss = F.cross_entropy(cog_logits.view(-1, self.cog_size),
329 | targets[:,:,3].view(-1),
330 | reduction="none").view(batchsize,seqlen)
331 | lat_loss = F.cross_entropy(lat_logits.view(-1, self.lat_size),
332 | targets[:,:,0].view(-1),
333 | reduction="none").view(batchsize,seqlen)
334 | lon_loss = F.cross_entropy(lon_logits.view(-1, self.lon_size),
335 | targets[:,:,1].view(-1),
336 | reduction="none").view(batchsize,seqlen)
337 |
338 | if self.blur:
339 | lat_probs = F.softmax(lat_logits, dim=-1)
340 | lon_probs = F.softmax(lon_logits, dim=-1)
341 | sog_probs = F.softmax(sog_logits, dim=-1)
342 | cog_probs = F.softmax(cog_logits, dim=-1)
343 |
344 | for _ in range(self.blur_n):
345 | blurred_lat_probs = self.blur_module(lat_probs.reshape(-1,1,self.lat_size)).reshape(lat_probs.shape)
346 | blurred_lon_probs = self.blur_module(lon_probs.reshape(-1,1,self.lon_size)).reshape(lon_probs.shape)
347 | blurred_sog_probs = self.blur_module(sog_probs.reshape(-1,1,self.sog_size)).reshape(sog_probs.shape)
348 | blurred_cog_probs = self.blur_module(cog_probs.reshape(-1,1,self.cog_size)).reshape(cog_probs.shape)
349 |
350 | blurred_lat_loss = F.nll_loss(blurred_lat_probs.view(-1, self.lat_size),
351 | targets[:,:,0].view(-1),
352 | reduction="none").view(batchsize,seqlen)
353 | blurred_lon_loss = F.nll_loss(blurred_lon_probs.view(-1, self.lon_size),
354 | targets[:,:,1].view(-1),
355 | reduction="none").view(batchsize,seqlen)
356 | blurred_sog_loss = F.nll_loss(blurred_sog_probs.view(-1, self.sog_size),
357 | targets[:,:,2].view(-1),
358 | reduction="none").view(batchsize,seqlen)
359 | blurred_cog_loss = F.nll_loss(blurred_cog_probs.view(-1, self.cog_size),
360 | targets[:,:,3].view(-1),
361 | reduction="none").view(batchsize,seqlen)
362 |
363 | lat_loss += self.blur_loss_w*blurred_lat_loss
364 | lon_loss += self.blur_loss_w*blurred_lon_loss
365 | sog_loss += self.blur_loss_w*blurred_sog_loss
366 | cog_loss += self.blur_loss_w*blurred_cog_loss
367 |
368 | lat_probs = blurred_lat_probs
369 | lon_probs = blurred_lon_probs
370 | sog_probs = blurred_sog_probs
371 | cog_probs = blurred_cog_probs
372 |
373 |
374 | loss_tuple = (lat_loss, lon_loss, sog_loss, cog_loss)
375 | loss = sum(loss_tuple)
376 |
377 | if masks is not None:
378 | loss = (loss*masks).sum(dim=1)/masks.sum(dim=1)
379 |
380 | loss = loss.mean()
381 |
382 | if return_loss_tuple:
383 | return logits, loss, loss_tuple
384 | else:
385 | return logits, loss
386 |
387 |
--------------------------------------------------------------------------------
/S-Transformer/requirements.txt:
--------------------------------------------------------------------------------
1 |
2 | - _libgcc_mutex=0.1
3 | - _py-xgboost-mutex=2.0
4 | - aiohttp=3.7.4.post0
5 | - argon2-cffi=20.1.0
6 | - async-timeout=3.0.1
7 | - async_generator=1.10
8 | - attrs=20.3.0
9 | - blas=1.0
10 | - bleach=3.3.0
11 | - blinker=1.4
12 | - bokeh=2.3.3
13 | - bottleneck=1.3.2
14 | - brotlipy=0.7.0
15 | - c-ares=1.17.1
16 | - ca-certificates=2021.5.30
17 | - catalogue=1.0.0
18 | - catboost=0.26
19 | - certifi=2021.5.30
20 | - cffi=1.14.5
21 | - cftime=1.5.0
22 | - chardet=4.0.0
23 | - cloudpickle=1.6.0
24 | - confuse=1.4.0
25 | - cryptography=3.4.7
26 | - cudatoolkit=9.2
27 | - curl=7.71.1
28 | - cycler=0.10.0
29 | - cymem=2.0.5
30 | - cython-blis=0.7.4
31 | - cytoolz=0.9.0.1
32 | - dask=2021.7.2
33 | - dask-core=2021.7.2
34 | - dataclasses=0.8
35 | - decorator=5.0.5
36 | - defusedxml=0.7.1
37 | - dill=0.2.9
38 | - distributed=2021.7.2
39 | - entrypoints=0.3
40 | - freetype=2.10.4
41 | - fsspec=2021.7.0
42 | - future=0.18.2
43 | - hdf4=4.2.13
44 | - hdf5=1.10.6
45 | - heapdict=1.0.1
46 | - htmlmin=0.1.12
47 | - idna=2.10
48 | - imagehash=4.2.0
49 | - imbalanced-learn=0.8.0
50 | - importlib-metadata=3.7.3
51 | - importlib_metadata=3.7.3
52 | - intel-openmp=2020.2
53 | - ipykernel=5.3.4
54 | - ipython=5.8.0
55 | - ipython_genutils=0.2.0
56 | - ipywidgets=7.5.1
57 | - jinja2=2.11.3
58 | - jpeg=9b
59 | - json5=0.9.5
60 | - jsonschema=3.0.2
61 | - jupyter_client=6.1.12
62 | - jupyter_core=4.7.1
63 | - jupyterlab=2.2.6
64 | - jupyterlab_pygments=0.1.2
65 | - jupyterlab_server=1.2.0
66 | - kiwisolver=1.3.1
67 | - krb5=1.18.2
68 | - lcms2=2.12
69 | - libcurl=7.71.1
70 | - libffi=3.3
71 | - libllvm10=10.0.1
72 | - libnetcdf=4.6.1
73 | - libpng=1.6.37
74 | - libprotobuf=3.17.2
75 | - libsodium=1.0.18
76 | - libssh2=1.9.0
77 | - libtiff=4.1.0
78 | - libxgboost=1.3.3
79 | - lightgbm=3.1.1
80 | - llvmlite=0.36.0
81 | - locket=0.2.0
82 | - lz4-c=1.9.3
83 | - markupsafe=1.1.1
84 | - matplotlib=3.3.2
85 | - matplotlib-base=3.3.2
86 | - missingno=0.4.2
87 | - mistune=0.8.4
88 | - mkl=2020.2
89 | - mkl-service=2.3.0
90 | - mkl_fft=1.3.0
91 | - mkl_random=1.1.1
92 | - msgpack-numpy=0.4.7.1
93 | - msgpack-python=1.0.2
94 | - multidict=5.1.0
95 | - nb_conda=2.2.1
96 | - nb_conda_kernels=2.3.1
97 | - nbclient=0.5.3
98 | - nbconvert=6.0.7
99 | - nbformat=5.1.3
100 | - nest-asyncio=1.5.1
101 | - netcdf4=1.5.7
102 | - networkx=2.5
103 | - ninja=1.10.2
104 | - notebook=6.3.0
105 | - numba=0.53.1
106 | - numpy=1.19.2
107 | - numpy-base=1.19.2
108 | - olefile=0.46
109 | - openssl=1.1.1k
110 | - packaging=20.9
111 | - pandas=1.2.3
112 | - pandas-profiling=2.9.0
113 | - pandoc=2.12
114 | - pandocfilters=1.4.3
115 | - partd=1.2.0
116 | - pexpect=4.8.0
117 | - phik=0.11.2
118 | - pickleshare=0.7.5
119 | - pillow=8.2.0
120 | - pip=21.0.1
121 | - plac=1.1.0
122 | - preshed=3.0.2
123 | - proj=7.0.1
124 | - prometheus_client=0.10.0
125 | - prompt_toolkit=1.0.15
126 | - psutil=5.8.0
127 | - ptyprocess=0.7.0
128 | - pyasn1=0.4.8
129 | - pycparser=2.20
130 | - pydeprecate=0.3.1
131 | - pygments=2.8.1
132 | - pyjwt=2.1.0
133 | - pyopenssl=20.0.1
134 | - pyparsing=2.4.7
135 | - pyproj=2.6.1.post1
136 | - pyrsistent=0.17.3
137 | - pysocks=1.7.1
138 | - python=3.7.9
139 | - python-dateutil=2.8.1
140 | - python_abi=3.7
141 | - pytorch=1.6.0
142 | - pytorch-lightning=1.4.4
143 | - pytz=2021.1
144 | - pyu2f=0.1.5
145 | - pywavelets=1.1.1
146 | - pyzmq=20.0.0
147 | - requests=2.25.1
148 | - requests-oauthlib=1.3.0
149 | - seaborn=0.11.1
150 | - send2trash=1.5.0
151 | - setuptools=52.0.0
152 | - simplegeneric=0.8.1
153 | - six=1.15.0
154 | - sortedcontainers=2.4.0
155 | - spacy=2.3.5
156 | - sqlite=3.35.4
157 | - srsly=1.0.5
158 | - tangled-up-in-unicode=0.1.0
159 | - tbb=2020.3
160 | - tblib=1.7.0
161 | - tensorboard-data-server=0.6.0
162 | - terminado=0.9.4
163 | - testpath=0.4.4
164 | - thinc=7.4.5
165 | - threadpoolctl=2.1.0
166 | - tk=8.6.10
167 | - toolz=0.11.1
168 | - torchmetrics=0.5.0
169 | - torchtext=0.7.0
170 | - torchvision=0.7.0
171 | - tornado=6.1
172 | - traitlets=5.0.5
173 | - typing-extensions=3.10.0.0
174 | - typing_extensions=3.10.0.0
175 | - ujson=4.0.2
176 | - visions=0.5.0
177 | - wasabi=0.8.2
178 | - wcwidth=0.2.5
179 | - webencodings=0.5.1
180 | - wheel=0.36.2
181 | - widgetsnbextension=3.5.1
182 | - wrapt=1.12.1
183 | - xarray=0.19.0
184 | - xgboost=1.3.3
185 | - xz=5.2.5
186 | - yaml=0.2.5
187 | - yarl=1.6.3
188 | - zeromq=4.3.4
189 | - zict=2.0.0
190 | - zlib=1.2.11
191 | - zstd=1.4.9
192 | - pip:
193 | - readline==8.1
194 | - ncurses==6.2
195 | - libgomp==9.3.0
196 | - libstdcxx-ng==9.1.0
197 | - libgcc-ng==9.3.0
198 | - libgfortran-ng==7.3.0
199 | - py-xgboost==1.3.30
200 | - ld_impl_linux-64==2.33.1
201 | - murmurhash==1.0.50
202 | - libedit==3.1.20210216
203 | - _openmp_mutex==4.5
204 | - absl-py==0.10.0
205 | - addict==2.3.0
206 | - aiohttp-cors==0.7.0
207 | - aioredis==1.3.1
208 | - blessings==1.7
209 | - boltons==20.2.1
210 | - cachetools==4.1.1
211 | - click==7.1.2
212 | - colorama==0.4.4
213 | - colorful==0.5.4
214 | - de-core-news-sm==2.0.0
215 | - dm-tree==0.1.5
216 | - einops==0.3.0
217 | - en-core-web-sm==2.0.0
218 | - filelock==3.0.12
219 | - fire==0.4.0
220 | - google-api-core==1.26.3
221 | - google-auth==1.22.1
222 | - google-auth-oauthlib==0.4.1
223 | - googleapis-common-protos==1.53.0
224 | - gpustat==0.6.0
225 | - grpcio==1.32.0
226 | - hiredis==2.0.0
227 | - joblib==0.17.0
228 | - lambda-networks==0.4.0
229 | - markdown==3.3
230 | - nltk==3.5
231 | - nvidia-ml-py3==7.352.0
232 | - oauthlib==3.1.0
233 | - opencensus==0.7.12
234 | - opencensus-context==0.1.2
235 | - opencv-python==4.4.0.44
236 | - pathspec==0.8.0
237 | - protobuf==3.13.0
238 | - py-spy==0.3.5
239 | - pyasn1-modules==0.2.8
240 | - pyyaml==5.4.1
241 | - ray==1.2.0
242 | - redis==3.5.3
243 | - regex==2020.11.13
244 | - rsa==4.6
245 | - scikit-learn==0.23.2
246 | - scipy==1.5.2
247 | - sdeint==0.2.1
248 | - sklearn==0.0
249 | - tabulate==0.8.9
250 | - tensorboard==2.3.0
251 | - tensorboard-plugin-wit==1.7.0
252 | - termcolor==1.1.0
253 | - torchsde==0.2.1
254 | - tqdm==4.50.2
255 | - trampoline==0.1.2
256 | - urllib3==1.25.10
257 | - werkzeug==1.0.1
258 | - zipp==3.3.0
259 |
260 |
--------------------------------------------------------------------------------
/S-Transformer/requirements.yml:
--------------------------------------------------------------------------------
1 | name: TrAISS
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1
8 | - _py-xgboost-mutex=2.0
9 | - aiohttp=3.7.4.post0
10 | - argon2-cffi=20.1.0
11 | - async-timeout=3.0.1
12 | - async_generator=1.10
13 | - attrs=20.3.0
14 | - blas=1.0
15 | - bleach=3.3.0
16 | - blinker=1.4
17 | - bokeh=2.3.3
18 | - bottleneck=1.3.2
19 | - brotlipy=0.7.0
20 | - c-ares=1.17.1
21 | - ca-certificates=2021.5.30
22 | - catalogue=1.0.0
23 | - catboost=0.26
24 | - certifi=2021.5.30
25 | - cffi=1.14.5
26 | - cftime=1.5.0
27 | - chardet=4.0.0
28 | - cloudpickle=1.6.0
29 | - confuse=1.4.0
30 | - cryptography=3.4.7
31 | - cudatoolkit=9.2
32 | - curl=7.71.1
33 | - cycler=0.10.0
34 | - cymem=2.0.5
35 | - cython-blis=0.7.4
36 | - cytoolz=0.9.0.1
37 | - dask=2021.7.2
38 | - dask-core=2021.7.2
39 | - dataclasses=0.8
40 | - decorator=5.0.5
41 | - defusedxml=0.7.1
42 | - dill=0.2.9
43 | - distributed=2021.7.2
44 | - entrypoints=0.3
45 | - freetype=2.10.4
46 | - fsspec=2021.7.0
47 | - future=0.18.2
48 | - hdf4=4.2.13
49 | - hdf5=1.10.6
50 | - heapdict=1.0.1
51 | - htmlmin=0.1.12
52 | - idna=2.10
53 | - imagehash=4.2.0
54 | - imbalanced-learn=0.8.0
55 | - importlib-metadata=3.7.3
56 | - importlib_metadata=3.7.3
57 | - intel-openmp=2020.2
58 | - ipykernel=5.3.4
59 | - ipython=5.8.0
60 | - ipython_genutils=0.2.0
61 | - ipywidgets=7.5.1
62 | - jinja2=2.11.3
63 | - jpeg=9b
64 | - json5=0.9.5
65 | - jsonschema=3.0.2
66 | - jupyter_client=6.1.12
67 | - jupyter_core=4.7.1
68 | - jupyterlab=2.2.6
69 | - jupyterlab_pygments=0.1.2
70 | - jupyterlab_server=1.2.0
71 | - kiwisolver=1.3.1
72 | - krb5=1.18.2
73 | - lcms2=2.12
74 | - libcurl=7.71.1
75 | - libffi=3.3
76 | - libllvm10=10.0.1
77 | - libnetcdf=4.6.1
78 | - libpng=1.6.37
79 | - libprotobuf=3.17.2
80 | - libsodium=1.0.18
81 | - libssh2=1.9.0
82 | - libtiff=4.1.0
83 | - libxgboost=1.3.3
84 | - lightgbm=3.1.1
85 | - llvmlite=0.36.0
86 | - locket=0.2.0
87 | - lz4-c=1.9.3
88 | - markupsafe=1.1.1
89 | - matplotlib=3.3.2
90 | - matplotlib-base=3.3.2
91 | - missingno=0.4.2
92 | - mistune=0.8.4
93 | - mkl=2020.2
94 | - mkl-service=2.3.0
95 | - mkl_fft=1.3.0
96 | - mkl_random=1.1.1
97 | - msgpack-numpy=0.4.7.1
98 | - msgpack-python=1.0.2
99 | - multidict=5.1.0
100 | - nb_conda=2.2.1
101 | - nb_conda_kernels=2.3.1
102 | - nbclient=0.5.3
103 | - nbconvert=6.0.7
104 | - nbformat=5.1.3
105 | - nest-asyncio=1.5.1
106 | - netcdf4=1.5.7
107 | - networkx=2.5
108 | - ninja=1.10.2
109 | - notebook=6.3.0
110 | - numba=0.53.1
111 | - numpy=1.19.2
112 | - numpy-base=1.19.2
113 | - olefile=0.46
114 | - openssl=1.1.1k
115 | - packaging=20.9
116 | - pandas=1.2.3
117 | - pandas-profiling=2.9.0
118 | - pandoc=2.12
119 | - pandocfilters=1.4.3
120 | - partd=1.2.0
121 | - pexpect=4.8.0
122 | - phik=0.11.2
123 | - pickleshare=0.7.5
124 | - pillow=8.2.0
125 | - pip=21.0.1
126 | - plac=1.1.0
127 | - preshed=3.0.2
128 | - proj=7.0.1
129 | - prometheus_client=0.10.0
130 | - prompt_toolkit=1.0.15
131 | - psutil=5.8.0
132 | - ptyprocess=0.7.0
133 | - pyasn1=0.4.8
134 | - pycparser=2.20
135 | - pydeprecate=0.3.1
136 | - pygments=2.8.1
137 | - pyjwt=2.1.0
138 | - pyopenssl=20.0.1
139 | - pyparsing=2.4.7
140 | - pyproj=2.6.1.post1
141 | - pyrsistent=0.17.3
142 | - pysocks=1.7.1
143 | - python=3.7.9
144 | - python-dateutil=2.8.1
145 | - python_abi=3.7
146 | - pytorch=1.6.0
147 | - pytorch-lightning=1.4.4
148 | - pytz=2021.1
149 | - pyu2f=0.1.5
150 | - pywavelets=1.1.1
151 | - pyzmq=20.0.0
152 | - requests=2.25.1
153 | - requests-oauthlib=1.3.0
154 | - seaborn=0.11.1
155 | - send2trash=1.5.0
156 | - setuptools=52.0.0
157 | - simplegeneric=0.8.1
158 | - six=1.15.0
159 | - sortedcontainers=2.4.0
160 | - spacy=2.3.5
161 | - sqlite=3.35.4
162 | - srsly=1.0.5
163 | - tangled-up-in-unicode=0.1.0
164 | - tbb=2020.3
165 | - tblib=1.7.0
166 | - tensorboard-data-server=0.6.0
167 | - terminado=0.9.4
168 | - testpath=0.4.4
169 | - thinc=7.4.5
170 | - threadpoolctl=2.1.0
171 | - tk=8.6.10
172 | - toolz=0.11.1
173 | - torchmetrics=0.5.0
174 | - torchtext=0.7.0
175 | - torchvision=0.7.0
176 | - tornado=6.1
177 | - traitlets=5.0.5
178 | - typing-extensions=3.10.0.0
179 | - typing_extensions=3.10.0.0
180 | - ujson=4.0.2
181 | - visions=0.5.0
182 | - wasabi=0.8.2
183 | - wcwidth=0.2.5
184 | - webencodings=0.5.1
185 | - wheel=0.36.2
186 | - widgetsnbextension=3.5.1
187 | - wrapt=1.12.1
188 | - xarray=0.19.0
189 | - xgboost=1.3.3
190 | - xz=5.2.5
191 | - yaml=0.2.5
192 | - yarl=1.6.3
193 | - zeromq=4.3.4
194 | - zict=2.0.0
195 | - zlib=1.2.11
196 | - zstd=1.4.9
197 | - pip:
198 | - readline==8.1
199 | - ncurses==6.2
200 | - libgomp==9.3.0
201 | - libstdcxx-ng==9.1.0
202 | - libgcc-ng==9.3.0
203 | - libgfortran-ng==7.3.0
204 | - py-xgboost==1.3.30
205 | - ld_impl_linux-64==2.33.1
206 | - murmurhash==1.0.50
207 | - libedit==3.1.20210216
208 | - _openmp_mutex==4.5
209 | - absl-py==0.10.0
210 | - addict==2.3.0
211 | - aiohttp-cors==0.7.0
212 | - aioredis==1.3.1
213 | - blessings==1.7
214 | - boltons==20.2.1
215 | - cachetools==4.1.1
216 | - click==7.1.2
217 | - colorama==0.4.4
218 | - colorful==0.5.4
219 | - de-core-news-sm==2.0.0
220 | - dm-tree==0.1.5
221 | - einops==0.3.0
222 | - en-core-web-sm==2.0.0
223 | - filelock==3.0.12
224 | - fire==0.4.0
225 | - google-api-core==1.26.3
226 | - google-auth==1.22.1
227 | - google-auth-oauthlib==0.4.1
228 | - googleapis-common-protos==1.53.0
229 | - gpustat==0.6.0
230 | - grpcio==1.32.0
231 | - hiredis==2.0.0
232 | - joblib==0.17.0
233 | - lambda-networks==0.4.0
234 | - markdown==3.3
235 | - nltk==3.5
236 | - nvidia-ml-py3==7.352.0
237 | - oauthlib==3.1.0
238 | - opencensus==0.7.12
239 | - opencensus-context==0.1.2
240 | - opencv-python==4.4.0.44
241 | - pathspec==0.8.0
242 | - protobuf==3.13.0
243 | - py-spy==0.3.5
244 | - pyasn1-modules==0.2.8
245 | - pyyaml==5.4.1
246 | - ray==1.2.0
247 | - redis==3.5.3
248 | - regex==2020.11.13
249 | - rsa==4.6
250 | - scikit-learn==0.23.2
251 | - scipy==1.5.2
252 | - sdeint==0.2.1
253 | - sklearn==0.0
254 | - tabulate==0.8.9
255 | - tensorboard==2.3.0
256 | - tensorboard-plugin-wit==1.7.0
257 | - termcolor==1.1.0
258 | - torchsde==0.2.1
259 | - tqdm==4.50.2
260 | - trampoline==0.1.2
261 | - urllib3==1.25.10
262 | - werkzeug==1.0.1
263 | - zipp==3.3.0
264 |
265 |
--------------------------------------------------------------------------------
/S-Transformer/trAISformer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | # coding=utf-8
4 | # Copyright 2021, Duong Nguyen
5 | #
6 | # Licensed under the CECILL-C License;
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.cecill.info
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | """Pytorch implementation of TrAISformer---A generative transformer for
19 | AIS trajectory prediction
20 |
21 | https://arxiv.org/abs/2109.03958
22 |
23 | """
24 | import numpy as np
25 | from numpy import linalg
26 | import matplotlib.pyplot as plt
27 | import os
28 | import sys
29 | import pickle
30 | from tqdm import tqdm
31 | import math
32 | import logging
33 | import pdb
34 |
35 | import torch
36 | import torch.nn as nn
37 | from torch.nn import functional as F
38 | import torch.optim as optim
39 | from torch.optim.lr_scheduler import LambdaLR
40 | from torch.utils.data import Dataset, DataLoader
41 |
42 | import models, trainers, datasets, utils
43 | from config_trAISformer import Config
44 |
45 | cf = Config()
46 | TB_LOG = cf.tb_log
47 | if TB_LOG:
48 | from torch.utils.tensorboard import SummaryWriter
49 |
50 | tb = SummaryWriter()
51 |
52 | # make deterministic
53 | utils.set_seed(42)
54 | torch.pi = torch.acos(torch.zeros(1)).item() * 2
55 |
56 | if __name__ == "__main__":
57 |
58 | device = cf.device
59 | init_seqlen = cf.init_seqlen
60 |
61 | ## Logging
62 | # ===============================
63 | if not os.path.isdir(cf.savedir):
64 | os.makedirs(cf.savedir)
65 | print('======= Create directory to store trained models: ' + cf.savedir)
66 | else:
67 | print('======= Directory to store trained models: ' + cf.savedir)
68 | utils.new_log(cf.savedir, "log")
69 |
70 | ## Data
71 | # ===============================
72 | moving_threshold = 0.05
73 | l_pkl_filenames = [cf.trainset_name, cf.validset_name, cf.testset_name]
74 | Data, aisdatasets, aisdls = {}, {}, {}
75 | for phase, filename in zip(("train", "valid", "test"), l_pkl_filenames):
76 | datapath = os.path.join(cf.datadir, filename)
77 | print(f"Loading {datapath}...")
78 | with open(datapath, "rb") as f:
79 | l_pred_errors = pickle.load(f)
80 | for V in l_pred_errors:
81 | try:
82 | moving_idx = np.where(V["traj"][:, 2] > moving_threshold)[0][0]
83 | except:
84 | moving_idx = len(V["traj"]) - 1 # This track will be removed
85 | V["traj"] = V["traj"][moving_idx:, :]
86 | Data[phase] = [x for x in l_pred_errors if not np.isnan(x["traj"]).any() and len(x["traj"]) > cf.min_seqlen]
87 | print(len(l_pred_errors), len(Data[phase]))
88 | print(f"Length: {len(Data[phase])}")
89 | print("Creating pytorch dataset...")
90 | # Latter in this scipt, we will use inputs = x[:-1], targets = x[1:], hence
91 | # max_seqlen = cf.max_seqlen + 1.
92 | if cf.mode in ("pos_grad", "grad"):
93 | aisdatasets[phase] = datasets.AISDataset_grad(Data[phase],
94 | max_seqlen=cf.max_seqlen + 1,
95 | device=cf.device)
96 | else:
97 | aisdatasets[phase] = datasets.AISDataset(Data[phase],
98 | max_seqlen=cf.max_seqlen + 1,
99 | device=cf.device)
100 | if phase == "test":
101 | shuffle = False
102 | else:
103 | shuffle = True
104 | aisdls[phase] = DataLoader(aisdatasets[phase],
105 | batch_size=cf.batch_size,
106 | shuffle=shuffle)
107 | cf.final_tokens = 2 * len(aisdatasets["train"]) * cf.max_seqlen
108 |
109 | ## Model
110 | # ===============================
111 | model = models.TrAISformer(cf, partition_model=None)
112 |
113 | ## Trainer
114 | # ===============================
115 | trainer = trainers.Trainer(
116 | model, aisdatasets["train"], aisdatasets["valid"], cf, savedir=cf.savedir, device=cf.device, aisdls=aisdls, INIT_SEQLEN=init_seqlen)
117 |
118 | ## Training
119 | # ===============================
120 | if cf.retrain:
121 | trainer.train()
122 |
123 | ## Evaluation
124 | # ===============================
125 | # Load the best model
126 | model.load_state_dict(torch.load(cf.ckpt_path))
127 |
128 | v_ranges = torch.tensor([2, 3, 0, 0]).to(cf.device)
129 | v_roi_min = torch.tensor([model.lat_min, -7, 0, 0]).to(cf.device)
130 | max_seqlen = init_seqlen + 6 * 4
131 |
132 | model.eval()
133 | l_min_errors, l_mean_errors, l_masks = [], [], []
134 | pbar = tqdm(enumerate(aisdls["test"]), total=len(aisdls["test"]))
135 | with torch.no_grad():
136 | for it, (seqs, masks, seqlens, mmsis, time_starts) in pbar:
137 | seqs_init = seqs[:, :init_seqlen, :].to(cf.device)
138 | masks = masks[:, :max_seqlen].to(cf.device)
139 | batchsize = seqs.shape[0]
140 | error_ens = torch.zeros((batchsize, max_seqlen - cf.init_seqlen, cf.n_samples)).to(cf.device)
141 | for i_sample in range(cf.n_samples):
142 | preds = trainers.sample(model,
143 | seqs_init,
144 | max_seqlen - init_seqlen,
145 | temperature=1.0,
146 | sample=True,
147 | sample_mode=cf.sample_mode,
148 | r_vicinity=cf.r_vicinity,
149 | top_k=cf.top_k)
150 | inputs = seqs[:, :max_seqlen, :].to(cf.device)
151 | input_coords = (inputs * v_ranges + v_roi_min) * torch.pi / 180
152 | pred_coords = (preds * v_ranges + v_roi_min) * torch.pi / 180
153 | d = utils.haversine(input_coords, pred_coords) * masks
154 | error_ens[:, :, i_sample] = d[:, cf.init_seqlen:]
155 | # Accumulation through batches
156 | l_min_errors.append(error_ens.min(dim=-1))
157 | l_mean_errors.append(error_ens.mean(dim=-1))
158 | l_masks.append(masks[:, cf.init_seqlen:])
159 |
160 | l_min = [x.values for x in l_min_errors]
161 | m_masks = torch.cat(l_masks, dim=0)
162 | min_errors = torch.cat(l_min, dim=0) * m_masks
163 | pred_errors = min_errors.sum(dim=0) / m_masks.sum(dim=0)
164 | pred_errors = pred_errors.detach().cpu().numpy()
165 |
166 | ## Plot
167 | # ===============================
168 | plt.figure(figsize=(9, 6), dpi=150)
169 | v_times = np.arange(len(pred_errors)) / 6
170 | plt.plot(v_times, pred_errors)
171 |
172 | timestep = 6
173 | plt.plot(1, pred_errors[timestep], "o")
174 | plt.plot([1, 1], [0, pred_errors[timestep]], "r")
175 | plt.plot([0, 1], [pred_errors[timestep], pred_errors[timestep]], "r")
176 | plt.text(1.12, pred_errors[timestep] - 0.5, "{:.4f}".format(pred_errors[timestep]), fontsize=10)
177 |
178 | timestep = 12
179 | plt.plot(2, pred_errors[timestep], "o")
180 | plt.plot([2, 2], [0, pred_errors[timestep]], "r")
181 | plt.plot([0, 2], [pred_errors[timestep], pred_errors[timestep]], "r")
182 | plt.text(2.12, pred_errors[timestep] - 0.5, "{:.4f}".format(pred_errors[timestep]), fontsize=10)
183 |
184 | timestep = 18
185 | plt.plot(3, pred_errors[timestep], "o")
186 | plt.plot([3, 3], [0, pred_errors[timestep]], "r")
187 | plt.plot([0, 3], [pred_errors[timestep], pred_errors[timestep]], "r")
188 | plt.text(3.12, pred_errors[timestep] - 0.5, "{:.4f}".format(pred_errors[timestep]), fontsize=10)
189 | plt.xlabel("Time (hours)")
190 | plt.ylabel("Prediction errors (km)")
191 | plt.xlim([0, 12])
192 | plt.ylim([0, 20])
193 | # plt.ylim([0,pred_errors.max()+0.5])
194 | plt.savefig(cf.savedir + "prediction_error.png")
195 |
196 | # Yeah, done!!!
197 |
--------------------------------------------------------------------------------
/S-Transformer/trainers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021, Duong Nguyen
3 | #
4 | # Licensed under the CECILL-C License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.cecill.info
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Boilerplate for training a neural network.
17 |
18 | References:
19 | https://github.com/karpathy/minGPT
20 | """
21 |
22 | import os
23 | import math
24 | import logging
25 |
26 | from tqdm import tqdm
27 | import numpy as np
28 | import matplotlib.pyplot as plt
29 |
30 | import torch
31 | import torch.optim as optim
32 | from torch.optim.lr_scheduler import LambdaLR
33 | from torch.utils.data.dataloader import DataLoader
34 | from torch.nn import functional as F
35 | import utils
36 |
37 | from trAISformer import TB_LOG
38 |
39 | logger = logging.getLogger(__name__)
40 |
41 |
42 | @torch.no_grad()
43 | def sample(model,
44 | seqs,
45 | steps,
46 | temperature=1.0,
47 | sample=False,
48 | sample_mode="pos_vicinity",
49 | r_vicinity=20,
50 | top_k=None):
51 | """
52 | Take a conditoning sequence of AIS observations seq and predict the next observation,
53 | feed the predictions back into the model each time.
54 | """
55 | max_seqlen = model.get_max_seqlen()
56 | model.eval()
57 | for k in range(steps):
58 | seqs_cond = seqs if seqs.size(1) <= max_seqlen else seqs[:, -max_seqlen:] # crop context if needed
59 |
60 | # logits.shape: (batch_size, seq_len, data_size)
61 | logits, _ = model(seqs_cond)
62 | d2inf_pred = torch.zeros((logits.shape[0], 4)).to(seqs.device) + 0.5
63 |
64 | # pluck the logits at the final step and scale by temperature
65 | logits = logits[:, -1, :] / temperature # (batch_size, data_size)
66 |
67 | lat_logits, lon_logits, sog_logits, cog_logits = \
68 | torch.split(logits, (model.lat_size, model.lon_size, model.sog_size, model.cog_size), dim=-1)
69 |
70 | # optionally crop probabilities to only the top k options
71 | if sample_mode in ("pos_vicinity",):
72 | idxs, idxs_uniform = model.to_indexes(seqs_cond[:, -1:, :])
73 | lat_idxs, lon_idxs = idxs_uniform[:, 0, 0:1], idxs_uniform[:, 0, 1:2]
74 | lat_logits = utils.top_k_nearest_idx(lat_logits, lat_idxs, r_vicinity)
75 | lon_logits = utils.top_k_nearest_idx(lon_logits, lon_idxs, r_vicinity)
76 |
77 | if top_k is not None:
78 | lat_logits = utils.top_k_logits(lat_logits, top_k)
79 | lon_logits = utils.top_k_logits(lon_logits, top_k)
80 | sog_logits = utils.top_k_logits(sog_logits, top_k)
81 | cog_logits = utils.top_k_logits(cog_logits, top_k)
82 |
83 | # apply softmax to convert to probabilities
84 | lat_probs = F.softmax(lat_logits, dim=-1)
85 | lon_probs = F.softmax(lon_logits, dim=-1)
86 | sog_probs = F.softmax(sog_logits, dim=-1)
87 | cog_probs = F.softmax(cog_logits, dim=-1)
88 |
89 | # sample from the distribution or take the most likely
90 | if sample:
91 | lat_ix = torch.multinomial(lat_probs, num_samples=1) # (batch_size, 1)
92 | lon_ix = torch.multinomial(lon_probs, num_samples=1)
93 | sog_ix = torch.multinomial(sog_probs, num_samples=1)
94 | cog_ix = torch.multinomial(cog_probs, num_samples=1)
95 | else:
96 | _, lat_ix = torch.topk(lat_probs, k=1, dim=-1)
97 | _, lon_ix = torch.topk(lon_probs, k=1, dim=-1)
98 | _, sog_ix = torch.topk(sog_probs, k=1, dim=-1)
99 | _, cog_ix = torch.topk(cog_probs, k=1, dim=-1)
100 |
101 | ix = torch.cat((lat_ix, lon_ix, sog_ix, cog_ix), dim=-1)
102 | # convert to x (range: [0,1))
103 | x_sample = (ix.float() + d2inf_pred) / model.att_sizes
104 |
105 | # append to the sequence and continue
106 | seqs = torch.cat((seqs, x_sample.unsqueeze(1)), dim=1)
107 |
108 | return seqs
109 |
110 |
111 | class TrainerConfig:
112 | # optimization parameters
113 | max_epochs = 10
114 | batch_size = 64
115 | learning_rate = 3e-4
116 | betas = (0.9, 0.95)
117 | grad_norm_clip = 1.0
118 | weight_decay = 0.1 # only applied on matmul weights
119 | # learning rate decay params: linear warmup followed by cosine decay to 10% of original
120 | lr_decay = False
121 | warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
122 | final_tokens = 260e9 # (at what point we reach 10% of original LR)
123 | # checkpoint settings
124 | ckpt_path = None
125 | num_workers = 0 # for DataLoader
126 |
127 | def __init__(self, **kwargs):
128 | for k, v in kwargs.items():
129 | setattr(self, k, v)
130 |
131 |
132 | class Trainer:
133 |
134 | def __init__(self, model, train_dataset, test_dataset, config, savedir=None, device=torch.device("cpu"), aisdls={},
135 | INIT_SEQLEN=0):
136 | self.train_dataset = train_dataset
137 | self.test_dataset = test_dataset
138 | self.config = config
139 | self.savedir = savedir
140 |
141 | self.device = device
142 | self.model = model.to(device)
143 | self.aisdls = aisdls
144 | self.INIT_SEQLEN = INIT_SEQLEN
145 |
146 | def save_checkpoint(self, best_epoch):
147 | # DataParallel wrappers keep raw model object in .module attribute
148 | raw_model = self.model.module if hasattr(self.model, "module") else self.model
149 | # logging.info("saving %s", self.config.ckpt_path)
150 | logging.info(f"Best epoch: {best_epoch:03d}, saving model to {self.config.ckpt_path}")
151 | torch.save(raw_model.state_dict(), self.config.ckpt_path)
152 |
153 | def train(self):
154 | model, config, aisdls, INIT_SEQLEN, = self.model, self.config, self.aisdls, self.INIT_SEQLEN
155 | raw_model = model.module if hasattr(self.model, "module") else model
156 | optimizer = raw_model.configure_optimizers(config)
157 | if model.mode in ("gridcont_gridsin", "gridcont_gridsigmoid", "gridcont2_gridsigmoid",):
158 | return_loss_tuple = True
159 | else:
160 | return_loss_tuple = False
161 |
162 | def run_epoch(split, epoch=0):
163 | is_train = split == 'Training'
164 | model.train(is_train)
165 | data = self.train_dataset if is_train else self.test_dataset
166 | loader = DataLoader(data, shuffle=True, pin_memory=True,
167 | batch_size=config.batch_size,
168 | num_workers=config.num_workers)
169 |
170 | losses = []
171 | n_batches = len(loader)
172 | pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
173 | d_loss, d_reg_loss, d_n = 0, 0, 0
174 | for it, (seqs, masks, seqlens, mmsis, time_starts) in pbar:
175 |
176 | # place data on the correct device
177 | seqs = seqs.to(self.device)
178 | masks = masks[:, :-1].to(self.device)
179 |
180 | # forward the model
181 | with torch.set_grad_enabled(is_train):
182 | if return_loss_tuple:
183 | logits, loss, loss_tuple = model(seqs,
184 | masks=masks,
185 | with_targets=True,
186 | return_loss_tuple=return_loss_tuple)
187 | else:
188 | logits, loss = model(seqs, masks=masks, with_targets=True)
189 | loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
190 | losses.append(loss.item())
191 |
192 | d_loss += loss.item() * seqs.shape[0]
193 | if return_loss_tuple:
194 | reg_loss = loss_tuple[-1]
195 | reg_loss = reg_loss.mean()
196 | d_reg_loss += reg_loss.item() * seqs.shape[0]
197 | d_n += seqs.shape[0]
198 | if is_train:
199 |
200 | # backprop and update the parameters
201 | model.zero_grad()
202 | loss.backward()
203 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
204 | optimizer.step()
205 |
206 | # decay the learning rate based on our progress
207 | if config.lr_decay:
208 | self.tokens += (
209 | seqs >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
210 | if self.tokens < config.warmup_tokens:
211 | # linear warmup
212 | lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
213 | else:
214 | # cosine learning rate decay
215 | progress = float(self.tokens - config.warmup_tokens) / float(
216 | max(1, config.final_tokens - config.warmup_tokens))
217 | lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
218 | lr = config.learning_rate * lr_mult
219 | for param_group in optimizer.param_groups:
220 | param_group['lr'] = lr
221 | else:
222 | lr = config.learning_rate
223 |
224 | # report progress
225 | pbar.set_description(f"epoch {epoch + 1} iter {it}: loss {loss.item():.5f}. lr {lr:e}")
226 |
227 | # tb logging
228 | if TB_LOG:
229 | tb.add_scalar("loss",
230 | loss.item(),
231 | epoch * n_batches + it)
232 | tb.add_scalar("lr",
233 | lr,
234 | epoch * n_batches + it)
235 |
236 | for name, params in model.head.named_parameters():
237 | tb.add_histogram(f"head.{name}", params, epoch * n_batches + it)
238 | tb.add_histogram(f"head.{name}.grad", params.grad, epoch * n_batches + it)
239 | if model.mode in ("gridcont_real",):
240 | for name, params in model.res_pred.named_parameters():
241 | tb.add_histogram(f"res_pred.{name}", params, epoch * n_batches + it)
242 | tb.add_histogram(f"res_pred.{name}.grad", params.grad, epoch * n_batches + it)
243 |
244 | if is_train:
245 | if return_loss_tuple:
246 | logging.info(
247 | f"{split}, epoch {epoch + 1}, loss {d_loss / d_n:.5f}, {d_reg_loss / d_n:.5f}, lr {lr:e}.")
248 | else:
249 | logging.info(f"{split}, epoch {epoch + 1}, loss {d_loss / d_n:.5f}, lr {lr:e}.")
250 | else:
251 | if return_loss_tuple:
252 | logging.info(f"{split}, epoch {epoch + 1}, loss {d_loss / d_n:.5f}.")
253 | else:
254 | logging.info(f"{split}, epoch {epoch + 1}, loss {d_loss / d_n:.5f}.")
255 |
256 | if not is_train:
257 | test_loss = float(np.mean(losses))
258 | # logging.info("test loss: %f", test_loss)
259 | return test_loss
260 |
261 | best_loss = float('inf')
262 | self.tokens = 0 # counter used for learning rate decay
263 | best_epoch = 0
264 |
265 | for epoch in range(config.max_epochs):
266 |
267 | run_epoch('Training', epoch=epoch)
268 | if self.test_dataset is not None:
269 | test_loss = run_epoch('Valid', epoch=epoch)
270 |
271 | # supports early stopping based on the test loss, or just save always if no test set is provided
272 | good_model = self.test_dataset is None or test_loss < best_loss
273 | if self.config.ckpt_path is not None and good_model:
274 | best_loss = test_loss
275 | best_epoch = epoch
276 | self.save_checkpoint(best_epoch + 1)
277 |
278 | ## SAMPLE AND PLOT
279 | # ==========================================================================================
280 | # ==========================================================================================
281 | raw_model = model.module if hasattr(self.model, "module") else model
282 | seqs, masks, seqlens, mmsis, time_starts = iter(aisdls["test"]).next()
283 | n_plots = 7
284 | init_seqlen = INIT_SEQLEN
285 | seqs_init = seqs[:n_plots, :init_seqlen, :].to(self.device)
286 | preds = sample(raw_model,
287 | seqs_init,
288 | 96 - init_seqlen,
289 | temperature=1.0,
290 | sample=True,
291 | sample_mode=self.config.sample_mode,
292 | r_vicinity=self.config.r_vicinity,
293 | top_k=self.config.top_k)
294 |
295 | img_path = os.path.join(self.savedir, f'epoch_{epoch + 1:03d}.jpg')
296 | plt.figure(figsize=(9, 6), dpi=150)
297 | cmap = plt.cm.get_cmap("jet")
298 | preds_np = preds.detach().cpu().numpy()
299 | inputs_np = seqs.detach().cpu().numpy()
300 | for idx in range(n_plots):
301 | c = cmap(float(idx) / (n_plots))
302 | try:
303 | seqlen = seqlens[idx].item()
304 | except:
305 | continue
306 | plt.plot(inputs_np[idx][:init_seqlen, 1], inputs_np[idx][:init_seqlen, 0], color=c)
307 | plt.plot(inputs_np[idx][:init_seqlen, 1], inputs_np[idx][:init_seqlen, 0], "o", markersize=3, color=c)
308 | plt.plot(inputs_np[idx][:seqlen, 1], inputs_np[idx][:seqlen, 0], linestyle="-.", color=c)
309 | plt.plot(preds_np[idx][init_seqlen:, 1], preds_np[idx][init_seqlen:, 0], "x", markersize=4, color=c)
310 | plt.xlim([-0.05, 1.05])
311 | plt.ylim([-0.05, 1.05])
312 | plt.savefig(img_path, dpi=150)
313 | plt.close()
314 |
315 | # Final state
316 | raw_model = self.model.module if hasattr(self.model, "module") else self.model
317 | # logging.info("saving %s", self.config.ckpt_path)
318 | logging.info(f"Last epoch: {epoch:03d}, saving model to {self.config.ckpt_path}")
319 | save_path = self.config.ckpt_path.replace("model.pt", f"model_{epoch + 1:03d}.pt")
320 | torch.save(raw_model.state_dict(), save_path)
321 |
--------------------------------------------------------------------------------
/S-Transformer/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021, Duong Nguyen
3 | #
4 | # Licensed under the CECILL-C License;
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.cecill.info
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utility functions for GPTrajectory.
17 |
18 | References:
19 | https://github.com/karpathy/minGPT
20 | """
21 | import numpy as np
22 | import os
23 | import math
24 | import logging
25 | import random
26 | import datetime
27 | import socket
28 |
29 |
30 | import torch
31 | import torch.nn as nn
32 | from torch.nn import functional as F
33 | torch.pi = torch.acos(torch.zeros(1)).item()*2
34 |
35 |
36 | def set_seed(seed):
37 | random.seed(seed)
38 | np.random.seed(seed)
39 | torch.manual_seed(seed)
40 | torch.cuda.manual_seed_all(seed)
41 | torch.backends.cudnn.deterministic = True
42 |
43 |
44 | def new_log(logdir,filename):
45 | """Defines logging format.
46 | """
47 | filename = os.path.join(logdir,
48 | datetime.datetime.now().strftime("log_%Y-%m-%d-%H-%M-%S_"+socket.gethostname()+"_"+filename+".log"))
49 | logging.basicConfig(level=logging.INFO,
50 | filename=filename,
51 | format="%(asctime)s - %(name)s - %(message)s",
52 | filemode="w")
53 | console = logging.StreamHandler()
54 | console.setLevel(logging.INFO)
55 | formatter = logging.Formatter("%(asctime)s - %(name)s - %(message)s")
56 | console.setFormatter(formatter)
57 | logging.getLogger('').addHandler(console)
58 |
59 | def haversine(input_coords,
60 | pred_coords):
61 | """ Calculate the haversine distances between input_coords and pred_coords.
62 |
63 | Args:
64 | input_coords, pred_coords: Tensors of size (...,N), with (...,0) and (...,1) are
65 | the latitude and longitude in radians.
66 |
67 | Returns:
68 | The havesine distances between
69 | """
70 | R = 6371
71 | lat_errors = pred_coords[...,0] - input_coords[...,0]
72 | lon_errors = pred_coords[...,1] - input_coords[...,1]
73 | a = torch.sin(lat_errors/2)**2\
74 | +torch.cos(input_coords[:,:,0])*torch.cos(pred_coords[:,:,0])*torch.sin(lon_errors/2)**2
75 | c = 2*torch.atan2(torch.sqrt(a),torch.sqrt(1-a))
76 | d = R*c
77 | return d
78 |
79 | def top_k_logits(logits, k):
80 | v, ix = torch.topk(logits, k)
81 | out = logits.clone()
82 | out[out < v[:, [-1]]] = -float('Inf')
83 | return out
84 |
85 | def top_k_nearest_idx(att_logits, att_idxs, r_vicinity):
86 | """Keep only k values nearest the current idx.
87 |
88 | Args:
89 | att_logits: a Tensor of shape (bachsize, data_size).
90 | att_idxs: a Tensor of shape (bachsize, 1), indicates
91 | the current idxs.
92 | r_vicinity: number of values to be kept.
93 | """
94 | device = att_logits.device
95 | idx_range = torch.arange(att_logits.shape[-1]).to(device).repeat(att_logits.shape[0],1)
96 | idx_dists = torch.abs(idx_range - att_idxs)
97 | out = att_logits.clone()
98 | out[idx_dists >= r_vicinity/2] = -float('Inf')
99 | return out
--------------------------------------------------------------------------------
/关系矩阵_热力图.py:
--------------------------------------------------------------------------------
1 | import seaborn as sns
2 | import matplotlib.pyplot as plt
3 | import pandas as pd
4 |
5 | # 读取数据
6 | df = pd.read_csv('轨迹提取文件夹/205609000/subtrack_12.csv')
7 |
8 | # 选择分析的特征
9 | features = ['LAT', 'LON', 'SOG', 'COG', 'Heading']
10 |
11 | # 使用Seaborn的pairplot函数创建成对关系图
12 | sns.pairplot(df[features])
13 | plt.show() # 显示成对关系图
14 |
15 | # 计算斯皮尔曼相关性系数
16 | spearman_corr = df[features].corr(method='spearman')
17 |
18 | # 创建热力图,使用'GnBu'配色方案
19 | plt.figure(figsize=(10, 8)) # 调整图的大小以适应你的需要
20 | sns.heatmap(spearman_corr, annot=True, cmap='GnBu', square=True, linewidths=.5)
21 | plt.title('Spearman Correlation Heatmap')
22 | plt.show() # 显示热力图
23 | """
24 | 热力图颜色类型
25 | 'Blues', 'BuGn', 'BuPu'
26 | 'GnBu', 'Greens', 'Greys'
27 | 'Oranges', 'OrRd', 'PuBu'
28 | 'PuBuGn', 'PuRd', 'Purples'
29 | 'RdPu', 'Reds', 'YlGn'
30 | 'YlGnBu', 'YlOrBr', 'YlOrRd'
31 | """
--------------------------------------------------------------------------------
/插值处理.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | from scipy.interpolate import lagrange
5 | plt.rcParams['font.sans-serif'] = ['SimHei'] # 把字体设置为SimHei
6 | plt.rcParams['axes.unicode_minus'] = False
7 | # 读取原始数据
8 | df = pd.read_csv('轨迹提取文件夹/210959000/subtrack_5.csv', parse_dates=['BaseDateTime'])
9 | df.sort_values('BaseDateTime', inplace=True)
10 | df.set_index('BaseDateTime', inplace=True)
11 |
12 | # 创建新的时间索引
13 | new_index = pd.date_range(start=df.index.min(), end=df.index.max(), freq='2T')
14 |
15 | # 三次样条插值函数
16 | def cubic_spline_interpolation(df):
17 | df_resampled = df.resample('2T').mean().interpolate('cubic')
18 | return df_resampled
19 |
20 | # 应用三次样条插值
21 | df_cubic = cubic_spline_interpolation(df[['LAT', 'LON']])
22 |
23 | # 加权移动平均插值函数
24 | def weighted_moving_average_interpolation(df, k, weights):
25 | if len(weights) != 2 * k:
26 | raise ValueError("权重的数量必须是 2k。")
27 |
28 | df_wma = pd.DataFrame(index=df.index, columns=df.columns)
29 | for col in df.columns:
30 | interpolated_values = []
31 | for t in df.index:
32 | window = df[col].rolling(window=2*k, center=True).apply(lambda x: np.dot(x, weights) / sum(weights), raw=True)
33 | window = window.reindex(df.index, method='nearest')
34 | interpolated_values.append(window.loc[t])
35 | df_wma[col] = interpolated_values
36 | return df_wma
37 |
38 | # 设置 k 和权重
39 | k = 2
40 | weights = [1, 2, 2, 1]
41 |
42 | # 应用加权移动平均插值
43 | df_weighted = weighted_moving_average_interpolation(df[['LAT', 'LON']], k, weights)
44 |
45 | # 保存插值结果
46 | df_weighted.to_csv('插值文件夹/均值插值.csv')
47 | df_cubic.to_csv('插值文件夹/三次样条.csv')
48 |
49 | # 绘制散点图比较
50 | fig, axs = plt.subplots(2, 2, figsize=(10, 18)) # 创建三个子图
51 | # 调整subplot布局
52 | plt.subplots_adjust(top=1.5, bottom=0.1, left=0.1, right=0.9, hspace=0.5, wspace=0.5)
53 | # 原始数据散点图
54 | axs[0, 0].scatter(df['LON'], df['LAT'], alpha=0.5, color='blue')
55 | axs[0, 0].set_title('原始数据')
56 | axs[0, 0].set_xlabel('LON')
57 | axs[0, 0].set_ylabel('LAT')
58 | axs[0, 0].grid(True)
59 |
60 | # 三次样条插值数据散点图
61 | axs[0, 1].scatter(df_cubic['LON'], df_cubic['LAT'], alpha=0.5, color='green')
62 | axs[0, 1].set_title('三次样条插值')
63 | axs[0, 1].set_xlabel('LON')
64 | axs[0, 1].set_ylabel('LAT')
65 | axs[0, 1].grid(True)
66 |
67 | # 加权移动平均插值数据散点图
68 | axs[1, 0].scatter(df_weighted['LON'], df_weighted['LAT'], alpha=0.5, color='red')
69 | axs[1, 0].set_title('加权移动平均插值')
70 | axs[1, 0].set_xlabel('LON')
71 | axs[1, 0].set_ylabel('LAT')
72 | axs[1, 0].grid(True)
73 |
74 | plt.tight_layout()
75 | plt.show()
76 |
77 | print("运行完毕")
78 |
--------------------------------------------------------------------------------
/数据清洗.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 |
4 | # 文件夹路径
5 | input_folder = '数据集'
6 | output_folder = '处理完毕的数据集'
7 |
8 | # 如果输出文件夹不存在,则创建它
9 | if not os.path.exists(output_folder):
10 | os.makedirs(output_folder)
11 |
12 | # 遍历文件夹中的所有CSV文件
13 | for filename in os.listdir(input_folder):
14 | if filename.endswith('.csv'):
15 | file_path = os.path.join(input_folder, filename)
16 |
17 | # Load the dataset
18 | df = pd.read_csv(file_path)
19 |
20 | # 1. Sort by MMSI and BaseDateTime
21 | df.sort_values(by=['MMSI', 'BaseDateTime'], inplace=True)
22 |
23 | # 2. Remove records where MMSI is not 9 digits
24 | df = df[df['MMSI'].apply(lambda x: len(str(x)) == 9)]
25 |
26 | # 3. Filter for 'normal' navigation statuses
27 | """normal_statuses = ['under engine', 'at anchor', 'not under command', 'restricted manoeuvrability',
28 | 'moored', 'aground', 'fishing', 'under way']
29 | df = df[df['Status'].isin(normal_statuses)]"""
30 |
31 | # 4. Remove records where Length < 3 or Width < 2
32 | df = df[(df['Length'] >= 3) & (df['Width'] >= 2)]
33 |
34 | # 5. Remove records with out-of-range values
35 | df = df[(df['LON'] >= -180.0) & (df['LON'] <= 180.0)]
36 | df = df[(df['LAT'] >= -90.0) & (df['LAT'] <= 90.0)]
37 | df = df[(df['SOG'] >= 0) & (df['SOG'] <= 51.2)]
38 | df = df[(df['COG'] >= -204.7) & (df['COG'] <= 204.8)]
39 |
40 | # 6. Remove records where SOG is zero for five consecutive points
41 | df['zero_sog'] = (df['SOG'] == 0).astype(int)
42 | df['rolling_sum'] = df.groupby('MMSI')['zero_sog'].rolling(window=5).sum().reset_index(level=0, drop=True)
43 | df = df[df['rolling_sum'] < 5]
44 | df.drop(columns=['zero_sog', 'rolling_sum'], inplace=True)
45 |
46 | # 保存处理后的数据集
47 | new_filename = 'new_' + filename
48 | output_path = os.path.join(output_folder, new_filename)
49 | df.to_csv(output_path, index=False)
50 |
51 | print(f"清洗完毕,已经保存到 '{output_path}'.")
--------------------------------------------------------------------------------
/箱形图.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import matplotlib.pyplot as plt
4 |
5 | input_folder = 'D:\\AA_work\\数据预处理\\处理完毕的数据集' # 设置你的CSV文件所在文件夹路径
6 |
7 | # 初始化一个空的DataFrame列表
8 | data_frames = []
9 |
10 | # 遍历文件夹中的所有CSV文件
11 | for filename in os.listdir(input_folder):
12 | if filename.endswith('.csv'):
13 | file_path = os.path.join(input_folder, filename)
14 | df = pd.read_csv(file_path)
15 | data_frames.append(df) # 将每个文件的DataFrame添加到列表中
16 |
17 | # 合并所有的DataFrame为一个大的DataFrame
18 | combined_df = pd.concat(data_frames, ignore_index=True)
19 |
20 | # 创建箱形图
21 | fig, axs = plt.subplots(3, 2, figsize=(12, 9)) # 修改图形的大小以避免挤压和重叠
22 |
23 | # 定义一个绘制箱形图的函数
24 | def draw_boxplot(column, title, ax, ylim=None):
25 | ax.boxplot(combined_df[column].dropna(),
26 | patch_artist=True, # 启用补丁艺术家模式,允许填充箱体
27 | boxprops={'color': 'blue', 'facecolor': 'lightblue'}, # 箱体轮廓和填充颜色
28 | flierprops={'marker': 'o', 'markerfacecolor': 'red', 'markersize': 1, 'markeredgecolor': 'red'}, # 异常点样式
29 | whiskerprops={'color': 'blue'}, # 括号颜色
30 | capprops={'color': 'blue'}, # 条帽颜色
31 | medianprops={'color': 'red'}) # 中位线颜色设置为红色
32 | ax.set_title(title)
33 | ax.set_ylim(ylim) # 设置y轴范围,如果提供了ylim参数
34 | ax.grid(True) # 添加网格线
35 |
36 |
37 | # 绘制箱形图
38 | draw_boxplot('LON', '经度', axs[0, 0]) # 指定y轴范围
39 | draw_boxplot('LAT', '纬度', axs[0, 1])
40 | draw_boxplot('SOG', '对地航速', axs[1, 0])
41 | draw_boxplot('COG', '航向', axs[1, 1])
42 | draw_boxplot('Heading', '航向角', axs[2, 0])
43 |
44 | # 删除多余的子图
45 | fig.delaxes(axs[2, 1])
46 |
47 | fig.tight_layout() # 自动调整子图参数, 使之填充整个图像区域
48 | plt.show()
--------------------------------------------------------------------------------
/轨迹平稳性检验.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import pandas as pd
3 |
4 | plt.rcParams['font.sans-serif'] = ['SimHei'] # 把字体设置为SimHei
5 | plt.rcParams['axes.unicode_minus'] = False
6 | # 读取数据
7 | df = pd.read_csv('轨迹提取文件夹/205609000/subtrack_12.csv')
8 |
9 | # 绘制LAT, LON, SOG, COG随时间的变化
10 | plt.figure(figsize=(15, 10)) # 设置图形的大小
11 |
12 | # 绘制LAT
13 | plt.subplot(2, 2, 1) # 2行2列的子图中的第1个
14 | plt.plot(df.index, df['LAT'], label='LAT')
15 | plt.xlabel('序列数')
16 | plt.ylabel('LAT')
17 | plt.title('(a) 纬度序列分布')
18 |
19 | # 绘制LON
20 | plt.subplot(2, 2, 2) # 2行2列的子图中的第2个
21 | plt.plot(df.index, df['LON'], label='LON', color='orange')
22 | plt.xlabel('序列数')
23 | plt.ylabel('LON')
24 | plt.title('(b) 经度序列分布')
25 |
26 | # 绘制SOG
27 | plt.subplot(2, 2, 3) # 2行2列的子图中的第3个
28 | plt.plot(df.index, df['SOG'], label='SOG', color='green')
29 | plt.xlabel('序列数')
30 | plt.ylabel('SOG')
31 | plt.title('(c) 对地航速序列分布')
32 |
33 | # 绘制COG
34 | plt.subplot(2, 2, 4) # 2行2列的子图中的第4个
35 | plt.plot(df.index, df['COG'], label='COG', color='red')
36 | plt.xlabel('序列数')
37 | plt.ylabel('COG')
38 | plt.title('(d) 对地航向序列分布')
39 |
40 | # 调整子图间的间距
41 | plt.tight_layout()
42 | plt.show()
43 |
--------------------------------------------------------------------------------
/轨迹提取.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import os
3 | from datetime import datetime
4 |
5 | # 原始数据文件夹路径
6 | data_folder_path = '处理完毕的数据集'
7 | # 结果保存的基本路径
8 | output_base_path = '轨迹提取文件夹'
9 |
10 | # 读取所有CSV文件
11 | all_files = [os.path.join(data_folder_path, f) for f in os.listdir(data_folder_path) if f.endswith('.csv')]
12 | all_data = [pd.read_csv(file) for file in all_files]
13 |
14 | # 合并DataFrame
15 | df = pd.concat(all_data, ignore_index=True)
16 |
17 | # 确保BaseDateTime是datetime类型
18 | df['BaseDateTime'] = pd.to_datetime(df['BaseDateTime'])
19 |
20 | # 按MMSI分组并排序
21 | grouped = df.sort_values('BaseDateTime').groupby('MMSI')
22 |
23 | # 处理每个MMSI的数据
24 | for mmsi, group in grouped:
25 | # 创建MMSI对应的目录在新的输出路径下
26 | mmsi_folder = os.path.join(output_base_path, str(mmsi))
27 | if not os.path.exists(mmsi_folder):
28 | os.makedirs(mmsi_folder)
29 |
30 | start_index = 0
31 | subtrack_number = 0
32 | for i in range(1, len(group)):
33 | # 计算时间差
34 | time_diff = (group.iloc[i]['BaseDateTime'] - group.iloc[i - 1]['BaseDateTime']).total_seconds() / 60
35 | if time_diff > 30:
36 | # 保存子轨迹
37 | sub_df = group.iloc[start_index:i]
38 | sub_df.to_csv(os.path.join(mmsi_folder, f'subtrack_{subtrack_number}.csv'), index=False)
39 | subtrack_number += 1
40 | start_index = i
41 | # 保存最后一个子轨迹
42 | sub_df = group.iloc[start_index:]
43 | sub_df.to_csv(os.path.join(mmsi_folder, f'subtrack_{subtrack_number}.csv'), index=False)
44 |
45 |
--------------------------------------------------------------------------------