├── train ├── __init__.py ├── deploy_v.yaml ├── deploy_cloudml.yaml ├── deploy_cloudml.py └── training_log.yaml ├── classification ├── metrics │ ├── __init__.py │ ├── ydump.py │ └── compute_fishing_metrics.py ├── models │ ├── README.md │ ├── __init__.py │ ├── models_test.py │ ├── shake_shake.py │ ├── model.py │ ├── fishing_detection.py │ ├── vessel_characterization.py │ ├── vessel_characterization_depth.py │ ├── vessel_characterization_shakex2.py │ ├── layers.py │ └── objectives.py ├── README.md ├── data │ ├── README.md │ └── __init__.py ├── feature_generation │ ├── file_iterator_test.py │ ├── __init__.py │ ├── file_iterator.py │ ├── feature_generation_test.py │ ├── vessel_feature_generation_test.py │ ├── feature_generation.py │ ├── fishing_feature_generation_test.py │ ├── vessel_feature_generation.py │ └── fishing_feature_generation.py ├── __init__.py ├── run_inference.py ├── run_training.py ├── metadata_test.py └── metadata.py ├── test_all ├── .gitignore ├── .travis.yml ├── common ├── __init__.py └── gcp_config.py ├── CONTRIBUTING ├── Dockerfile ├── setup.py ├── notebooks └── AveragingLengthsAcrossTime.py ├── CHANGES.md ├── README.md └── LICENSE /train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /classification/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /classification/models/README.md: -------------------------------------------------------------------------------- 1 | # Production 2 | 3 | Neural net classification models used in production. -------------------------------------------------------------------------------- /classification/README.md: -------------------------------------------------------------------------------- 1 | # Classification 2 | 3 | This directory contains the core files for neural net models. -------------------------------------------------------------------------------- /test_all: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | # Run python TF tests. 6 | export TF_CPP_MIN_LOG_LEVEL=2 7 | python -m classification.metadata_test 8 | python -m classification.models.models_test 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | temp 2 | untracked 3 | target 4 | .cache 5 | bigquery_credentials.dat 6 | *.swp 7 | *.pyc 8 | .DS_Store 9 | dist/ 10 | *.egg-info 11 | *.json.gz 12 | classification/classification/models/prod/TEMP 13 | .ipynb_checkpoints 14 | train/ssvid_to_vessel_id.csv 15 | -------------------------------------------------------------------------------- /classification/data/README.md: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | Data files that tensorflow packages up for training. 4 | 5 | * `combined_fishing_ranges.csv` - Fishing range data for fishing localization training. 6 | 7 | * `training_classes.csv` - Map of MMSI to classes for vessel classification training. -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | dist: trusty 3 | 4 | before_script: 5 | - export PATH=$HOME/.local/bin:$PATH 6 | - sudo pip install tensorflow 7 | - sudo pip install google-api-python-client pyyaml python-dateutil NewlineJSON pytz yattag 8 | 9 | script: 10 | - ./test_all 11 | -------------------------------------------------------------------------------- /classification/feature_generation/file_iterator_test.py: -------------------------------------------------------------------------------- 1 | 2 | from . import file_iterator 3 | import posixpath as pp 4 | import pytest 5 | 6 | id_path = "gs://machine-learning-dev-ttl-120d/features/mmsi_features_v20191126/ids/part-00000-of-00001.txt" 7 | 8 | def test_GCSFile(): 9 | path = pp.join(id_path) 10 | with file_iterator.GCSFile(path) as fp: 11 | lines = fp.read().strip().split() 12 | assert lines[:2] == [b'900410135', b'413222478'] 13 | 14 | 15 | if __name__ == '__main__': 16 | tf.test.main() 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /classification/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /classification/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /train/deploy_v.yaml: -------------------------------------------------------------------------------- 1 | region: us-central1 2 | staging_bucket: gs://world-fishing-827-ml 3 | tensor_flow_config_template: | # This gets interpolated and then passed onto TF 4 | trainingInput: 5 | args: [ 6 | "{model_name}", 7 | "--feature_dimensions", "14", 8 | "--root_feature_path", "{feature_path}", 9 | "--training_output_path", "{output_path}/{model_name}", 10 | "--metadata_file", "{vessel_info}", 11 | "--fishing_ranges_file", "{fishing_ranges}", 12 | "--metrics", "minimal", 13 | "--split", "{split}" 14 | ] 15 | scaleTier: CUSTOM 16 | masterType: large_model_v100 17 | runtimeVersion: "1.13" 18 | 19 | 20 | -------------------------------------------------------------------------------- /classification/feature_generation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /train/deploy_cloudml.yaml: -------------------------------------------------------------------------------- 1 | region: us-central1 2 | staging_bucket: gs://world-fishing-827-ml 3 | tensor_flow_config_template: | # This gets interpolated and then passed onto TF 4 | trainingInput: 5 | args: [ 6 | "{model_name}", 7 | "--feature_dimensions", "14", 8 | "--root_feature_path", "{feature_path}", 9 | "--training_output_path", "{output_path}/{model_name}", 10 | "--metadata_file", "{vessel_info}", 11 | "--fishing_ranges_file", "{fishing_ranges}", 12 | "--metrics", "minimal", 13 | "--split", "{split}" 14 | ] 15 | scaleTier: CUSTOM 16 | masterType: large_model_v100 17 | runtimeVersion: "1.15" 18 | pythonVersion: "3.7" 19 | 20 | 21 | -------------------------------------------------------------------------------- /classification/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vessel classification: feature generation and model training/inference. 3 | """ 4 | 5 | __version__ = '3.0.3' 6 | __author__ = 'Tim Hochberg' 7 | __email__ = 'tim@globalfishingwatch.com' 8 | __source__ = 'https://github.com/GlobalFishingWatch/vessel-classification' 9 | __license__ = """ 10 | Copyright 2023 Global Fishing Watch Inc. 11 | Authors: 12 | 13 | Tim Hochberg 14 | 15 | Licensed under the Apache License, Version 2.0 (the "License"); 16 | you may not use this file except in compliance with the License. 17 | You may obtain a copy of the License at 18 | 19 | http://www.apache.org/licenses/LICENSE-2.0 20 | 21 | Unless required by applicable law or agreed to in writing, software 22 | distributed under the License is distributed on an "AS IS" BASIS, 23 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 24 | See the License for the specific language governing permissions and 25 | limitations under the License. 26 | """ 27 | -------------------------------------------------------------------------------- /classification/feature_generation/file_iterator.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import numpy as np 4 | import tempfile 5 | import subprocess 6 | import os 7 | import resource 8 | import shutil 9 | import time 10 | 11 | import tensorflow as tf 12 | 13 | from .feature_utilities import np_array_extract_all_fixed_slices 14 | from .feature_utilities import np_array_extract_slices_for_time_ranges 15 | from .feature_utilities import np_pad_repeat_slice 16 | 17 | 18 | class GCSFile(object): 19 | 20 | def __init__(self, path): 21 | self.gcs_path = path 22 | 23 | def __enter__(self): 24 | self.temp_dir = tempfile.mkdtemp() 25 | local_path = os.path.join(self.temp_dir, os.path.basename(self.gcs_path)) 26 | subprocess.check_call(['gsutil', 'cp', self.gcs_path, local_path]) 27 | return self._process(local_path) 28 | 29 | def _process(self, path): 30 | return open(path, 'rb') 31 | 32 | def __exit__(self, *args): 33 | shutil.rmtree(self.temp_dir) 34 | 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /CONTRIBUTING: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult [GitHub Help] for more 22 | information on using pull requests. 23 | 24 | [GitHub Help]: https://help.github.com/articles/about-pull-requests/ 25 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:16.04 2 | 3 | RUN mkdir -p /opt/project 4 | WORKDIR /opt/project 5 | 6 | # Prepare dependencies 7 | RUN apt-get update && \ 8 | apt-get install -y apt-transport-https ca-certificates unzip curl libcurl3 wget 9 | 10 | # Install APT dependencies 11 | RUN apt-get -y update && \ 12 | apt-get -y install python python-setuptools python-dev build-essential git 13 | 14 | # Install google cloud 15 | RUN curl -sSL https://sdk.cloud.google.com | bash && \ 16 | /root/google-cloud-sdk/bin/gcloud config set --installation component_manager/disable_update_check true && \ 17 | /root/google-cloud-sdk/bin/gcloud components install beta 18 | ENV PATH $PATH:/root/google-cloud-sdk/bin 19 | 20 | # Install python dependencies 21 | RUN easy_install pip && \ 22 | pip install --upgrade pip 23 | RUN pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.2.0-cp27-none-linux_x86_64.whl && \ 24 | pip install google-api-python-client pyyaml pytz newlinejson python-dateutil yattag pandas-gbq && \ 25 | pip install git+https://github.com/GlobalFishingWatch/bqtools.git 26 | 27 | COPY . /opt/project 28 | RUN pip install . 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import setuptools 16 | import glob 17 | import os 18 | 19 | package = __import__('classification') 20 | 21 | DEPENDENCIES = [ 22 | "google-api-python-client", 23 | "six>=1.13.0" 24 | ] 25 | 26 | 27 | data_files = [os.path.basename(x) 28 | for x in glob.glob("classification/data/*.csv")] 29 | 30 | setuptools.setup( 31 | name='vessel_inference', 32 | version=package.__version__, 33 | author=package.__author__, 34 | author_email=package.__email__, 35 | description=package.__doc__.strip(), 36 | package_data={ 37 | 'classification.data': data_files 38 | }, 39 | packages=[ 40 | 'common', 41 | 'classification', 42 | 'classification.data', 43 | 'classification.models', 44 | 'classification.feature_generation' 45 | ], 46 | install_requires=DEPENDENCIES 47 | ) 48 | 49 | -------------------------------------------------------------------------------- /classification/models/models_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import tensorflow as tf 17 | 18 | from classification import metadata 19 | from . import vessel_characterization, fishing_detection 20 | 21 | 22 | class ModelsTest(tf.test.TestCase): 23 | num_feature_dimensions = 11 24 | model_classes = [vessel_characterization.Model, fishing_detection.Model] 25 | 26 | def _build_estimator(self, model_class): 27 | vmd = metadata.VesselMetadata({}, {}) 28 | model = model_class(self.num_feature_dimensions, vmd, metrics='all') 29 | return model.make_estimator("dummy_directory") 30 | 31 | def test_estimator_contruction(self): 32 | for i, model_class in enumerate(self.model_classes): 33 | with self.test_session(): 34 | # This protects against multiple model using same variable names 35 | with tf.variable_scope("training-test-{}".format(i)): 36 | est = self._build_estimator(model_class) 37 | 38 | # TODO: test input_fn 39 | 40 | 41 | if __name__ == '__main__': 42 | tf.test.main() 43 | -------------------------------------------------------------------------------- /classification/models/shake_shake.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def shake_shake(x1, x2, is_training): 5 | is_training = tf.constant(is_training, dtype=tf.bool) 6 | # create alpha and beta 7 | batch_size = tf.shape(x1)[0] 8 | # TODO: modifed for 1d, make more general or rename 9 | alpha = tf.random_uniform((batch_size, 1, 1)) 10 | beta = tf.random_uniform((batch_size, 1, 1)) 11 | # shake-shake during training phase 12 | def x_shake(): 13 | return beta * x1 + (1 - beta) * x2 + tf.stop_gradient((alpha - beta) * x1 + (beta - alpha) * x2) 14 | # even-even during testing phase 15 | def x_even(): 16 | return 0.5 * x1 + 0.5 * x2 17 | return tf.cond(is_training, x_shake, x_even) 18 | 19 | 20 | def shake_out(x, is_training): 21 | is_training = tf.constant(is_training, dtype=tf.bool) 22 | # create alpha and beta 23 | batch_size = tf.shape(x)[0] 24 | feature_depth = tf.shape(x)[2] # TODO: bulletproof 25 | alpha = tf.random_uniform((batch_size, 1, feature_depth)) 26 | # shake-shake during training phase 27 | def x_shake(): 28 | return alpha * x, (1 - alpha * x) 29 | # even-even during testing phase 30 | def x_even(): 31 | return 0.5 * x, 0.5 * x 32 | return tf.cond(is_training, x_shake, x_even) 33 | 34 | 35 | def shake_out2(x1, x2, is_training): 36 | is_training = tf.constant(is_training, dtype=tf.bool) 37 | # create alpha and beta 38 | batch_size = tf.shape(x1)[0] 39 | feature_depth = tf.shape(x1)[2] # TODO: bulletproof 40 | # TODO: modifed for 1d, make more general or rename 41 | alpha = tf.random_uniform((batch_size, 1, feature_depth)) 42 | # shake-shake during training phase 43 | def x_shake(): 44 | return alpha * x1 + (1 - alpha) * x2 45 | # even-even during testing phase 46 | def x_even(): 47 | return 0.5 * x1 + 0.5 * x2 48 | return tf.cond(is_training, x_shake, x_even) -------------------------------------------------------------------------------- /common/gcp_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import datetime 16 | import logging 17 | import os 18 | import sys 19 | 20 | 21 | class GcpConfig(object): 22 | def __init__(self, start_time, project_id, root_path): 23 | self.start_time = start_time 24 | self.project_id = project_id 25 | self.root_path = root_path 26 | 27 | def model_path(self): 28 | return self.root_path + '/models' 29 | 30 | # TODO(alexwilson): This config is too hard-coded to our current setup. Move 31 | # out to config files for greater flexibility. Note there is an equivalent to 32 | # this in Commmon.scala which should remain in-sync. 33 | @staticmethod 34 | def make_from_env_name(environment, job_id): 35 | now = datetime.datetime.utcnow() 36 | project_id = "world-fishing-827" 37 | if environment == 'prod': 38 | root_path = 'gs://machine-learning/data-production/classification/%s' % job_id 39 | elif environment == 'dev': 40 | user_name = os.environ['USER'] 41 | if not user_name: 42 | logging.fatal( 43 | 'USER environment variable cannot be empty for dev runs.') 44 | sys.exit(-1) 45 | root_path = 'gs://machine-learning-dev-ttl-120d/data-production/classification/%s/%s' % ( 46 | user_name, job_id) 47 | else: 48 | logging.fatal('Invalid environment: %s', env) 49 | sys.exit(-1) 50 | 51 | return GcpConfig(now, project_id, root_path) 52 | -------------------------------------------------------------------------------- /classification/feature_generation/feature_generation_test.py: -------------------------------------------------------------------------------- 1 | 2 | import gc 3 | import posixpath as pp 4 | import tensorflow as tf 5 | import numpy as np 6 | from . import feature_generation 7 | from ..models import vessel_characterization 8 | import pytest 9 | 10 | 11 | metadata = vessel_characterization.Model.read_metadata([b'416853000', b'100209703', b'204225000'], 12 | 'classification/data/char_info_mmsi_v20200114.csv', {}, '0') 13 | # TODO: copy the referenced file to somewhere permanent 14 | prefix = b"gs://machine-learning-dev-ttl-120d/features/mmsi_features_fishing_testpy3/" 15 | 16 | 17 | def test_read_input_fn_one_shot(): 18 | paths = ([prefix + b"features/416853000.tfrecord"]# + 19 | # [prefix + b"features/205285000.tfrecord"] + 20 | # [prefix + b"features/204225000.tfrecord"] 21 | ) 22 | dataset = feature_generation.read_input_fn_one_shot(paths, 15) 23 | iterator = dataset.make_one_shot_iterator() 24 | next_element = iterator.get_next() 25 | values = [] 26 | with tf.Session() as sess: 27 | while True: 28 | try: 29 | id_, data = sess.run(next_element) 30 | values.append((id_, data)) 31 | except tf.errors.OutOfRangeError: 32 | break 33 | assert len(values) == len(paths) 34 | [(id_, data)] = [x for x in values if x[0] == b'416853000'] 35 | assert id_ == b'416853000' 36 | assert data.shape == (11845, 15), data.shape 37 | 38 | 39 | 40 | def test_read_input_fn_infinite(): 41 | path = prefix + b"features/416853000.tfrecord" 42 | dataset = feature_generation.read_input_fn_infinite([path], 15) 43 | iterator = dataset.make_one_shot_iterator() 44 | next_element = iterator.get_next() 45 | values = [] 46 | with tf.Session() as sess: 47 | for _ in range(3): 48 | id_, data = sess.run(next_element) 49 | values.append((id_, data)) 50 | assert len(values) == 3 51 | id_, data = values[0] 52 | print(metadata.id_map_int2bytes) 53 | assert metadata.id_map_int2bytes[id_] == b'416853000' 54 | assert data.shape == (11845, 15), data.shape 55 | # assert np.allclose(data[0], [ 1.4905498e+09, 9.3955746e+00, 8.2576685e+00, 2.6900861e-01, 56 | # 2.7795976e-01, -6.9944441e-01, 9.0191650e-01, 4.3191043e-01, 57 | # 6.1232343e-17, 1.0000000e+00, -8.6529666e-01, 2.8903718e+00, 58 | # 0.0000000e+00, 0.0000000e+00, 0.0000000e+00]) 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /classification/feature_generation/vessel_feature_generation_test.py: -------------------------------------------------------------------------------- 1 | 2 | import gc 3 | import posixpath as pp 4 | import tensorflow as tf 5 | import numpy as np 6 | from . import feature_generation 7 | from . import vessel_feature_generation 8 | from . import feature_utilities 9 | from ..models import vessel_characterization 10 | import logging 11 | import pytest 12 | 13 | # TODO: copy the referenced file to somewhere permanent 14 | prefix = b"gs://machine-learning-dev-ttl-120d/features/mmsi_features_fishing_testpy3/" 15 | metadata = vessel_characterization.Model.read_metadata([b'416853000'], 16 | 'classification/data/char_info_mmsi_v20200114.csv', {}, '0') 17 | mdl = vessel_characterization.Model(14, metadata, 'minimal') 18 | 19 | 20 | def test_input_fn(): 21 | path = prefix + b"features/416853000.tfrecord" 22 | input_fn = vessel_feature_generation.input_fn( 23 | metadata, 24 | [path], 25 | mdl.num_feature_dimensions + 1, 26 | mdl.max_window_duration_seconds, 27 | mdl.window_max_points, 28 | mdl.min_viable_timeslice_length, 29 | objectives=mdl.training_objectives, 30 | parallelism=1) 31 | iterator = input_fn.make_one_shot_iterator() 32 | next_element = iterator.get_next() 33 | vals = [] 34 | with tf.compat.v1.Session() as sess: 35 | for _ in range(3): 36 | x = sess.run(next_element) 37 | vals.append(x) 38 | assert len(vals) == 3 39 | (obj_0, obj_1) = vals[0] 40 | assert sorted(obj_0.keys()) == ['features', 'id', 'time_ranges', 'timestamps'] 41 | assert sorted(obj_1.keys()) == ['Vessel-Crew-Size', 'Vessel-class', 'Vessel-engine-Power', 42 | 'Vessel-length', 'Vessel-tonnage'] 43 | assert [np.argmax(obj_b['Vessel-class']) for (obj_a, obj_b) in vals] == [31] * 3 44 | 45 | 46 | 47 | 48 | 49 | def test_predict_input_fn(): 50 | path = prefix + b"features/416853000.tfrecord" 51 | input_fn = vessel_feature_generation.predict_input_fn( 52 | [path], 53 | mdl.num_feature_dimensions + 1, 54 | [(1190549800.0, 1567037800.0)], 55 | mdl.window_max_points, 56 | mdl.min_viable_timeslice_length, 57 | parallelism=1) 58 | iterator = input_fn.make_one_shot_iterator() 59 | next_element = iterator.get_next() 60 | values = [] 61 | with tf.compat.v1.Session() as sess: 62 | while True: 63 | try: 64 | x = sess.run(next_element) 65 | values.append(x) 66 | except tf.errors.OutOfRangeError: 67 | break 68 | [x] = values 69 | assert sorted(x.keys()) == ['features', 'id', 'time_ranges', 'timestamps'] 70 | 71 | 72 | if __name__ == '__main__': 73 | tf.test.main() 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /classification/feature_generation/feature_generation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import posixpath as pp 4 | import os 5 | 6 | def filename_generator(filenames, random_state, weights): 7 | if weights is not None: 8 | weights = np.array(weights) 9 | weights /= weights.sum() 10 | while True: 11 | yield random_state.choice(filenames, p=weights) 12 | 13 | 14 | def flatten_features(features, timestamps, time_ranges, id_): 15 | return tf.data.Dataset.from_tensor_slices((features, timestamps, time_ranges, id_)) 16 | 17 | 18 | def set_feature_shapes(all_features, num_features, window_size): 19 | features, timestamps, time_ranges, id_ = all_features 20 | features.set_shape([window_size, num_features - 1]) 21 | timestamps.set_shape([window_size]) 22 | time_ranges.set_shape([2]) 23 | id_.set_shape([]) 24 | 25 | 26 | def parse_function_core(example_proto, num_features): 27 | context_features, sequence_features = tf.io.parse_single_sequence_example( 28 | example_proto, 29 | context_features={ 30 | 'id': tf.io.FixedLenFeature([], tf.int64) 31 | }, 32 | sequence_features={ 33 | 'movement_features': tf.io.FixedLenSequenceFeature(shape=(num_features, ), 34 | dtype=tf.float32) 35 | } 36 | ) 37 | return context_features['id'], sequence_features['movement_features'] 38 | 39 | def path2id(path): 40 | return tf.compat.v1.py_func( 41 | lambda p: pp.splitext(pp.basename(p))[0], [path], tf.string) 42 | 43 | def read_input_fn_infinite(paths, num_features, num_parallel_reads=4, 44 | random_state=None, weights=None): 45 | """Read data for training. 46 | 47 | Because we are IO bound during training, we return the raw IDs. These 48 | are mapped real IDs using the vessel metadata. 49 | """ 50 | 51 | def parse_function(example_proto): 52 | return parse_function_core(example_proto, num_features) 53 | 54 | if random_state is None: 55 | random_state = np.random.RandomState() 56 | 57 | path_ds = tf.data.Dataset.from_generator(lambda:filename_generator(paths, random_state, weights), 58 | tf.string) 59 | 60 | return (tf.data.TFRecordDataset(path_ds, num_parallel_reads=num_parallel_reads) 61 | .map(parse_function, num_parallel_calls=num_parallel_reads)) 62 | 63 | 64 | def read_input_fn_one_shot(paths, num_features, num_parallel_reads=4): 65 | """Read data for training. 66 | 67 | Because we are less likely to be IO bound during inference 68 | we return the real IDs as derived from the filenames. 69 | """ 70 | 71 | def parse_function(example_proto): 72 | return parse_function_core(example_proto, num_features) 73 | 74 | path_ds_1 = tf.data.Dataset.from_tensor_slices(paths) 75 | path_ds_2 = tf.data.Dataset.from_tensor_slices(paths) 76 | 77 | return tf.data.Dataset.zip(( 78 | path_ds_1 79 | .map(path2id), 80 | tf.data.TFRecordDataset(path_ds_2) 81 | .map(parse_function) 82 | .map(lambda id_, features: features) 83 | )) 84 | 85 | 86 | -------------------------------------------------------------------------------- /notebooks/AveragingLengthsAcrossTime.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # formats: ipynb,py:light 5 | # text_representation: 6 | # extension: .py 7 | # format_name: light 8 | # format_version: '1.5' 9 | # jupytext_version: 1.6.0 10 | # kernelspec: 11 | # display_name: Python 3 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | import matplotlib.pyplot as plt 17 | import pandas as pd 18 | import numpy as np 19 | 20 | query = """ 21 | with 22 | 23 | labels as ( 24 | select cast(id as string) as ssvid, length as length_lbl 25 | from `machine_learning_dev_ttl_120d.char_info_mmsi_v20200124` 26 | where split = 'Test' and length is not null 27 | and cast(id as string) not in ('367661820') -- this has bogus length 28 | ), 29 | 30 | inferred as ( 31 | select ssvid, start_time, length 32 | from `world-fishing-827.gfw_research_precursors.vc_v20200124_results_*` 33 | ), 34 | 35 | monthly_activity as ( 36 | select ssvid, sum(positions) positions, sum(active_positions) active_positions, 37 | extract(month from date) month, 38 | extract(year from date) year 39 | from gfw_research.pipe_v20190502_segs_daily 40 | group by ssvid, month, year 41 | ), 42 | 43 | semiyearly_activity as ( 44 | select ssvid, sum(positions) positions, sum(active_positions) active_positions, 45 | timestamp(datetime(year, 1, 1, 0, 0, 0)) start_time, year 46 | from monthly_activity 47 | where month <= 6 48 | group by ssvid, year 49 | union all 50 | select ssvid, sum(positions) positions, sum(active_positions) active_positions, 51 | timestamp(datetime(year, 7, 1, 0, 0, 0)) start_time, year 52 | from monthly_activity 53 | where month > 6 54 | group by ssvid, year 55 | ) 56 | 57 | 58 | select * 59 | from labels 60 | join inferred 61 | using (ssvid) 62 | join semiyearly_activity 63 | using (ssvid, start_time) 64 | """ 65 | length_df = pd.read_gbq(query, project_id='world-fishing-827', dialect='standard') 66 | 67 | # ## By SSVID only 68 | 69 | # + 70 | 71 | df = length_df.groupby(by = ['ssvid']).mean() 72 | plt.plot(df.length_lbl, df.length, '.') 73 | r2 = np.corrcoef(length_df.length_lbl, length_df.length)[0,1] ** 2 74 | r2avg = np.corrcoef(df.length_lbl, df.length)[0,1] ** 2 75 | 76 | # + 77 | lbls = [] 78 | lens = [] 79 | for key, group in length_df.groupby(by = ['ssvid']): 80 | lbls.append(group.length_lbl.mean()) 81 | scale = 10 * np.log(group.active_positions + 1) + np.log(group.positions + 1) 82 | l = (group.length * scale).sum() / scale.sum() 83 | lens.append(l) 84 | 85 | plt.plot(lbls, lens, '.') 86 | 87 | r2avg2 = np.corrcoef(lbls, lens)[0,1] ** 2 88 | print(f'{r2:.3f}, {r2avg:.3f}, {r2avg2:.3f}') 89 | # - 90 | 91 | # ## By SSVID and year 92 | 93 | # + 94 | 95 | df = length_df.groupby(by = ['ssvid', 'year']).mean() 96 | plt.plot(df.length_lbl, df.length, '.') 97 | r2 = np.corrcoef(length_df.length_lbl, length_df.length)[0,1] ** 2 98 | r2avg = np.corrcoef(df.length_lbl, df.length)[0,1] ** 2 99 | 100 | # + 101 | lbls = [] 102 | lens = [] 103 | for key, group in length_df.groupby(by = ['ssvid', 'year']): 104 | lbls.append(group.length_lbl.mean()) 105 | scale = 100 * np.log(group.active_positions + 1) + np.log(group.positions + 1) 106 | l = (group.length * scale).sum() / scale.sum() 107 | lens.append(l) 108 | 109 | plt.plot(lbls, lens, '.') 110 | 111 | r2avg2 = np.corrcoef(lbls, lens)[0,1] ** 2 112 | print(f'{r2:.3f}, {r2avg:.3f}, {r2avg2:.3f}') 113 | -------------------------------------------------------------------------------- /train/deploy_cloudml.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2017 Google Inc. and Skytruth Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from __future__ import print_function 18 | from common.gcp_config import GcpConfig 19 | import yaml 20 | import json 21 | import time 22 | import subprocess 23 | import os 24 | import datetime 25 | import argparse 26 | from oauth2client.client import GoogleCredentials 27 | from googleapiclient import discovery 28 | import tempfile 29 | 30 | 31 | def launch(args): 32 | # Read the configuration file so that we 33 | # know the train path and don't need to 34 | # hardcode it here 35 | with open(args.config_file) as f: 36 | config = yaml.safe_load(f.read()) 37 | tf_config_template = config['tensor_flow_config_template'] 38 | 39 | gcp = GcpConfig.make_from_env_name(args.env, args.job_name) 40 | 41 | tf_config_txt = tf_config_template.format( 42 | output_path=gcp.model_path(), **args.__dict__) 43 | 44 | timestamp = gcp.start_time.strftime('%Y%m%dT%H%M%S') 45 | job_id = ('%s_%s_%s' % (args.model_name, args.job_name, timestamp)).replace( 46 | '.', '_').replace('-', '_') 47 | 48 | # Kick off the job on CloudML 49 | with tempfile.NamedTemporaryFile('w') as temp: 50 | temp.write(tf_config_txt) 51 | temp.flush() 52 | 53 | with open(temp.name) as f: 54 | tf_config = yaml.safe_load(f) 55 | 56 | # It seems that we currently need to pass args as both 'args' in the 57 | # config file and as args after the '--'?! 58 | args = [ 59 | 'gcloud', 'ai-platform', 60 | 'jobs', 'submit', 'training', job_id, 61 | '--config', temp.name, '--module-name', 62 | 'classification.run_training', '--staging-bucket', 63 | config['staging_bucket'], '--package-path', 'classification', 64 | '--region', config['region'], '--' 65 | ] + tf_config['trainingInput']['args'] 66 | 67 | print('Executing:\n', ' '.join(args)) 68 | print("Config:\n", tf_config_txt) 69 | 70 | subprocess.check_call(args) 71 | 72 | return job_id 73 | 74 | 75 | if __name__ == "__main__": 76 | import argparse 77 | parser = argparse.ArgumentParser(description='Deploy ML Training.') 78 | parser.add_argument('--env', required=True, 79 | help='environment for run: prod/dev.') 80 | parser.add_argument('--model_name', required=True, 81 | help='module name of model.') 82 | parser.add_argument('--job_name', required=True, 83 | help='unique name for this job.') 84 | parser.add_argument('--feature_path', required=True, 85 | help='gcs path to features.') 86 | parser.add_argument('--vessel_info', required=True, 87 | help='local path to vessel_info.') 88 | parser.add_argument('--fishing_ranges', default='', 89 | help='optional local path fishing ranges') 90 | parser.add_argument('--config_file', default='deploy_cloudml.yaml', 91 | help='configuration file path.') 92 | parser.add_argument('--split', default=0, type=int, 93 | help='Split to use (-1) for all') 94 | args = parser.parse_args() 95 | 96 | launch(args) 97 | -------------------------------------------------------------------------------- /CHANGES.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a 6 | Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to 7 | [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 8 | 9 | ## [Unreleased] 10 | 11 | ## v3.0.3 - 2023-07-19 12 | 13 | ### Changed 14 | 15 | * [PIPELINE-1407](https://globalfishingwatch.atlassian.net/browse/PIPELINE-1407): Changes 16 | Removes the char `'` present in install_requirements. It fixes the issue at build: 17 | 'install_requires' must be a string or list of strings containing valid 18 | project/version requirement specifiers. 19 | 20 | ## v3.0.2 - 2020-03-29 21 | 22 | ### Added 23 | 24 | * [GlobalFishingWatch/gfw-eng-tasks#44](https://github.com/GlobalFishingWatch/gfw-eng-tasks/issues/44): Adds 25 | a fix bug when range didn't include features. 26 | Also adds missing test files 27 | 28 | ## v3.0.1 - 2020-03-20 29 | 30 | ### Added 31 | 32 | * [GlobalFishingWatch/gfw-eng-tasks#34](https://github.com/GlobalFishingWatch/gfw-eng-tasks/issues/34): Adds 33 | * Removes commented code 34 | * Pull out dependence on ujson and NewLineJson since no longer used 35 | * Reinstate padding; factor out new hash function so can be used from feature pipeline 36 | * Disable padding during fishing inference 37 | * Fix padding; remove approx_means 38 | * Fix pathname when generaring training paths 39 | * Switch to using blake2b for hashing 40 | * Fix padding bug that generated a lot of fishing regions with bad timestamps 41 | * Add missing input functions 42 | * Fix type causing fail when range was done 43 | * Remove debuggin logging 44 | * Bug fix 45 | * Fix ranges in vessel classification 46 | * More logging 47 | * More debugginf logging 48 | * Change logging to warning to make sure gets through 49 | * Add debugging logging for time range computation 50 | * Use research vessels again 51 | * Stop converting id to int when running inference 52 | * Force zip to return list under py3 53 | * Fix speed problems and memory leaks 54 | * Python 3 compatibility; much directed at working around change in builtin hash 55 | * Simple python 3 fixes; mostly in tests 56 | * Fix print statements 57 | * Separate cargo and tanker in coarse mapping 58 | * Improve metrics computation and add some support for auto generating docs from metadata 59 | * Add model that uses depths so we can test depth inference 60 | * Dont apply synonyms to top level classes 61 | * Automatically convert seismic_vessels to research 62 | * Tweak compute_metrics to keep fine classification table in defined order 63 | * Reinstate seismic vessel as unused class; improvements to metrics 64 | * Fix compute vessel metrics 65 | * Remove seismic vessel and update training and testing to use vessel database 66 | * Debug training data generation and update metric computation. 67 | * Silence yaml warning by switching to safe_load 68 | * Tweak training invokation to try to improve speed 69 | * Switch to computing vessel parameters in terms of MMSI 70 | 71 | ## v3.0.0 - 2019-06-12 72 | 73 | ### Added 74 | 75 | **BREAKING CHANGE, requires pipe-tools and pipe-features 3.0** 76 | * [#41](https://github.com/GlobalFishingWatch/pipe-features/pull/41) 77 | * Refactor to use new Tensorflow Dataset and Estimator APIs 78 | * Support more recent tensorflow versions. 79 | * Go back to original random rolling of data during training since tests 80 | showed slightly better accuracy. 81 | * Changes to support UVI and MMSI simultaneously. 82 | * Fix way vessel types are upsampled 83 | * Fix vessel metrics to work with vessel_id. 84 | * Correctly stratify data using new classes. 85 | * Generate training data directly from Vessel Database 86 | * Modify training invocation to make tracking runs easier. 87 | -------------------------------------------------------------------------------- /classification/models/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import abc 16 | import numpy as np 17 | import six 18 | from classification import metadata 19 | 20 | 21 | class ModelBase(object): 22 | __metaclass__ = abc.ABCMeta 23 | 24 | @property 25 | def number_of_steps(self): 26 | """Number of training examples to use""" 27 | return 500000 28 | 29 | @property 30 | def use_ranges_for_training(self): 31 | """Choose features overlapping with provided ranges during training""" 32 | return False 33 | 34 | @property 35 | def batch_size(self): 36 | return 64 37 | 38 | @property 39 | def max_window_duration_seconds(self): 40 | """ Window max duration in seconds. A value of zero indicates that 41 | we would instead like to choose a fixed-length window. """ 42 | return None 43 | 44 | # We often allocate a much smaller buffer than would fit the specified time 45 | # sampled at 5 mins intervals, on the basis that the sample is almost 46 | # always much more sparse. 47 | @property 48 | def window_max_points(self): 49 | return None 50 | 51 | @property 52 | def min_viable_timeslice_length(self): 53 | return 500 54 | 55 | @property 56 | def max_replication_factor(self): 57 | return 100.0 58 | 59 | def __init__(self, num_feature_dimensions, vessel_metadata): 60 | self.num_feature_dimensions = num_feature_dimensions 61 | if vessel_metadata: 62 | self.vessel_metadata = vessel_metadata 63 | self.fishing_ranges_map = vessel_metadata.fishing_ranges_map 64 | else: 65 | self.vessel_metadata = None 66 | self.fishing_ranges_map = None 67 | self.training_objectives = None 68 | 69 | def build_training_file_list(self, base_feature_path, split): 70 | boundary = 1 if (split == metadata.TRAINING_SPLIT) else self.batch_size 71 | random_state = np.random.RandomState() 72 | training_ids = self.vessel_metadata.weighted_training_list( 73 | random_state, 74 | split, 75 | self.max_replication_factor, 76 | boundary=boundary) 77 | return [ 78 | '%s/%s.tfrecord' % (base_feature_path, six.ensure_text(id_)) 79 | for id_ in training_ids 80 | ] 81 | 82 | @staticmethod 83 | def read_metadata(all_available_ids, 84 | metadata_file, 85 | fishing_ranges, 86 | split): 87 | # Ignore split for the time being 88 | return metadata.read_vessel_multiclass_metadata( 89 | all_available_ids, metadata_file, fishing_ranges) 90 | 91 | def zero_pad_features(self, features): 92 | """ Zero-pad features in the depth dimension to match requested feature depth. """ 93 | 94 | feature_pad_size = self.feature_depth - self.num_feature_dimensions 95 | assert (feature_pad_size >= 0) 96 | batch_size, _, _, _ = features.get_shape() 97 | zero_padding = tf.tile(features[:, :, :, :1] * 0, 98 | [1, 1, 1, feature_pad_size]) 99 | padded = tf.concat(3, [features, zero_padding]) 100 | 101 | return padded 102 | -------------------------------------------------------------------------------- /classification/run_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | 17 | import logging 18 | import numpy as np 19 | import os 20 | import pytz 21 | import subprocess 22 | import tempfile 23 | import tensorflow as tf 24 | import time 25 | import uuid 26 | from datetime import datetime 27 | from datetime import timedelta 28 | 29 | 30 | class Inferer(object): 31 | def __init__(self, model, model_checkpoint_path, root_feature_path, parallelism=4): 32 | 33 | self.model = model 34 | self.estimator = model.make_estimator(model_checkpoint_path) 35 | self.root_feature_path = root_feature_path 36 | logging.info('created Inferer with Model, %s, and dims %s', model, 37 | model.num_feature_dimensions) 38 | self.parallelism = 4 39 | 40 | def close(self): 41 | self.sess.close() 42 | 43 | 44 | def _feature_files(self, ids): 45 | return [ 46 | '%s/%s.tfrecord' % (self.root_feature_path, x) 47 | for x in ids 48 | ] 49 | 50 | def _build_time_ranges(self, interval_months, start_date, end_date): 51 | # TODO: should use min_window_duration here 52 | window_dur_seconds = self.model.max_window_duration_seconds 53 | last_viable_date = datetime.now( 54 | pytz.utc) - timedelta(seconds=window_dur_seconds) 55 | time_starts = [] 56 | start_year = start_date.year 57 | month_count = start_date.month - 1 58 | if start_date.day != 1: 59 | raise ValueError('start_date must fall on the 1st of the month') 60 | dt = start_date 61 | while True: 62 | year = start_year + month_count // 12 63 | month = month_count % 12 + 1 64 | month_count += interval_months 65 | dt = datetime(year, month, 1, tzinfo=pytz.utc) 66 | if dt >= end_date: 67 | break 68 | time_starts.append(dt) 69 | delta = timedelta(seconds=self.model.max_window_duration_seconds) 70 | time_ranges = [(int(time.mktime(dt.timetuple())), 71 | int(time.mktime((dt + delta).timetuple()))) 72 | for dt in time_starts] 73 | return time_ranges 74 | 75 | def run_inference(self, ids, interval_months, start_date, end_date): 76 | paths = self._feature_files(ids) 77 | 78 | if self.model.max_window_duration_seconds != 0: 79 | time_ranges = self._build_time_ranges(interval_months, start_date, end_date) 80 | input_fn = self.model.make_prediction_input_fn(paths, time_ranges, self.parallelism) 81 | else: 82 | input_fn = self.model.make_prediction_input_fn(paths, (start_date, end_date), self.parallelism) 83 | 84 | for result in self.estimator.predict(input_fn=input_fn): 85 | 86 | start_time, end_time = [datetime.utcfromtimestamp(x) for x in result['time_ranges']] 87 | output = { 88 | 'id': result['id'], 89 | 'start_time': start_time.isoformat(), 90 | 'end_time': end_time.isoformat() 91 | } 92 | for k, v in result.items(): 93 | if k in self.model.objective_map: 94 | o = self.model.objective_map[k] 95 | output[o.metadata_label] = o.build_json_results(v, result['timestamps']) 96 | 97 | yield output 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /classification/feature_generation/fishing_feature_generation_test.py: -------------------------------------------------------------------------------- 1 | 2 | import datetime 3 | import gc 4 | import posixpath as pp 5 | import tensorflow as tf 6 | import numpy as np 7 | from . import feature_generation 8 | from . import fishing_feature_generation 9 | from . import feature_utilities 10 | from ..models import fishing_detection 11 | from .. import metadata as metedata_mod 12 | import logging 13 | import pytest 14 | 15 | MAX_ITERS = 100 16 | 17 | # TODO: copy the referenced file to somewhere permanent 18 | prefix = b"gs://machine-learning-dev-ttl-120d/features/mmsi_features_fishing_testpy3/" 19 | fishing_ranges = metedata_mod.read_fishing_ranges('classification/data/det_ranges_mmsi_v20200114.csv') 20 | metadata = fishing_detection.Model.read_metadata([b'416853000', b'204248000'], 21 | 'classification/data/det_info_mmsi_v20200114.csv', fishing_ranges, 'Training') 22 | mdl = fishing_detection.Model(14, metadata, 'minimal') 23 | 24 | 25 | 26 | 27 | def test_predict_input_fn(): 28 | path1 = prefix + b"features/416853000.tfrecord" 29 | input_fn = fishing_feature_generation.predict_input_fn( 30 | [path1], 31 | mdl.num_feature_dimensions + 1, 32 | mdl.window_max_points, 33 | datetime.datetime(2015,1,1), 34 | datetime.datetime(2015,12,31), 35 | mdl.window, 36 | parallelism=1) 37 | iterator = input_fn.make_one_shot_iterator() 38 | next_element = iterator.get_next() 39 | values = [] 40 | with tf.compat.v1.Session() as sess: 41 | for _ in range(MAX_ITERS): 42 | try: 43 | x = sess.run(next_element) 44 | assert x['id'] in (b'416853000') 45 | td = [datetime.datetime.utcfromtimestamp(y) for y in x['time_ranges']] 46 | values.append(x.copy()) 47 | except tf.errors.OutOfRangeError: 48 | break 49 | else: 50 | raise RuntimeError('too many elements retrieved') 51 | x = values[0] 52 | assert sorted(x.keys()) == ['features', 'id', 'time_ranges', 'timestamps'] 53 | assert x['id'] == b'416853000' 54 | for x in values[:10]: 55 | td = [datetime.datetime.utcfromtimestamp(y) for y in x['time_ranges']] 56 | print(x['id'], td) 57 | 58 | def test_predict_input_fn_out_of_range(): 59 | path1 = prefix + b"features/416853000.tfrecord" 60 | input_fn = fishing_feature_generation.predict_input_fn( 61 | [path1], 62 | mdl.num_feature_dimensions + 1, 63 | mdl.window_max_points, 64 | datetime.datetime(2010,1,1), 65 | datetime.datetime(2010,12,31), 66 | mdl.window, 67 | parallelism=1) 68 | iterator = input_fn.make_one_shot_iterator() 69 | next_element = iterator.get_next() 70 | values = [] 71 | with tf.compat.v1.Session() as sess: 72 | for _ in range(MAX_ITERS): 73 | try: 74 | x = sess.run(next_element) 75 | assert x['id'] in (b'416853000') 76 | td = [datetime.datetime.utcfromtimestamp(y) for y in x['time_ranges']] 77 | values.append(x.copy()) 78 | except tf.errors.OutOfRangeError: 79 | break 80 | else: 81 | raise RuntimeError('too many elements retrieved') 82 | assert len(values) == 0 83 | 84 | 85 | def test_input_fn(): 86 | path = prefix + b"features/204248000.tfrecord" 87 | input_fn = fishing_feature_generation.input_fn( 88 | metadata, 89 | [path], 90 | mdl.num_feature_dimensions + 1, 91 | mdl.max_window_duration_seconds, 92 | mdl.window_max_points, 93 | mdl.min_viable_timeslice_length, 94 | parallelism=1) 95 | iterator = input_fn.make_one_shot_iterator() 96 | next_element = iterator.get_next() 97 | vals = [] 98 | with tf.compat.v1.Session() as sess: 99 | for _ in range(3): 100 | x = sess.run(next_element) 101 | vals.append(x) 102 | assert len(vals) == 3 103 | (obj_0, obj_1) = vals[0] 104 | assert sorted(obj_0.keys()) == ['features', 'id', 'time_ranges', 'timestamps'] 105 | assert obj_0['id'] == b'204248000' 106 | 107 | 108 | 109 | 110 | 111 | if __name__ == '__main__': 112 | tf.test.main() 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /classification/feature_generation/vessel_feature_generation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import logging 3 | import numpy as np 4 | from . import feature_generation 5 | from . import feature_utilities 6 | import six 7 | 8 | 9 | def input_fn( 10 | metadata, 11 | filenames, 12 | num_features, 13 | max_time_delta, 14 | window_size, 15 | min_timeslice_size, 16 | objectives, 17 | parallelism=4, 18 | num_slices_per_id=4): 19 | 20 | random_state = np.random.RandomState() 21 | 22 | def xform(id_, movement_features): 23 | 24 | def _xform(id_, features): 25 | id_ = metadata.id_map_int2bytes[id_] 26 | return feature_utilities.extract_n_random_fixed_times( 27 | random_state, features, num_slices_per_id, max_time_delta, 28 | window_size, id_, min_timeslice_size) 29 | 30 | features, timestamps, time_ranges, id_ = tf.py_func( 31 | _xform, 32 | [id_, movement_features], 33 | [tf.float32, tf.int32, tf.int32, tf.string]) 34 | return (features, timestamps, time_ranges, id_) 35 | 36 | def add_labels(features, timestamps, time_bounds, id_): 37 | 38 | def _add_labels(id_, timestamps): 39 | labels = [o.create_label(id_, timestamps) for o in objectives] 40 | return labels 41 | 42 | labels = tf.py_func( 43 | _add_labels, 44 | [id_, timestamps], 45 | [tf.float32] * len(objectives)) 46 | return ((features, timestamps, time_bounds, id_), tuple(labels)) 47 | 48 | def set_shapes(all_features, labels): 49 | feature_generation.set_feature_shapes(all_features, num_features, window_size) 50 | for i, obj in enumerate(objectives): 51 | t = labels[i] 52 | t.set_shape(obj.output_shape) 53 | return all_features, labels 54 | 55 | def lbls_as_dict(features, labels): 56 | d = {obj.name : labels[i] for (i, obj) in enumerate(objectives)} 57 | return features, d 58 | 59 | def features_as_dict(features, labels): 60 | features, timestamps, time_bounds, id_ = features 61 | d = {'features' : features, 'timestamps' : timestamps, 'time_ranges' : time_bounds, 'id' : id_} 62 | return d, labels 63 | 64 | raw_data = feature_generation.read_input_fn_infinite( 65 | filenames, 66 | num_features, 67 | num_parallel_reads=parallelism, 68 | random_state=random_state) 69 | 70 | return (raw_data 71 | .map(xform, num_parallel_calls=parallelism) 72 | .flat_map(feature_generation.flatten_features) 73 | .map(add_labels, num_parallel_calls=parallelism) 74 | .map(set_shapes, num_parallel_calls=parallelism) 75 | .map(lbls_as_dict, num_parallel_calls=parallelism) 76 | .map(features_as_dict, num_parallel_calls=parallelism) 77 | ) 78 | 79 | 80 | def predict_input_fn(paths, 81 | num_features, 82 | time_ranges, 83 | window_size, 84 | min_timeslice_size, 85 | parallelism=4): 86 | 87 | random_state = np.random.RandomState() 88 | 89 | def xform(id_, movement_features): 90 | 91 | def _xform(id_, features): 92 | return feature_utilities.np_array_extract_slices_for_time_ranges( 93 | random_state, features, id_, time_ranges, 94 | window_size, min_timeslice_size) 95 | 96 | raw_features = tf.cast(movement_features, tf.float32) 97 | features, timestamps, time_ranges_tensor, id_ = tf.py_func( 98 | _xform, 99 | [id_, raw_features], 100 | [tf.float32, tf.int32, tf.int32, tf.string]) 101 | return (features, timestamps, time_ranges_tensor, id_) 102 | 103 | def set_shapes(features, timestamps, time_bounds, id_): 104 | all_features = features, timestamps, time_bounds, id_ 105 | feature_generation.set_feature_shapes(all_features, num_features, window_size) 106 | return all_features 107 | 108 | def features_as_dict(features, timestamps, time_bounds, id_): 109 | d = {'features' : features, 'timestamps' : timestamps, 'time_ranges' : time_bounds, 'id' : id_} 110 | return d 111 | 112 | raw_data = feature_generation.read_input_fn_one_shot(paths, num_features, num_parallel_reads=parallelism) 113 | 114 | return (raw_data 115 | .map(xform, num_parallel_calls=parallelism) 116 | .flat_map(feature_generation.flatten_features) 117 | .map(set_shapes) 118 | .map(features_as_dict) 119 | ) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Global Fishing Watch Vessel Classification Pipeline. 2 | 3 | [Global Fishing Watch](http://globalfishingwatch.org) is a partnership between [Skytruth](https://skytruth.org), [Google](https://environment.google/projects/fishing-watch/) and [Oceana](http://oceana.org) to map all of the trackable commercial fishing activity in the world, in near-real time, and make it accessible to researchers, regulators, decision-makers, and the public. 4 | 5 | This repository contains code to build Tensorflow models to classify vessels and identify fishing behavior 6 | based on [AIS](https://en.wikipedia.org/wiki/Automatic_identification_system) data. 7 | 8 | (This is not an official Google Product). 9 | 10 | ## Overview 11 | 12 | Use AIS, and possibly VMS data in the future, to extract various types of information including: 13 | 14 | - Vessel types 15 | 16 | - Vessel fishing activity 17 | 18 | - Vessel attributes (length, tonnage, etc) 19 | 20 | The project consists of a convolutional neural networks (CNN) that infers vessel features. 21 | 22 | 23 | ### Neural Networks 24 | 25 | We have two CNN in production, as well as several experimental nets. One net 26 | predict vessel class (`longliner`, `cargo`, `sailing`, etc), as well as 27 | vessel length and other vessel parameters, while the second predicts whether 28 | a vessel is fishing or not at a given time point. 29 | 30 | *We initially used a single CNN to predict everything at once, 31 | but we've moveed to having two CNN. The original 32 | hope was that we would be able to take advantage of transfer learning between 33 | the various features. However, we did not see any gains from that, and using 34 | a multiple nets adds useful flexibility.* 35 | 36 | The nets share a similar structure, consisting of a large number (currently 9) 37 | of 1-D convolutional layers, followed by a single dense layer. The net for 38 | fishing prediction is somewhat more complicated since it must predict fishing at 39 | each point. To do this all of the layers of the net are combined, with upscaling 40 | of the upper layers, to produce a set of features at each point. 41 | These design of these nets incorporates ideas are borrowed 42 | from the ResNets and Inception nets, among other places, but adapted for the 1D environment. 43 | 44 | The code associated with the neural networks is located in 45 | `classification`. The models themselves are located 46 | in `classification/models`. 47 | 48 | ## Data layout 49 | 50 | *The data layout is currently in flux as we move data generation to Python-Dataflow 51 | managed by Airflow* 52 | 53 | ### Common parameters 54 | 55 | In order to support the above layout, all our programs need the following common parameters: 56 | 57 | * `env`: to specify the environment - either development or production. 58 | * `job-name`: for the name (or date) of the current job. 59 | * Additionally if the job is a dev job, the programs will read the $USER environment variable 60 | in order to be able to choose the appropriate subdirectory for the output data. 61 | 62 | 63 | # Neural Net Classification 64 | 65 | ## Running Stuff 66 | 67 | - `python -m train.deploy_cloudml` -- launch a training run on cloudml. Use `--help` to see options 68 | 69 | If not running in the SkyTruth/GFW environment, you will need to edit `deploy_cloudml.yaml` 70 | to set the gcs paths correctly. 71 | 72 | For example, to run vessel classification in the dev environment with the name `test`: 73 | 74 | python -m train.deploy_cloudml \ 75 | --env dev \ 76 | --model_name vessel_characterization \ 77 | --job_name test_deploy_v20200601 \ 78 | --config train/deploy_v_py3.yaml \ 79 | --feature_path gs://machine-learning-dev-ttl-120d/features/vessel_char_track_id_features_v20200428/features \ 80 | --vessel_info char_info_tid_v20200428.csv \ 81 | --fishing_ranges det_ranges_tid_v20200428.csv 82 | 83 | 84 | **IMPORTANT**: Even though there is a maximum number of training steps specified, the CloudML 85 | process does not shut down reliably. You need to periodically check on the process and kill it 86 | manually if it has completed and is hanging. In addition, there are occasionally other problems 87 | where either the master or chief will hang or die so that new checkpoints aren't written, or 88 | new validation data isn't written out. Again, killing and restarting the training is the solution. 89 | (This will pick up at the last checkpoint saved.) 90 | 91 | - *running training locally* -- this is primarily for testing as it will be quite slow unless you 92 | have a heavy duty machine: 93 | 94 | python -m classification.run_training \ 95 | fishing_range_classification \ 96 | --feature_dimensions 14 \ 97 | --root_feature_path FEATURE_PATH \ 98 | --training_output_path OUTPUT_PATH \ 99 | --fishing_range_training_upweight 1 \ 100 | --metadata_file VESSEL_INFO_FILE_NAME \ 101 | --fishing_ranges_file FISHING_RANGES_FILE_NAME \ 102 | --split {0, 1, 2, 3, 4, -1} 103 | --metrics minimal 104 | 105 | - `python -m train.compute_metrics` -- evaluate results and dump vessel lists. Use `--help` to see options 106 | 107 | 108 | * Inference is now run solely through Apache Beam. See README in pipe-features for details 109 | 110 | 111 | ## Local Environment Setup 112 | 113 | * Python 3.7++ 114 | * Tensorflow version >1.14.0,<2.0 from (https://www.tensorflow.org/get_started/os_setup) 115 | * `pip install google-api-python-client pyyaml pytz newlinejson python-dateutil yattag` 116 | 117 | 118 | 119 | 120 | 121 | 122 | -------------------------------------------------------------------------------- /classification/run_training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | import argparse 17 | import logging 18 | import os 19 | import sys 20 | import importlib 21 | import numpy as np 22 | import tensorflow as tf 23 | from pkg_resources import resource_filename 24 | from . import metadata 25 | 26 | def compute_approx_norms(model_fn, count=100): 27 | dataset = model_fn() 28 | print(dataset) 29 | iter = model_fn().make_initializable_iterator() 30 | print(iter) 31 | el = iter.get_next() 32 | means = [] 33 | vars = [] 34 | with tf.Session() as sess: 35 | sess.run(iter.initializer) 36 | for _ in range(count): 37 | x = sess.run(el)[0]['features'] 38 | means.append(x.mean(axis=(0, 1))) 39 | vars.append(x.var(axis=(0, 1))) 40 | return np.mean(means, axis=0), np.sqrt(np.mean(vars, axis=0)) 41 | 42 | 43 | def main(args): 44 | logging.getLogger().setLevel(logging.DEBUG) 45 | tf.logging.set_verbosity(tf.logging.DEBUG) 46 | 47 | logging.info("Running with Tensorflow version: %s", tf.__version__) 48 | 49 | logging.info("Loading model: %s", args.model_name) 50 | 51 | module = "classification.models.{}".format(args.model_name) 52 | try: 53 | Model = importlib.import_module(module).Model 54 | except: 55 | logging.fatal("Could not load model: {}".format(module)) 56 | raise 57 | 58 | metadata_file = os.path.abspath( 59 | resource_filename('classification.data', args.metadata_file)) 60 | if not os.path.exists(metadata_file): 61 | logging.fatal("Could not find metadata file: %s.", metadata_file) 62 | sys.exit(-1) 63 | 64 | if args.fishing_ranges_file: 65 | fishing_ranges_file = os.path.abspath( 66 | resource_filename('classification.data', args.fishing_ranges_file)) 67 | if not os.path.exists(fishing_ranges_file): 68 | logging.fatal("Could not find fishing range file: %s.", 69 | fishing_ranges_file) 70 | sys.exit(-1) 71 | fishing_ranges = metadata.read_fishing_ranges(fishing_ranges_file) 72 | else: 73 | fishing_ranges = {} 74 | 75 | all_available_ids = metadata.find_available_ids(args.root_feature_path) 76 | 77 | split = None if (args.split == -1) else args.split 78 | logging.info("Using split: %s", split) 79 | 80 | vessel_metadata = Model.read_metadata( 81 | all_available_ids, metadata_file, 82 | fishing_ranges, split=split) 83 | 84 | 85 | feature_dimensions = int(args.feature_dimensions) 86 | chosen_model = Model(feature_dimensions, vessel_metadata, args.metrics) 87 | 88 | train_input_fn = chosen_model.make_training_input_fn(args.root_feature_path, 89 | args.num_parallel_readers) 90 | 91 | test_input_fn = chosen_model.make_test_input_fn(args.root_feature_path, 92 | args.num_parallel_readers) 93 | 94 | estimator = chosen_model.make_estimator(args.training_output_path) 95 | train_spec = tf.estimator.TrainSpec( 96 | input_fn=train_input_fn, 97 | max_steps=chosen_model.number_of_steps 98 | ) 99 | eval_spec = tf.estimator.EvalSpec( 100 | steps=10, 101 | input_fn=test_input_fn, 102 | start_delay_secs=120, 103 | throttle_secs=600 104 | ) 105 | 106 | tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) 107 | 108 | 109 | 110 | def parse_args(): 111 | """ Parses command-line arguments for training.""" 112 | argparser = argparse.ArgumentParser('Train fishing classification model.') 113 | 114 | argparser.add_argument('model_name') 115 | 116 | argparser.add_argument( 117 | '--root_feature_path', 118 | required=True, 119 | help='The root path to the vessel movement feature directories.') 120 | 121 | argparser.add_argument( 122 | '--training_output_path', 123 | required=True, 124 | help='The working path for model statistics and checkpoints.') 125 | 126 | argparser.add_argument( 127 | '--feature_dimensions', 128 | required=True, 129 | help='The number of dimensions of a classification feature.') 130 | 131 | argparser.add_argument('--metadata_file', help='Path to metadata.') 132 | 133 | argparser.add_argument( 134 | '--fishing_ranges_file', help='Path to fishing range file.') 135 | 136 | argparser.add_argument( 137 | '--metrics', 138 | default='all', 139 | help='How many metrics to dump ["all" | "minimal"]') 140 | 141 | argparser.add_argument( 142 | '--num_parallel_readers', 143 | default=1, type=int, 144 | help='How many parallel readers to employ reading data') 145 | 146 | argparser.add_argument( 147 | '--split', 148 | default=0, type=int, 149 | help='Which split to train/test on') 150 | 151 | return argparser.parse_args() 152 | 153 | 154 | if __name__ == '__main__': 155 | args = parse_args() 156 | main(args) 157 | -------------------------------------------------------------------------------- /classification/feature_generation/fishing_feature_generation.py: -------------------------------------------------------------------------------- 1 | import calendar 2 | import numpy as np 3 | import os 4 | import tensorflow as tf 5 | import six 6 | from . import feature_generation 7 | from . import feature_utilities 8 | 9 | def input_fn(metadata, 10 | filenames, 11 | num_features, 12 | max_time_delta, 13 | window_size, 14 | min_timeslice_size, 15 | parallelism=4, 16 | num_slices_per_id=8, 17 | num_parallel_reads=1): 18 | 19 | random_state = np.random.RandomState() 20 | 21 | weights = [] 22 | for p in filenames: 23 | id_, _ = os.path.splitext(os.path.basename(p)) 24 | id_ = six.ensure_binary(id_) 25 | weights.append(metadata.vessel_weight(id_)) 26 | 27 | def xform(id_, movement_features): 28 | 29 | def _xform(id_, features): 30 | # Extract several random windows from each vessel track 31 | id_ = metadata.id_map_int2bytes[id_] 32 | ranges = metadata.fishing_ranges_map.get(id_, {}) 33 | return feature_utilities.extract_n_random_fixed_points( 34 | random_state, features, num_slices_per_id, 35 | window_size, id_, ranges) 36 | 37 | features, timestamps, time_ranges, id_ = tf.compat.v1.py_func( 38 | _xform, 39 | [id_, movement_features], 40 | [tf.float32, tf.int32, tf.int32, tf.string]) 41 | features = tf.squeeze(features, axis=1) 42 | return (features, timestamps, time_ranges, id_) 43 | 44 | fishing_ranges_map = {} 45 | for k, v in metadata.fishing_ranges_map.items(): 46 | fishing_ranges_map[k] = [] 47 | for sel_range in v: 48 | start_range = calendar.timegm(sel_range.start_time.utctimetuple()) 49 | end_range = calendar.timegm(sel_range.end_time.utctimetuple()) 50 | fishing_ranges_map[k].append((start_range, end_range, sel_range.is_fishing)) 51 | 52 | def add_labels(features, timestamps, time_bounds, id_): 53 | 54 | def _add_labels(id_, timestamps): 55 | dense_labels = np.empty_like(timestamps, dtype=np.float32) 56 | dense_labels.fill(-1.0) 57 | if id_ in fishing_ranges_map: 58 | for start_range, end_range, is_fishing in fishing_ranges_map[id_]: 59 | start_ndx = np.searchsorted(timestamps, start_range, side='left') 60 | end_ndx = np.searchsorted(timestamps, end_range, side='right') 61 | dense_labels[start_ndx:end_ndx] = is_fishing 62 | return dense_labels 63 | 64 | [labels] = tf.compat.v1.py_func( 65 | _add_labels, 66 | [id_, timestamps], 67 | [tf.float32]) 68 | return ((features, timestamps, time_bounds, id_), labels) 69 | 70 | def set_shapes(all_features, labels): 71 | feature_generation.set_feature_shapes(all_features, num_features, window_size) 72 | labels.set_shape([window_size]) 73 | return all_features, labels 74 | 75 | def features_as_dict(features, labels): 76 | features, timestamps, time_bounds, id_ = features 77 | d = {'features' : features, 'timestamps' : timestamps, 'time_ranges' : time_bounds, 'id' : id_} 78 | return d, labels 79 | 80 | raw_data = feature_generation.read_input_fn_infinite( 81 | filenames, 82 | num_features, 83 | num_parallel_reads=num_parallel_reads, 84 | random_state=random_state, 85 | weights=weights) 86 | 87 | return (raw_data 88 | .map(xform, num_parallel_calls=parallelism) 89 | .flat_map(feature_generation.flatten_features) 90 | .map(add_labels, num_parallel_calls=parallelism) 91 | .map(set_shapes) 92 | .map(features_as_dict) 93 | ) 94 | 95 | 96 | 97 | 98 | def predict_input_fn(paths, 99 | num_features, 100 | window_size, 101 | start_date, 102 | end_date, 103 | window, 104 | parallelism=4): 105 | 106 | if window is None: 107 | b, e = 0, window_size 108 | else: 109 | b, e = window 110 | shift = e - b - 1 111 | 112 | random_state = np.random.RandomState() 113 | 114 | 115 | 116 | # TODO: use paths to build hashlist and test 117 | # Look again at differences between fishing and vessel inference 118 | 119 | 120 | 121 | def xform(id_, movement_features): 122 | 123 | def _xform(id_, features): 124 | return feature_utilities.process_fixed_window_features( 125 | random_state, features, id_, num_features, 126 | window_size, shift, start_date, end_date, b, e) 127 | 128 | features, timestamps, time_ranges_tensor, id_ = tf.compat.v1.py_func( 129 | _xform, 130 | [id_, movement_features], 131 | [tf.float32, tf.int32, tf.int32, tf.string]) 132 | features = tf.squeeze(features, axis=1) 133 | return (features, timestamps, time_ranges_tensor, id_) 134 | 135 | def set_shapes(features, timestamps, time_bounds, id_): 136 | all_features = features, timestamps, time_bounds, id_ 137 | feature_generation.set_feature_shapes(all_features, num_features, window_size) 138 | return all_features 139 | 140 | def features_as_dict(features, timestamps, time_bounds, id_): 141 | d = {'features' : features, 'timestamps' : timestamps, 'time_ranges' : time_bounds, 'id' : id_} 142 | return d 143 | 144 | raw_data = feature_generation.read_input_fn_one_shot(paths, num_features, num_parallel_reads=parallelism) 145 | 146 | return (raw_data 147 | .map(xform, num_parallel_calls=parallelism) 148 | .flat_map(feature_generation.flatten_features) 149 | .map(set_shapes) 150 | .map(features_as_dict) 151 | ) 152 | 153 | 154 | -------------------------------------------------------------------------------- /classification/metadata_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import csv 17 | import numpy as np 18 | from . import metadata 19 | import tensorflow as tf 20 | from datetime import datetime 21 | import six 22 | 23 | 24 | class VesselMetadataFileReaderTest(tf.test.TestCase): 25 | raw_lines = [ 26 | 'id,label,length,split,idhash\n', 27 | '100001,drifting_longlines,10.0,Test,2\n', 28 | '100002,drifting_longlines,24.0,Training,3\n', 29 | '100003,drifting_longlines,7.0,Training,4\n', 30 | '100004,drifting_longlines,8.0,Test,5\n', 31 | '100005,trawlers,10.0,Test,6\n', 32 | '100006,trawlers,24.0,Test,7\n', 33 | '100007,passenger,24.0,Training,8\n', 34 | '100008,trawlers,24.0,Training,9\n', 35 | '100009,trawlers,10.0,Test,10\n', 36 | '100010,trawlers,24.0,Training,11\n', 37 | '100011,tug,60.0,Test,12\n', 38 | '100012,tug,5.0,Training,13\n', 39 | '100014,tug,24.0,Test,14\n', 40 | '100013,tug|trawlers,5.0,Training,15\n', 41 | ] 42 | 43 | fishing_range_dict = { 44 | b'100001': [metadata.FishingRange( 45 | datetime(2015, 3, 1), datetime(2015, 3, 2), 1.0)], 46 | b'100002': [metadata.FishingRange( 47 | datetime(2015, 3, 1), datetime(2015, 3, 2), 1.0)], 48 | b'100003': [metadata.FishingRange( 49 | datetime(2015, 3, 1), datetime(2015, 3, 2), 1.0)], 50 | b'100004': [metadata.FishingRange( 51 | datetime(2015, 3, 1), datetime(2015, 3, 2), 1.0)], 52 | b'100005': [metadata.FishingRange( 53 | datetime(2015, 3, 1), datetime(2015, 3, 2), 1.0)], 54 | b'100006': [metadata.FishingRange( 55 | datetime(2015, 3, 1), datetime(2015, 3, 2), 1.0)], 56 | b'100007': [metadata.FishingRange( 57 | datetime(2015, 3, 1), datetime(2015, 3, 2), 1.0)], 58 | b'100008': [metadata.FishingRange( 59 | datetime(2015, 3, 1), datetime(2015, 3, 2), 1.0)], 60 | b'100009': 61 | [metadata.FishingRange(datetime(2015, 3, 1), datetime(2015, 3, 4), 1.0) 62 | ], # Thrice as much fishing 63 | b'100010': [], 64 | b'100011': [], 65 | b'100012': [], 66 | b'100013': [], 67 | } 68 | 69 | def test_metadata_file_reader(self): 70 | parsed_lines = csv.DictReader(self.raw_lines) 71 | available_vessels = set(six.ensure_binary(str(x)) for x in range(100001, 100014)) 72 | result = metadata.read_vessel_multiclass_metadata_lines( 73 | available_vessels, parsed_lines, {}) 74 | 75 | # First one is test so weighted as 1 for now 76 | self.assertEqual(1.0, result.vessel_weight(b'100001')) 77 | self.assertEqual(1.118033988749895, result.vessel_weight(b'100002')) 78 | self.assertEqual(1.0, result.vessel_weight(b'100008')) 79 | self.assertEqual(1.2909944487358056, result.vessel_weight(b'100012')) 80 | self.assertEqual(1.5811388300841898, result.vessel_weight(b'100007')) 81 | self.assertEqual(1.1454972243679027, result.vessel_weight(b'100013')) 82 | 83 | self._check_splits(result) 84 | 85 | def test_fixed_time_reader(self): 86 | parsed_lines = csv.DictReader(self.raw_lines) 87 | available_vessels = set(six.ensure_binary(str(x)) for x in range(100001, 100014)) 88 | result = metadata.read_vessel_time_weighted_metadata_lines( 89 | available_vessels, parsed_lines, self.fishing_range_dict, 90 | 'Test') 91 | 92 | self.assertEqual(1.0, result.vessel_weight(b'100001')) 93 | self.assertEqual(1.0, result.vessel_weight(b'100002')) 94 | self.assertEqual(3.0, result.vessel_weight(b'100009')) 95 | self.assertEqual(0.0, result.vessel_weight(b'100012')) 96 | 97 | self._check_splits(result) 98 | 99 | def _check_splits(self, result): 100 | 101 | self.assertTrue('Training' in result.metadata_by_split) 102 | self.assertTrue('Test' in result.metadata_by_split) 103 | self.assertTrue('passenger', result.vessel_label('label', b'100007')) 104 | 105 | print(result.metadata_by_split['Test'][b'100001'][0]) 106 | self.assertEqual(result.metadata_by_split['Test'][b'100001'][0], 107 | {'label': 'drifting_longlines', 108 | 'length': '10.0', 109 | 'id': '100001', 110 | 'split': 'Test', 111 | 'idhash' : '2'}) 112 | self.assertEqual(result.metadata_by_split['Test'][b'100005'][0], 113 | {'label': 'trawlers', 114 | 'length': '10.0', 115 | 'id': '100005', 116 | 'split': 'Test', 117 | 'idhash' : '6'}) 118 | self.assertEqual(result.metadata_by_split['Training'][b'100002'][0], 119 | {'label': 'drifting_longlines', 120 | 'length': '24.0', 121 | 'id': '100002', 122 | 'split': 'Training', 123 | 'idhash' : '3'}) 124 | self.assertEqual(result.metadata_by_split['Training'][b'100003'][0], 125 | {'label': 'drifting_longlines', 126 | 'length': '7.0', 127 | 'id': '100003', 128 | 'split': 'Training', 129 | 'idhash' : '4'}) 130 | 131 | 132 | def _get_metadata_files(): 133 | from pkg_resources import resource_filename 134 | for name in ["training_classes.csv"]: 135 | # TODO: rework to test encounters as well. 136 | yield os.path.abspath(resource_filename('classification.data', name)) 137 | 138 | 139 | class MetadataConsistencyTest(tf.test.TestCase): 140 | def test_metadata_consistency(self): 141 | for metadata_file in _get_metadata_files(): 142 | self.assertTrue(os.path.exists(metadata_file)) 143 | # By putting '' in these sets we can safely remove it later 144 | labels = set(['']) 145 | for row in metadata.metadata_file_reader(metadata_file): 146 | label_str = row['label'] 147 | for lbl in label_str.split('|'): 148 | labels.add(lbl.strip()) 149 | labels.remove('') 150 | 151 | expected = set([lbl for (lbl, _) in metadata.VESSEL_CATEGORIES]) 152 | assert expected >= labels, (expected - labels, labels - expected) 153 | 154 | 155 | class MultihotLabelConsistencyTest(tf.test.TestCase): 156 | def test_fine_label_consistency(self): 157 | names = [] 158 | for coarse, fine_list in metadata.VESSEL_CATEGORIES: 159 | for fine in fine_list: 160 | if fine not in names: 161 | names.append(fine) 162 | self.assertEqual( 163 | sorted(names), sorted(metadata.VESSEL_CLASS_DETAILED_NAMES)) 164 | 165 | 166 | if __name__ == '__main__': 167 | tf.test.main() 168 | -------------------------------------------------------------------------------- /classification/models/fishing_detection.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | import argparse 17 | import json 18 | from .model import ModelBase 19 | from . import layers 20 | from classification import metadata 21 | from .objectives import ( 22 | FishingLocalizationObjectiveCrossEntropy, TrainNetInfo) 23 | from classification.feature_generation import fishing_feature_generation 24 | import logging 25 | import math 26 | import numpy as np 27 | import six 28 | import os 29 | 30 | import tensorflow as tf 31 | 32 | 33 | class Model(ModelBase): 34 | 35 | window_size = 3 36 | stride = 2 37 | feature_depths = [48, 64, 96, 128, 192, 256, 384, 512, 768] 38 | strides = [2] * len(feature_depths) 39 | 40 | initial_learning_rate = 1e-3 41 | learning_decay_rate = 0.5 42 | decay_examples = 10000 43 | 44 | window = (256, 1024) 45 | 46 | @property 47 | def number_of_steps(self): 48 | return 50000 49 | 50 | @property 51 | def max_window_duration_seconds(self): 52 | # A fixed-length rather than fixed-duration window. 53 | return 0 54 | 55 | @property 56 | def window_max_points(self): 57 | return 1024 58 | 59 | 60 | @property 61 | def batch_size(self): 62 | return 16 63 | 64 | @staticmethod 65 | def read_metadata(all_available_ids, 66 | metadata_file, 67 | fishing_ranges, 68 | split): 69 | return metadata.read_vessel_time_weighted_metadata( 70 | all_available_ids, metadata_file, fishing_ranges, 71 | split=split) 72 | 73 | def __init__(self, num_feature_dimensions, vessel_metadata, metrics): 74 | super(Model, self).__init__(num_feature_dimensions, vessel_metadata) 75 | 76 | def length_or_none(id_): 77 | length = vessel_metadata.vessel_label('length', id_) 78 | if length == '': 79 | return None 80 | 81 | return np.float32(length) 82 | 83 | self.fishing_localisation_objective = FishingLocalizationObjectiveCrossEntropy( 84 | 'fishing_localisation', 85 | 'Fishing-localisation', 86 | vessel_metadata, 87 | metrics=metrics, 88 | window=self.window) 89 | 90 | self.objectives = [self.fishing_localisation_objective] 91 | self.objective_map = {obj.name : obj for obj in self.objectives} 92 | 93 | 94 | def build_training_file_list(self, base_feature_path, split): 95 | random_state = np.random.RandomState() 96 | training_ids = self.vessel_metadata.fishing_range_only_list( 97 | random_state, split) 98 | return [ 99 | '%s/%s.tfrecord' % (base_feature_path, six.ensure_text(id_)) 100 | for id_ in training_ids 101 | ] 102 | 103 | def _build_net(self, features, timestamps, ids, is_training): 104 | layers.misconception_fishing( 105 | features, 106 | filters_list=self.feature_depths, 107 | kernel_size=self.window_size, 108 | strides_list=self.strides, 109 | objective_function=self.fishing_localisation_objective, 110 | training=is_training, 111 | pre_filters=128, 112 | post_filters=128, 113 | post_layers=1 114 | ) 115 | 116 | 117 | def make_model_fn(self): 118 | def _model_fn(features, labels, mode, params): 119 | is_train = (mode == tf.estimator.ModeKeys.TRAIN) 120 | ids = features['id'] 121 | time_ranges = features['time_ranges'] 122 | timestamps = features['timestamps'] 123 | features = features['features'] 124 | self._build_net(features, timestamps, ids, is_train) 125 | 126 | if mode == tf.estimator.ModeKeys.PREDICT: 127 | predictions = { 128 | "id" : ids, 129 | "time_ranges": time_ranges, 130 | "timestamps" : timestamps, 131 | self.fishing_localisation_objective.name : self.fishing_localisation_objective.prediction 132 | } 133 | return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) 134 | 135 | global_step = tf.train.get_global_step() 136 | 137 | total_loss = self.fishing_localisation_objective.create_loss(labels) 138 | 139 | learning_rate = tf.train.exponential_decay( 140 | self.initial_learning_rate, global_step, 141 | self.decay_examples, self.learning_decay_rate) 142 | 143 | if mode == tf.estimator.ModeKeys.TRAIN: 144 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 145 | 146 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 147 | with tf.control_dependencies(update_ops): 148 | train_op = optimizer.minimize(loss=total_loss, 149 | global_step=global_step) 150 | 151 | return tf.estimator.EstimatorSpec( 152 | mode=mode, loss=total_loss, train_op=train_op) 153 | 154 | assert mode == tf.estimator.ModeKeys.EVAL 155 | 156 | eval_metrics = self.fishing_localisation_objective.create_metrics(labels) 157 | 158 | return tf.estimator.EstimatorSpec( 159 | mode=mode, 160 | loss=total_loss, 161 | eval_metric_ops=eval_metrics) 162 | return _model_fn 163 | 164 | def make_estimator(self, checkpoint_dir): 165 | session_config = tf.ConfigProto(allow_soft_placement=True) 166 | return tf.estimator.Estimator( 167 | config=tf.estimator.RunConfig( 168 | model_dir=checkpoint_dir, 169 | save_summary_steps=20, 170 | save_checkpoints_secs=300, 171 | keep_checkpoint_max=10, 172 | session_config=session_config), 173 | model_fn=self.make_model_fn(), 174 | params={ 175 | }) 176 | 177 | def make_input_fn(self, base_feature_path, split, parallelism, prefetch): 178 | def input_fn(): 179 | return (fishing_feature_generation.input_fn( 180 | self.vessel_metadata, 181 | self.build_training_file_list(base_feature_path, split), 182 | self.num_feature_dimensions + 1, 183 | self.max_window_duration_seconds, 184 | self.window_max_points, 185 | self.min_viable_timeslice_length, 186 | parallelism=parallelism) 187 | .prefetch(prefetch) 188 | .shuffle(prefetch) 189 | .batch(self.batch_size) 190 | ) 191 | return input_fn 192 | 193 | def make_training_input_fn(self, base_feature_path, num_parallel_reads, prefetch=1024): 194 | return self.make_input_fn(base_feature_path, metadata.TRAINING_SPLIT, num_parallel_reads, prefetch) 195 | 196 | def make_test_input_fn(self, base_feature_path, num_parallel_reads, prefetch=1024): 197 | return self.make_input_fn(base_feature_path, metadata.TEST_SPLIT, num_parallel_reads, prefetch) 198 | 199 | def make_prediction_input_fn(self, paths, range_info, parallelism): 200 | start_date, end_date = range_info 201 | def input_fn(): 202 | return fishing_feature_generation.predict_input_fn( 203 | paths, 204 | self.num_feature_dimensions + 1, 205 | self.window_max_points, 206 | start_date, 207 | end_date, 208 | self.window, 209 | parallelism=parallelism 210 | ).batch(1) 211 | return input_fn 212 | 213 | -------------------------------------------------------------------------------- /classification/models/vessel_characterization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import, division 16 | import argparse 17 | import json 18 | from .model import ModelBase 19 | from . import layers 20 | from classification import metadata 21 | from .objectives import ( 22 | TrainNetInfo, MultiClassificationObjective, LogRegressionObjectiveMAE) 23 | from classification.feature_generation import vessel_feature_generation 24 | import logging 25 | import math 26 | import numpy as np 27 | import os 28 | 29 | import tensorflow as tf 30 | 31 | # These are the approximate means and standard deviations of the features 32 | # generated by looking at a subset of features generated by cook_features 33 | # operating on the real training data. 34 | # approx_means = [ 6.9139094e+00, 4.3506036e+00, 6.1346573e-01, 5.7922113e-01, 35 | # -1.0719098e-03, -3.7928432e-02, 6.0396111e-03, -2.2346964e-01, 36 | # -1.9947174e-01, -8.9961719e-03, 1.5873306e+00, 0.0000000e+00, 37 | # 0.0000000e+00, 0.0000000e+00] 38 | # approx_stds = [0.8707472 , 3.2674177 , 0.7677098 , 0.80428886, 0.3563173 , 39 | # 0.7140226 , 0.6984951 , 0.68056667, 0.6656215 , 0.54217416, 40 | # 1.8521472 , 0. , 0. , 0. ] 41 | 42 | class Model(ModelBase): 43 | 44 | window_size = 3 45 | feature_depths = [48, 64, 96, 128, 192, 256, 384, 512, 768] 46 | strides = [2] * 9 47 | assert len(strides) == len(feature_depths) 48 | feature_sub_depths = 1024 49 | 50 | initial_learning_rate = 100e-5 51 | learning_decay_rate = 0.5 52 | decay_examples = 100000 53 | 54 | @property 55 | def number_of_steps(self): 56 | return 800000 57 | 58 | @property 59 | def max_window_duration_seconds(self): 60 | return 180 * 24 * 3600 61 | 62 | @property 63 | def window_max_points(self): 64 | nominal_max_points = (self.max_window_duration_seconds / (5 * 60)) / 4 65 | layer_reductions = np.prod(self.strides) 66 | final_size = int(round(nominal_max_points / layer_reductions)) 67 | max_points = final_size * layer_reductions 68 | logging.info('Using %s points', max_points) 69 | return max_points 70 | 71 | @property 72 | def min_viable_timeslice_length(self): 73 | return 500 74 | 75 | def __init__(self, num_feature_dimensions, vessel_metadata, metrics): 76 | super(Model, self).__init__(num_feature_dimensions, vessel_metadata) 77 | 78 | class XOrNan: 79 | def __init__(self, key): 80 | self.key = key 81 | 82 | def __call__(self, id_): 83 | x = vessel_metadata.vessel_label(self.key, id_) 84 | if x == '': 85 | x = np.nan 86 | return np.float32(x) 87 | 88 | self.training_objectives = [ 89 | LogRegressionObjectiveMAE( 90 | 'length', 91 | 'Vessel-length', 92 | XOrNan('length'), 93 | metrics=metrics, 94 | loss_weight=0.1), 95 | LogRegressionObjectiveMAE( 96 | 'tonnage', 97 | 'Vessel-tonnage', 98 | XOrNan('tonnage'), 99 | metrics=metrics, 100 | loss_weight=0.1), 101 | LogRegressionObjectiveMAE( 102 | 'engine_power', 103 | 'Vessel-engine-Power', 104 | XOrNan('engine_power'), 105 | metrics=metrics, 106 | loss_weight=0.1), 107 | LogRegressionObjectiveMAE( 108 | 'crew_size', 109 | 'Vessel-Crew-Size', 110 | XOrNan('crew_size'), 111 | metrics=metrics, 112 | loss_weight=0.1), 113 | MultiClassificationObjective( 114 | "Multiclass", "Vessel-class", vessel_metadata, metrics=metrics, loss_weight=1) 115 | ] 116 | 117 | self.objective_map = {obj.name : obj for obj in self.training_objectives} 118 | 119 | def _build_net(self, features, timestamps, ids, is_training): 120 | outputs, _ = layers.misconception_model( 121 | features, 122 | filters_list=self.feature_depths, 123 | kernel_size=self.window_size, 124 | strides_list=self.strides, 125 | objective_functions=self.training_objectives, 126 | training=is_training, 127 | sub_filters=self.feature_sub_depths, 128 | sub_layers=2, 129 | # feature_means=approx_means, 130 | # feature_stds=approx_stds 131 | ) 132 | return outputs 133 | 134 | def make_model_fn(self): 135 | def _model_fn(features, labels, mode, params): 136 | is_train = (mode == tf.estimator.ModeKeys.TRAIN) 137 | ids = features['id'] 138 | time_ranges = features['time_ranges'] 139 | timestamps = features['timestamps'] 140 | features = features['features'] 141 | self._build_net(features, timestamps, ids, is_train) 142 | 143 | if mode == tf.estimator.ModeKeys.PREDICT: 144 | predictions = { 145 | "id" : ids, 146 | "time_ranges" : time_ranges, 147 | "timestamps" : timestamps 148 | } 149 | for obj in self.training_objectives: 150 | predictions[obj.name] = obj.prediction 151 | return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) 152 | 153 | global_step = tf.train.get_global_step() 154 | 155 | total_loss = 0 156 | for obj in self. training_objectives: 157 | total_loss += obj.create_loss(labels[obj.name]) 158 | 159 | learning_rate = tf.train.exponential_decay( 160 | self.initial_learning_rate, global_step, 161 | self.decay_examples, self.learning_decay_rate) 162 | 163 | if mode == tf.estimator.ModeKeys.TRAIN: 164 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 165 | 166 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 167 | with tf.control_dependencies(update_ops): 168 | train_op = optimizer.minimize(loss=total_loss, 169 | global_step=global_step) 170 | 171 | return tf.estimator.EstimatorSpec( 172 | mode=mode, loss=total_loss, train_op=train_op) 173 | 174 | assert mode == tf.estimator.ModeKeys.EVAL 175 | 176 | eval_metrics = {} 177 | for obj in self.training_objectives: 178 | eval_metrics.update(obj.create_metrics(labels[obj.name])) 179 | 180 | return tf.estimator.EstimatorSpec( 181 | mode=mode, 182 | loss=total_loss, 183 | eval_metric_ops=eval_metrics) 184 | return _model_fn 185 | 186 | def make_estimator(self, checkpoint_dir): 187 | session_config = tf.ConfigProto(allow_soft_placement=True) 188 | return tf.estimator.Estimator( 189 | config=tf.estimator.RunConfig( 190 | model_dir=checkpoint_dir, 191 | save_summary_steps=20, 192 | save_checkpoints_secs=300, 193 | keep_checkpoint_max=10, 194 | session_config=session_config), 195 | model_fn=self.make_model_fn(), 196 | params={ 197 | }) 198 | 199 | def make_input_fn(self, base_feature_path, split, parallelism, prefetch): 200 | def input_fn(): 201 | return (vessel_feature_generation.input_fn( 202 | self.vessel_metadata, 203 | self.build_training_file_list(base_feature_path, split), 204 | self.num_feature_dimensions + 1, 205 | self.max_window_duration_seconds, 206 | self.window_max_points, 207 | self.min_viable_timeslice_length, 208 | objectives=self.training_objectives, 209 | parallelism=parallelism) 210 | .prefetch(prefetch) 211 | .shuffle(prefetch) 212 | .batch(self.batch_size) 213 | ) 214 | return input_fn 215 | 216 | def make_training_input_fn(self, base_feature_path, parallelism, prefetch=1024): 217 | return self.make_input_fn(base_feature_path, metadata.TRAINING_SPLIT, parallelism, prefetch) 218 | 219 | def make_test_input_fn(self, base_feature_path, parallelism, prefetch=1024): 220 | return self.make_input_fn(base_feature_path, metadata.TEST_SPLIT, parallelism, prefetch) 221 | 222 | def make_prediction_input_fn(self, paths, range_info, parallelism): 223 | time_ranges = range_info 224 | def input_fn(): 225 | return vessel_feature_generation.predict_input_fn( 226 | paths, 227 | self.num_feature_dimensions + 1, 228 | time_ranges, 229 | self.window_max_points, 230 | self.min_viable_timeslice_length, 231 | parallelism=parallelism 232 | ).batch(1) 233 | return input_fn 234 | 235 | -------------------------------------------------------------------------------- /classification/models/vessel_characterization_depth.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import, division 16 | import argparse 17 | import json 18 | from .model import ModelBase 19 | from . import layers 20 | from classification import metadata 21 | from .objectives import ( 22 | TrainNetInfo, MultiClassificationObjective, LogRegressionObjectiveMAE) 23 | from classification.feature_generation import vessel_feature_generation 24 | import logging 25 | import math 26 | import numpy as np 27 | import os 28 | 29 | import tensorflow as tf 30 | 31 | # These are the approximate means and standard deviations of the features 32 | # generated by looking at a subset of features generated by cook_features 33 | # operating on the real training data. 34 | # approx_means = [ 6.9139094e+00, 4.3506036e+00, 6.1346573e-01, 5.7922113e-01, 35 | # -1.0719098e-03, -3.7928432e-02, 6.0396111e-03, -2.2346964e-01, 36 | # -1.9947174e-01, -8.9961719e-03, 1.5873306e+00, 0.0000000e+00, 37 | # 0.0000000e+00, 0.0000000e+00] 38 | # approx_stds = [0.8707472 , 3.2674177 , 0.7677098 , 0.80428886, 0.3563173 , 39 | # 0.7140226 , 0.6984951 , 0.68056667, 0.6656215 , 0.54217416, 40 | # 1.8521472 , 0. , 0. , 0. ] 41 | 42 | class Model(ModelBase): 43 | 44 | window_size = 3 45 | feature_depths = [48, 64, 96, 128, 192, 256, 384, 512, 768] 46 | strides = [2] * 9 47 | assert len(strides) == len(feature_depths) 48 | feature_sub_depths = 1024 49 | 50 | initial_learning_rate = 100e-5 51 | learning_decay_rate = 0.5 52 | decay_examples = 100000 53 | 54 | @property 55 | def number_of_steps(self): 56 | return 800000 57 | 58 | @property 59 | def max_window_duration_seconds(self): 60 | return 180 * 24 * 3600 61 | 62 | @property 63 | def window_max_points(self): 64 | nominal_max_points = (self.max_window_duration_seconds / (5 * 60)) / 4 65 | layer_reductions = np.prod(self.strides) 66 | final_size = int(round(nominal_max_points / layer_reductions)) 67 | max_points = final_size * layer_reductions 68 | logging.info('Using %s points', max_points) 69 | return max_points 70 | 71 | @property 72 | def min_viable_timeslice_length(self): 73 | return 500 74 | 75 | def __init__(self, num_feature_dimensions, vessel_metadata, metrics): 76 | super(Model, self).__init__(num_feature_dimensions, vessel_metadata) 77 | 78 | class XOrNan: 79 | def __init__(self, key): 80 | self.key = key 81 | 82 | def __call__(self, id_): 83 | x = vessel_metadata.vessel_label(self.key, id_) 84 | if x == '': 85 | x = np.nan 86 | return np.float32(x) 87 | 88 | self.training_objectives = [ 89 | LogRegressionObjectiveMAE( 90 | 'length', 91 | 'Vessel-length', 92 | XOrNan('length'), 93 | metrics=metrics, 94 | loss_weight=0.1), 95 | LogRegressionObjectiveMAE( 96 | 'tonnage', 97 | 'Vessel-tonnage', 98 | XOrNan('tonnage'), 99 | metrics=metrics, 100 | loss_weight=0.1), 101 | LogRegressionObjectiveMAE( 102 | 'engine_power', 103 | 'Vessel-engine-Power', 104 | XOrNan('engine_power'), 105 | metrics=metrics, 106 | loss_weight=0.1), 107 | LogRegressionObjectiveMAE( 108 | 'crew_size', 109 | 'Vessel-Crew-Size', 110 | XOrNan('crew_size'), 111 | metrics=metrics, 112 | loss_weight=0.1), 113 | MultiClassificationObjective( 114 | "Multiclass", "Vessel-class", vessel_metadata, metrics=metrics, loss_weight=1) 115 | ] 116 | 117 | self.objective_map = {obj.name : obj for obj in self.training_objectives} 118 | 119 | def _build_net(self, features, timestamps, ids, is_training): 120 | outputs, _ = layers.misconception_model( 121 | features, 122 | filters_list=self.feature_depths, 123 | kernel_size=self.window_size, 124 | strides_list=self.strides, 125 | objective_functions=self.training_objectives, 126 | training=is_training, 127 | sub_filters=self.feature_sub_depths, 128 | sub_layers=2, 129 | # feature_means=approx_means, 130 | # feature_stds=approx_stds 131 | ) 132 | return outputs 133 | 134 | def make_model_fn(self): 135 | def _model_fn(features, labels, mode, params): 136 | is_train = (mode == tf.estimator.ModeKeys.TRAIN) 137 | ids = features['id'] 138 | time_ranges = features['time_ranges'] 139 | timestamps = features['timestamps'] 140 | features = features['features'] 141 | self._build_net(features, timestamps, ids, is_train) 142 | 143 | if mode == tf.estimator.ModeKeys.PREDICT: 144 | predictions = { 145 | "id" : ids, 146 | "time_ranges" : time_ranges, 147 | "timestamps" : timestamps 148 | } 149 | for obj in self.training_objectives: 150 | predictions[obj.name] = obj.prediction 151 | return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) 152 | 153 | global_step = tf.train.get_global_step() 154 | 155 | total_loss = 0 156 | for obj in self. training_objectives: 157 | total_loss += obj.create_loss(labels[obj.name]) 158 | 159 | learning_rate = tf.train.exponential_decay( 160 | self.initial_learning_rate, global_step, 161 | self.decay_examples, self.learning_decay_rate) 162 | 163 | if mode == tf.estimator.ModeKeys.TRAIN: 164 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 165 | 166 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 167 | with tf.control_dependencies(update_ops): 168 | train_op = optimizer.minimize(loss=total_loss, 169 | global_step=global_step) 170 | 171 | return tf.estimator.EstimatorSpec( 172 | mode=mode, loss=total_loss, train_op=train_op) 173 | 174 | assert mode == tf.estimator.ModeKeys.EVAL 175 | 176 | eval_metrics = {} 177 | for obj in self.training_objectives: 178 | eval_metrics.update(obj.create_metrics(labels[obj.name])) 179 | 180 | return tf.estimator.EstimatorSpec( 181 | mode=mode, 182 | loss=total_loss, 183 | eval_metric_ops=eval_metrics) 184 | return _model_fn 185 | 186 | def make_estimator(self, checkpoint_dir): 187 | session_config = tf.ConfigProto(allow_soft_placement=True) 188 | return tf.estimator.Estimator( 189 | config=tf.estimator.RunConfig( 190 | model_dir=checkpoint_dir, 191 | save_summary_steps=20, 192 | save_checkpoints_secs=300, 193 | keep_checkpoint_max=10, 194 | session_config=session_config), 195 | model_fn=self.make_model_fn(), 196 | params={ 197 | }) 198 | 199 | def make_input_fn(self, base_feature_path, split, parallelism, prefetch): 200 | def input_fn(): 201 | return (vessel_feature_generation.input_fn( 202 | self.vessel_metadata, 203 | self.build_training_file_list(base_feature_path, split), 204 | self.num_feature_dimensions + 1, 205 | self.max_window_duration_seconds, 206 | self.window_max_points, 207 | self.min_viable_timeslice_length, 208 | objectives=self.training_objectives, 209 | parallelism=parallelism) 210 | .prefetch(prefetch) 211 | .shuffle(prefetch) 212 | .batch(self.batch_size) 213 | ) 214 | return input_fn 215 | 216 | def make_training_input_fn(self, base_feature_path, parallelism, prefetch=1024): 217 | return self.make_input_fn(base_feature_path, metadata.TRAINING_SPLIT, parallelism, prefetch) 218 | 219 | def make_test_input_fn(self, base_feature_path, parallelism, prefetch=1024): 220 | return self.make_input_fn(base_feature_path, metadata.TEST_SPLIT, parallelism, prefetch) 221 | 222 | def make_prediction_input_fn(self, paths, range_info, parallelism): 223 | time_ranges = range_info 224 | def input_fn(): 225 | return vessel_feature_generation.predict_input_fn( 226 | paths, 227 | self.num_feature_dimensions + 1, 228 | time_ranges, 229 | self.window_max_points, 230 | self.min_viable_timeslice_length, 231 | parallelism=parallelism 232 | ).batch(1) 233 | return input_fn 234 | 235 | -------------------------------------------------------------------------------- /classification/models/vessel_characterization_shakex2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import, division 16 | import argparse 17 | import json 18 | from .model import ModelBase 19 | from . import layers_shakex2 20 | from classification import metadata 21 | from .objectives import ( 22 | TrainNetInfo, MultiClassificationObjective, LogRegressionObjectiveMAE) 23 | from classification.feature_generation import vessel_feature_generation 24 | import logging 25 | import math 26 | import numpy as np 27 | import os 28 | 29 | import tensorflow as tf 30 | 31 | # These are the approximate means and standard deviations of the features 32 | # generated by looking at a subset of features generated by cook_features 33 | # operating on the real training data. 34 | # approx_means = [ 6.9139094e+00, 4.3506036e+00, 6.1346573e-01, 5.7922113e-01, 35 | # -1.0719098e-03, -3.7928432e-02, 6.0396111e-03, -2.2346964e-01, 36 | # -1.9947174e-01, -8.9961719e-03, 1.5873306e+00, 0.0000000e+00, 37 | # 0.0000000e+00, 0.0000000e+00] 38 | # approx_stds = [0.8707472 , 3.2674177 , 0.7677098 , 0.80428886, 0.3563173 , 39 | # 0.7140226 , 0.6984951 , 0.68056667, 0.6656215 , 0.54217416, 40 | # 1.8521472 , 0. , 0. , 0. ] 41 | 42 | class Model(ModelBase): 43 | 44 | window_size = 3 45 | feature_depths = [48, 64, 96, 128, 192, 256, 384, 512, 768] 46 | strides = [2] * 9 47 | assert len(strides) == len(feature_depths) 48 | feature_sub_depths = 1024 49 | 50 | initial_learning_rate = 100e-5 51 | learning_decay_rate = 0.5 52 | decay_examples = 100000 53 | 54 | @property 55 | def number_of_steps(self): 56 | return 800000 57 | 58 | @property 59 | def max_window_duration_seconds(self): 60 | return 180 * 24 * 3600 61 | 62 | @property 63 | def window_max_points(self): 64 | nominal_max_points = (self.max_window_duration_seconds / (5 * 60)) / 4 65 | layer_reductions = np.prod(self.strides) 66 | final_size = int(round(nominal_max_points / layer_reductions)) 67 | max_points = final_size * layer_reductions 68 | logging.info('Using %s points', max_points) 69 | return max_points 70 | 71 | @property 72 | def min_viable_timeslice_length(self): 73 | return 500 74 | 75 | def __init__(self, num_feature_dimensions, vessel_metadata, metrics): 76 | super(Model, self).__init__(num_feature_dimensions, vessel_metadata) 77 | 78 | class XOrNan: 79 | def __init__(self, key): 80 | self.key = key 81 | 82 | def __call__(self, id_): 83 | x = vessel_metadata.vessel_label(self.key, id_) 84 | if x == '': 85 | x = np.nan 86 | return np.float32(x) 87 | 88 | self.training_objectives = [ 89 | LogRegressionObjectiveMAE( 90 | 'length', 91 | 'Vessel-length', 92 | XOrNan('length'), 93 | metrics=metrics, 94 | loss_weight=0.1), 95 | LogRegressionObjectiveMAE( 96 | 'tonnage', 97 | 'Vessel-tonnage', 98 | XOrNan('tonnage'), 99 | metrics=metrics, 100 | loss_weight=0.1), 101 | LogRegressionObjectiveMAE( 102 | 'engine_power', 103 | 'Vessel-engine-Power', 104 | XOrNan('engine_power'), 105 | metrics=metrics, 106 | loss_weight=0.1), 107 | LogRegressionObjectiveMAE( 108 | 'crew_size', 109 | 'Vessel-Crew-Size', 110 | XOrNan('crew_size'), 111 | metrics=metrics, 112 | loss_weight=0.1), 113 | MultiClassificationObjective( 114 | "Multiclass", "Vessel-class", vessel_metadata, metrics=metrics, loss_weight=1) 115 | ] 116 | 117 | self.objective_map = {obj.name : obj for obj in self.training_objectives} 118 | 119 | def _build_net(self, features, timestamps, ids, is_training): 120 | outputs = layers_shakex2.shake2_model( 121 | features, 122 | filters_list=self.feature_depths, 123 | kernel_size=self.window_size, 124 | strides_list=self.strides, 125 | objective_functions=self.training_objectives, 126 | training=is_training, 127 | sub_filters=self.feature_sub_depths, 128 | sub_layers=2, 129 | # feature_means=approx_means, 130 | # feature_stds=approx_stds 131 | ) 132 | return outputs 133 | 134 | def make_model_fn(self): 135 | def _model_fn(features, labels, mode, params): 136 | is_train = (mode == tf.estimator.ModeKeys.TRAIN) 137 | ids = features['id'] 138 | time_ranges = features['time_ranges'] 139 | timestamps = features['timestamps'] 140 | features = features['features'] 141 | self._build_net(features, timestamps, ids, is_train) 142 | 143 | if mode == tf.estimator.ModeKeys.PREDICT: 144 | predictions = { 145 | "id" : ids, 146 | "time_ranges" : time_ranges, 147 | "timestamps" : timestamps 148 | } 149 | for obj in self.training_objectives: 150 | predictions[obj.name] = obj.prediction 151 | return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) 152 | 153 | global_step = tf.train.get_global_step() 154 | 155 | total_loss = 0 156 | for obj in self. training_objectives: 157 | total_loss += obj.create_loss(labels[obj.name]) 158 | 159 | learning_rate = tf.train.exponential_decay( 160 | self.initial_learning_rate, global_step, 161 | self.decay_examples, self.learning_decay_rate) 162 | 163 | if mode == tf.estimator.ModeKeys.TRAIN: 164 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 165 | 166 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 167 | with tf.control_dependencies(update_ops): 168 | train_op = optimizer.minimize(loss=total_loss, 169 | global_step=global_step) 170 | 171 | return tf.estimator.EstimatorSpec( 172 | mode=mode, loss=total_loss, train_op=train_op) 173 | 174 | assert mode == tf.estimator.ModeKeys.EVAL 175 | 176 | eval_metrics = {} 177 | for obj in self.training_objectives: 178 | eval_metrics.update(obj.create_metrics(labels[obj.name])) 179 | 180 | return tf.estimator.EstimatorSpec( 181 | mode=mode, 182 | loss=total_loss, 183 | eval_metric_ops=eval_metrics) 184 | return _model_fn 185 | 186 | def make_estimator(self, checkpoint_dir): 187 | session_config = tf.ConfigProto(allow_soft_placement=True) 188 | return tf.estimator.Estimator( 189 | config=tf.estimator.RunConfig( 190 | model_dir=checkpoint_dir, 191 | save_summary_steps=20, 192 | save_checkpoints_secs=300, 193 | keep_checkpoint_max=10, 194 | session_config=session_config), 195 | model_fn=self.make_model_fn(), 196 | params={ 197 | }) 198 | 199 | def make_input_fn(self, base_feature_path, split, parallelism, prefetch): 200 | def input_fn(): 201 | return (vessel_feature_generation.input_fn( 202 | self.vessel_metadata, 203 | self.build_training_file_list(base_feature_path, split), 204 | self.num_feature_dimensions + 1, 205 | self.max_window_duration_seconds, 206 | self.window_max_points, 207 | self.min_viable_timeslice_length, 208 | objectives=self.training_objectives, 209 | parallelism=parallelism) 210 | .prefetch(prefetch) 211 | .shuffle(prefetch) 212 | .batch(self.batch_size) 213 | ) 214 | return input_fn 215 | 216 | def make_training_input_fn(self, base_feature_path, parallelism, prefetch=1024): 217 | return self.make_input_fn(base_feature_path, metadata.TRAINING_SPLIT, parallelism, prefetch) 218 | 219 | def make_test_input_fn(self, base_feature_path, parallelism, prefetch=1024): 220 | return self.make_input_fn(base_feature_path, metadata.TEST_SPLIT, parallelism, prefetch) 221 | 222 | def make_prediction_input_fn(self, paths, range_info, parallelism): 223 | time_ranges = range_info 224 | def input_fn(): 225 | return vessel_feature_generation.predict_input_fn( 226 | paths, 227 | self.num_feature_dimensions + 1, 228 | time_ranges, 229 | self.window_max_points, 230 | self.min_viable_timeslice_length, 231 | parallelism=parallelism 232 | ).batch(1) 233 | return input_fn 234 | 235 | -------------------------------------------------------------------------------- /classification/models/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import tensorflow as tf 16 | import tensorflow.layers as ly 17 | import numpy as np 18 | 19 | 20 | def zero_pad_features(features, depth): 21 | """ Zero-pad features in the depth dimension to match requested feature depth. """ 22 | 23 | n = int(features.get_shape().dims[-1]) 24 | extra_feature_count = depth - n 25 | assert n >= 0 26 | if n > 0: 27 | padding = tf.tile(features[:, :, :1] * 0, 28 | [1, 1, extra_feature_count]) 29 | features = tf.concat([features, padding], 2) 30 | return features 31 | 32 | def repeat_tensor(input, n): 33 | batch_size, width, depth = input.get_shape() 34 | repeated = tf.concat([input] * n, 2) 35 | return tf.reshape(repeated, [-1, int(width) * n, int(depth)]) 36 | 37 | 38 | def misconception_layer(inputs, 39 | filters, 40 | kernel_size, 41 | strides, 42 | training, 43 | scope=None, 44 | virtual_batch_size=None): 45 | """ A single layer of the misconception convolutional network. 46 | 47 | Args: 48 | input: a tensor of size [batch_size, 1, width, depth] 49 | window_size: the width of the conv and pooling filters to apply. 50 | stride: the downsampling to apply when filtering. 51 | depth: the depth of the output tensor. 52 | 53 | Returns: 54 | a tensor of size [batch_size, 1, width/stride, depth]. 55 | """ 56 | with tf.name_scope(scope): 57 | extra = kernel_size - strides 58 | p0 = extra // 2 59 | p1 = extra - p0 60 | padded = tf.pad(inputs, [[0, 0], [p0, p1], [0, 0]]) 61 | stage_conv = ly.conv1d( 62 | padded, filters, kernel_size, strides=strides, padding="valid", activation=None, use_bias=False) 63 | stage_conv = ly.batch_normalization(stage_conv, training=training, virtual_batch_size=virtual_batch_size) 64 | stage_conv = tf.nn.relu(stage_conv) 65 | stage_max_pool_reduce = tf.layers.max_pooling1d( 66 | padded, kernel_size, strides=strides, padding="valid") 67 | concat = tf.concat([stage_conv, stage_max_pool_reduce], 2) 68 | 69 | total = ly.conv1d(concat, filters, 1, activation=None, use_bias=False) 70 | total = ly.batch_normalization(total, training=training, virtual_batch_size=virtual_batch_size) 71 | total = tf.nn.relu(total) 72 | return total 73 | 74 | 75 | 76 | def misconception_with_bypass(inputs, 77 | filters, 78 | kernel_size, 79 | strides, 80 | training, 81 | scope=None, 82 | virtual_batch_size=None): 83 | with tf.name_scope(scope): 84 | residual = misconception_layer(inputs, filters, kernel_size, strides, training, scope, virtual_batch_size) 85 | if strides > 1: 86 | inputs = tf.layers.max_pooling1d( 87 | inputs, strides, strides=strides, padding="valid") 88 | inputs = zero_pad_features(inputs, filters) 89 | return inputs + residual 90 | 91 | 92 | def misconception_model(inputs, 93 | filters_list, 94 | kernel_size, 95 | strides_list, 96 | training, 97 | objective_functions, 98 | sub_filters=128, 99 | sub_layers=2, 100 | dropout_rate=0.5, 101 | virtual_batch_size=None, 102 | feature_means=None, 103 | feature_stds=None): 104 | """ A misconception tower. 105 | 106 | Args: 107 | input: a tensor of size [batch_size, 1, width, depth]. 108 | window_size: the width of the conv and pooling filters to apply. 109 | depth: the depth of the output tensor. 110 | levels: the height of the tower in misconception layers. 111 | objective_functions: a list of objective functions to add to the top of 112 | the network. 113 | is_training: whether the network is training. 114 | 115 | Returns: 116 | a tensor of size [batch_size, num_classes]. 117 | """ 118 | layers = [] 119 | net = inputs 120 | if feature_means is not None: 121 | net = net - tf.constant(feature_means)[None, None, :] 122 | if feature_stds is not None: 123 | net = net / (tf.constant(feature_stds) + 1e-6) 124 | layers.append(net) 125 | for filters, strides in zip(filters_list, strides_list): 126 | net = misconception_with_bypass(net, filters, kernel_size, strides, training, virtual_batch_size=virtual_batch_size) 127 | layers.append(net) 128 | outputs = [] 129 | for ofunc in objective_functions: 130 | onet = net 131 | for _ in range(sub_layers - 1): 132 | onet = ly.conv1d(onet, sub_filters, 1, activation=None, use_bias=False) 133 | onet = ly.batch_normalization(onet, training=training, virtual_batch_size=virtual_batch_size) 134 | onet = tf.nn.relu(onet) 135 | onet = ly.conv1d(onet, sub_filters, 1, activation=tf.nn.relu) 136 | onet = ly.flatten(onet) 137 | # 138 | onet = ly.dropout(onet, training=training, rate=dropout_rate) 139 | outputs.append(ofunc.build(onet)) 140 | 141 | return outputs, layers 142 | 143 | 144 | def misconception_model_2(inputs, 145 | filters_list, 146 | kernel_size, 147 | strides_list, 148 | training, 149 | objective_functions, 150 | sub_filters=128, 151 | sub_layers=2, 152 | dropout_rate=0.5): 153 | """ A misconception tower. 154 | 155 | Args: 156 | input: a tensor of size [batch_size, 1, width, depth]. 157 | window_size: the width of the conv and pooling filters to apply. 158 | depth: the depth of the output tensor. 159 | levels: the height of the tower in misconception layers. 160 | objective_functions: a list of objective functions to add to the top of 161 | the network. 162 | is_training: whether the network is training. 163 | 164 | Returns: 165 | a tensor of size [batch_size, num_classes]. 166 | """ 167 | layers = [] 168 | net = inputs 169 | layers.append(net) 170 | for filters, strides in zip(filters_list, strides_list): 171 | net = misconception_with_bypass(net, filters, kernel_size, strides, training) 172 | layers.append(net) 173 | onet = net 174 | for _ in range(sub_layers - 1): 175 | onet = ly.conv1d(onet, sub_filters, 1, activation=None, use_bias=False) 176 | onet = ly.batch_normalization(onet, training=training) 177 | onet = tf.nn.relu(onet) 178 | onet = ly.conv1d(onet, sub_filters, 1, activation=tf.nn.relu) 179 | snet = ly.conv1d(onet, 1, 1, activation=tf.nn.relu)[:, :, 0] 180 | selector = tf.expand_dims(tf.nn.softmax(snet), 2) 181 | 182 | outputs = [] 183 | for ofunc in objective_functions: 184 | onet = net 185 | for _ in range(sub_layers - 1): 186 | onet = ly.conv1d(onet, sub_filters, 1, activation=None, use_bias=False) 187 | onet = ly.batch_normalization(onet, training=training) 188 | onet = tf.nn.relu(onet) 189 | onet = ly.conv1d(onet, sub_filters, 1, activation=tf.nn.relu) 190 | 191 | onet = onet * selector 192 | n = int(onet.get_shape().dims[1]) 193 | onet = ly.average_pooling1d(onet, n, n) 194 | onet = ly.flatten(onet) 195 | # 196 | onet = ly.dropout(onet, training=training, rate=dropout_rate) 197 | outputs.append(ofunc.build(onet)) 198 | 199 | return outputs, layers 200 | 201 | 202 | def misconception_fishing(inputs, 203 | filters_list, 204 | kernel_size, 205 | strides_list, 206 | objective_function, 207 | training, 208 | pre_filters=128, 209 | post_filters=128, 210 | post_layers=1, 211 | dropout_rate=0.5, 212 | internal_dropout_rate=0.5, 213 | other_objectives=(), 214 | feature_means=None, 215 | feature_stds=None): 216 | 217 | _, layers = misconception_model( 218 | inputs, 219 | filters_list, 220 | kernel_size, 221 | strides_list, 222 | training, 223 | other_objectives, 224 | sub_filters=post_filters, 225 | sub_layers=2, 226 | dropout_rate=internal_dropout_rate, 227 | feature_means=feature_means, 228 | feature_stds=feature_stds 229 | ) 230 | 231 | expanded_layers = [] 232 | for i, lyr in enumerate(layers): 233 | lyr = ly.conv1d(lyr, pre_filters, 1, activation=None) 234 | lyr = ly.batch_normalization(lyr, training=training) 235 | lyr = tf.nn.relu(lyr) 236 | expanded_layers.append(repeat_tensor(lyr, 2**i)) 237 | 238 | embedding = tf.add_n(expanded_layers) 239 | 240 | for _ in range(post_layers - 1): 241 | embedding = ly.conv1d(embedding, post_filters, 1, activation=None, use_bias=False) 242 | embedding = ly.batch_normalization(embedding, training=training) 243 | embedding = tf.nn.relu(embedding) 244 | 245 | embedding = ly.conv1d(embedding, post_filters, 1, activation=tf.nn.relu) 246 | embedding = ly.dropout(embedding, training=training, rate=dropout_rate) 247 | 248 | fishing_outputs = ly.conv1d(embedding, 1, 1, activation=None) 249 | 250 | return objective_function.build(fishing_outputs) 251 | 252 | 253 | 254 | -------------------------------------------------------------------------------- /train/training_log.yaml: -------------------------------------------------------------------------------- 1 | - date: 2019-5-27 2 | notes: | 3 | Redo everything with features sharded with only 100 vessels needed. Also 4 | change to weighting to more closely match old runs. 5 | commit: d549bd2d143e82c0cf710c0f7aa6adbac6b37f41 6 | data_creation_command: | 7 | python -m train.create_train_info \ 8 | --vessel-database vessel_database.all_vessels_20190102 \ 9 | --fishing-table machine_learning_production.fishing_ranges_by_mmsi_v20190506 \ 10 | --id-type vessel-id \ 11 | --id-list gs://machine-learning-dev-ttl-120d/features/v3_vid_features_v20190523/ids/part-00000-of-00001.txt \ 12 | --dataset pipe_production_b \ 13 | --charinfo-file classification/data/char_info_v20190527.csv \ 14 | --detinfo-file classification/data/det_info_v20190527.csv \ 15 | --detranges-file classification/data/det_ranges_v20190527.csv \ 16 | --charinfo-table machine_learning_dev_ttl_120d.char_info_v20190527 \ 17 | --detinfo-table machine_learning_dev_ttl_120d.det_info_v20190527 \ 18 | --detranges-table machine_learning_dev_ttl_120d.det_ranges_v20190527 19 | detection_training_commands: 20 | - | 21 | python -m train.deploy_cloudml \ 22 | --env dev \ 23 | --model_name fishing_detection \ 24 | --job_name pv3_0527_0_lin100 \ 25 | --config train/deploy_v.yaml \ 26 | --feature_path gs://machine-learning-dev-ttl-120d/features/v3_vid_features_v20190523/features \ 27 | --vessel_info det_info_v20190527.csv \ 28 | --fishing_ranges det_ranges_v20190527.csv 29 | - | 30 | python -m train.deploy_cloudml \ 31 | --env dev \ 32 | --model_name fishing_detection \ 33 | --job_name pv3_0527_1_lin100 \ 34 | --config train/deploy_v.yaml \ 35 | --feature_path gs://machine-learning-dev-ttl-120d/features/v3_vid_features_v20190523/features \ 36 | --vessel_info det_info_v20190527.csv \ 37 | --fishing_ranges det_ranges_v20190527.csv \ 38 | --split 1 39 | - | 40 | python -m train.deploy_cloudml \ 41 | --env dev \ 42 | --model_name fishing_detection \ 43 | --job_name pv3_0527_all_lin100 \ 44 | --config train/deploy_v.yaml \ 45 | --feature_path gs://machine-learning-dev-ttl-120d/features/v3_vid_features_v20190523/features \ 46 | --vessel_info det_info_v20190527.csv \ 47 | --fishing_ranges det_ranges_v20190527.csv \ 48 | --split -1 49 | detection_inference_commands: 50 | - | 51 | python -m pipe_features.fishing_inference \ 52 | --feature_path gs://machine-learning-dev-ttl-120d/features/v3_vid_features_v20190523/features \ 53 | --checkpoint_path gs://machine-learning-dev-ttl-120d/data-production/classification/timothyhochberg/pv3_0527_0_lin100/models/fishing_detection \ 54 | --results_table machine_learning_dev_ttl_120d.fishing_detection_fishing_inference_v20190527_0_lin100_ \ 55 | --start_date 2012-01-01 \ 56 | --end_date 2018-12-31 \ 57 | --feature_dimensions 14 \ 58 | --temp_location=gs://machine-learning-dev-ttl-30d/scratch/nnet-char \ 59 | --runner DataflowRunner \ 60 | --project=world-fishing-827 \ 61 | --job_name=fishing-test \ 62 | --max_num_workers 100 \ 63 | --requirements_file=./requirements.txt \ 64 | --setup_file=./setup.py \ 65 | --worker_machine_type=custom-1-13312-ext \ 66 | --id_field_name vessel_id 67 | - | 68 | python -m pipe_features.fishing_inference \ 69 | --feature_path gs://machine-learning-dev-ttl-120d/features/v3_vid_features_v20190523/features \ 70 | --checkpoint_path gs://machine-learning-dev-ttl-120d/data-production/classification/timothyhochberg/pv3_0527_1_lin100/models/fishing_detection \ 71 | --results_table machine_learning_dev_ttl_120d.fishing_detection_fishing_inference_v20190527_1_lin100_ \ 72 | --start_date 2012-01-01 \ 73 | --end_date 2018-12-31 \ 74 | --feature_dimensions 14 \ 75 | --temp_location=gs://machine-learning-dev-ttl-30d/scratch/nnet-char \ 76 | --runner DataflowRunner \ 77 | --project=world-fishing-827 \ 78 | --job_name=fishing-test-1 \ 79 | --max_num_workers 100 \ 80 | --requirements_file=./requirements.txt \ 81 | --setup_file=./setup.py \ 82 | --worker_machine_type=custom-1-13312-ext \ 83 | --id_field_name vessel_id 84 | - | 85 | python -m pipe_features.fishing_inference \ 86 | --feature_path gs://machine-learning-dev-ttl-120d/features/v3_vid_features_v20190523/features \ 87 | --checkpoint_path gs://machine-learning-dev-ttl-120d/data-production/classification/timothyhochberg/pv3_0527_all_lin100/models/fishing_detection \ 88 | --results_table machine_learning_dev_ttl_120d.fishing_detection_fishing_inference_v20190527_all_lin100_ \ 89 | --start_date 2012-01-01 \ 90 | --end_date 2018-12-31 \ 91 | --feature_dimensions 14 \ 92 | --temp_location=gs://machine-learning-dev-ttl-30d/scratch/nnet-char \ 93 | --runner DataflowRunner \ 94 | --project=world-fishing-827 \ 95 | --job_name=fishing-test-all \ 96 | --max_num_workers 100 \ 97 | --requirements_file=./requirements.txt \ 98 | --setup_file=./setup.py \ 99 | --worker_machine_type=custom-1-13312-ext \ 100 | --id_field_name vessel_id 101 | detection_metrics_commands: 102 | - | 103 | python -m classification.metrics.compute_fishing_metrics \ 104 | --inference-table machine_learning_dev_ttl_120d.fishing_detection_fishing_inference_v20190527_0_lin100_ \ 105 | --dest-path ./test_fishing_inference_0527_0_lin100.html \ 106 | --label-path classification/data/det_info_v20190527.csv \ 107 | --fishing-ranges classification/data/det_ranges_v20190527.csv \ 108 | --split 0 109 | - | 110 | python -m classification.metrics.compute_fishing_metrics \ 111 | --inference-table machine_learning_dev_ttl_120d.fishing_detection_fishing_inference_v20190527_1_lin100_ \ 112 | --dest-path ./test_fishing_inference_0527_1_lin100.html \ 113 | --label-path classification/data/det_info_v20190527.csv \ 114 | --fishing-ranges classification/data/det_ranges_v20190527.csv \ 115 | --split 1 116 | 117 | 118 | 119 | - date: 2019-5-20 120 | notes: | 121 | Previous run was missing fishing vessels from training set 122 | because `create_train_info.py` was too picky. 123 | commit: 9bec5ef08fb1eea032ebb30bee8e86cd818289f5 124 | data_creation_command: | 125 | python -m train.create_train_info \ 126 | --vessel-database vessel_database.all_vessels_20190102 \ 127 | --fishing-table machine_learning_production.fishing_ranges_by_mmsi_v20190506 \ 128 | --id-type vessel-id \ 129 | --id-list gs://machine-learning-dev-ttl-120d/features/v3_vid_features_v20190503b/ids/part-00000-of-00001.txt \ 130 | --dataset pipe_production_b \ 131 | --charinfo-file classification/data/char_info_v20190520.csv \ 132 | --detinfo-file classification/data/det_info_v20190520.csv \ 133 | --detranges-file classification/data/det_ranges_v20190520.csv 134 | characterization_training_command: ~ 135 | characterization_inference_command: ~ 136 | characterization_metrics_command: ~ 137 | detection_training_command: | 138 | python -m train.deploy_cloudml \ 139 | --env dev \ 140 | --model_name fishing_detection \ 141 | --job_name pv3_0520 \ 142 | --config train/deploy_v.yaml \ 143 | --feature_path gs://machine-learning-dev-ttl-120d/features/v3_vid_features_v20190515/features \ 144 | --vessel_info det_info_v20190520.csv \ 145 | --fishing_ranges det_ranges_v20190520.csv 146 | detection_inference_command: | 147 | python -m pipe_features.fishing_inference \ 148 | --feature_path gs://machine-learning-dev-ttl-120d/features/v3_vid_features_v20190515/features \ 149 | --checkpoint_path gs://machine-learning-dev-ttl-120d/data-production/classification/timothyhochberg/pv3_0520/models/fishing_detection \ 150 | --results_table machine_learning_dev_ttl_120d.fishing_detection_fishing_inference_v20190521B_ \ 151 | --start_date 2012-01-01 \ 152 | --end_date 2018-12-31 \ 153 | --feature_dimensions 14 \ 154 | --temp_location=gs://machine-learning-dev-ttl-30d/scratch/nnet-char \ 155 | --runner DataflowRunner \ 156 | --project=world-fishing-827 \ 157 | --job_name=fishing-test \ 158 | --max_num_workers 100 \ 159 | --requirements_file=./requirements.txt \ 160 | --setup_file=./setup.py \ 161 | --worker_machine_type=custom-1-13312-ext \ 162 | --id_field_name vessel_id 163 | detection_metrics_command: | 164 | python -m classification.metrics.compute_fishing_metrics \ 165 | --inference-table machine_learning_dev_ttl_120d.fishing_detection_fishing_inference_v20190521_ \ 166 | --dest-path ./test_fishing_inference_0520.html \ 167 | --label-path classification/data/det_info_v20190520.csv \ 168 | --fishing-ranges classification/data/det_ranges_v20190520.csv 169 | 170 | 171 | - date: 2019-5-16 172 | notes: Previous run didn't have full date range sharded 173 | commit: e9a1023d032b7d1d7dd7aade6558a5ee2ebcf7a1 174 | data_creation_command: | 175 | python -m train.create_train_info \ 176 | --vessel-database vessel_database.all_vessels_20190102 \ 177 | --fishing-table machine_learning_production.fishing_ranges_by_mmsi_v20190506 \ 178 | --id-type vessel-id \ 179 | --dataset pipe_production_b \ 180 | --id-list gs://machine-learning-dev-ttl-120d/features/v3_vid_features_v20190515/ids/part-00000-of-00001.txt \ 181 | --charinfo-file classification/data/char_info_v20190516.csv \ 182 | --detinfo-file classification/data/det_info_v20190516.csv \ 183 | --detranges-file classification/data/det_ranges_v20190516.csv \ 184 | --charinfo-table machine_learning_dev_ttl_120d.char_info_v20190516 \ 185 | --detinfo-table machine_learning_dev_ttl_120d.det_info_v20190516 \ 186 | --detranges-table machine_learning_dev_ttl_120d.det_ranges_v20190516 187 | characterization_training_command: | 188 | python -m train.deploy_cloudml \ 189 | --env dev \ 190 | --model_name vessel_characterization \ 191 | --job_name pv3_0516 \ 192 | --config train/deploy_v.yaml \ 193 | --feature_path gs://machine-learning-dev-ttl-120d/features/v3_vid_features_v20190515/features \ 194 | --vessel_info char_info_v20190516.csv 195 | characterization_inference_command: | 196 | python -m pipe_features.vessel_inference \ 197 | --feature_path gs://machine-learning-dev-ttl-120d/features/v3_vid_features_v20190503b/features \ 198 | --checkpoint_path gs://machine-learning-dev-ttl-120d/data-production/classification/timothyhochberg/pv3_0516/models/vessel_characterization \ 199 | --results_table machine_learning_dev_ttl_120d.vessel_char_vid_features_v20190520 \ 200 | --start_date 2012-01-01 \ 201 | --end_date 2018-12-31 \ 202 | --feature_dimensions 14 \ 203 | --temp_location=gs://machine-learning-dev-ttl-30d/scratch/nnet-char \ 204 | --runner DataflowRunner \ 205 | --project=world-fishing-827 \ 206 | --job_name=vessel-test \ 207 | --max_num_workers 50 \ 208 | --requirements_file=./requirements.txt \ 209 | --setup_file=./setup.py \ 210 | --worker_machine_type=custom-1-13312-ext \ 211 | --id_field_name vessel_id 212 | characterization_metrics_command: 213 | python -m classification.metrics.compute_vessel_metrics \ 214 | --inference-table machine_learning_dev_ttl_120d.vessel_char_vid_features_v20190520 \ 215 | --label-table machine_learning_dev_ttl_120d.char_info_v20190516 \ 216 | --dest-path ./untracked/metric_results/test_inference_metrics_0516.html 217 | detection_training_command: ~ 218 | detection_inference_command: ~ 219 | detection_metrics_command: ~ 220 | 221 | 222 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /classification/metrics/ydump.py: -------------------------------------------------------------------------------- 1 | css = """ 2 | 3 | table { 4 | text-align: center; 5 | border-collapse: collapse; 6 | } 7 | 8 | .confusion-matrix th.col { 9 | height: 140px; 10 | white-space: nowrap; 11 | } 12 | 13 | .confusion-matrix th.col div { 14 | transform: translate(16px, 49px) rotate(315deg); 15 | width: 30px; 16 | } 17 | 18 | .confusion-matrix th.col span { 19 | border-bottom: 1px solid #ccc; 20 | padding: 5px 10px; 21 | text-align: left; 22 | } 23 | 24 | .confusion-matrix th.row { 25 | text-align: right; 26 | } 27 | 28 | .confusion-matrix td.diagonal { 29 | border: 1px solid black; 30 | } 31 | 32 | .confusion-matrix td.offdiagonal { 33 | border: 1px dotted grey; 34 | } 35 | 36 | .unbreakable { 37 | page-break-inside: avoid; 38 | } 39 | 40 | 41 | 42 | 43 | """ 44 | 45 | # basic metrics 46 | 47 | 48 | def precision_score(y_true, y_pred): 49 | y_true = np.asarray(y_true, dtype=bool) 50 | y_pred = np.asarray(y_pred, dtype=bool) 51 | 52 | true_pos = y_true & y_pred 53 | all_pos = y_pred 54 | 55 | return true_pos.sum() / all_pos.sum() 56 | 57 | 58 | def recall_score(y_true, y_pred): 59 | y_true = np.asarray(y_true, dtype=bool) 60 | y_pred = np.asarray(y_pred, dtype=bool) 61 | 62 | true_pos = y_true & y_pred 63 | all_true = y_true 64 | 65 | return true_pos.sum() / all_true.sum() 66 | 67 | 68 | def f1_score(y_true, y_pred): 69 | prec = precision_score(y_true, y_pred) 70 | recall = recall_score(y_true, y_pred) 71 | 72 | return 2 / (1 / prec + 1 / recall) 73 | 74 | 75 | def accuracy_score(y_true, y_pred, weights=None): 76 | y_true = np.asarray(y_true) 77 | y_pred = np.asarray(y_pred) 78 | if weights is None: 79 | weights = np.ones_like(y_pred).astype(float) 80 | weights = np.asarray(weights) 81 | 82 | correct = (y_true == y_pred) 83 | 84 | return (weights * correct).sum() / weights.sum() 85 | 86 | 87 | def weights(labels, y_true, y_pred, max_weight=200): 88 | y_true = np.asarray(y_true) 89 | y_pred = np.asarray(y_pred) 90 | 91 | weights = np.zeros([len(y_true)]) 92 | for lbl in labels: 93 | trues = (y_true == lbl) 94 | if trues.sum(): 95 | wt = min(len(trues) / trues.sum(), max_weight) 96 | weights += trues * wt 97 | 98 | return weights / weights.sum() 99 | 100 | 101 | def base_confusion_matrix(y_true, y_pred, labels): 102 | n = len(labels) 103 | label_map = {lbl: i for i, lbl in enumerate(labels)} 104 | cm = np.zeros([n, n], dtype=int) 105 | 106 | for yt, yp in zip(y_true, y_pred): 107 | if yt not in label_map: 108 | logging.warn('%s not in label_map', yt) 109 | continue 110 | if yp not in label_map: 111 | logging.warn('%s not in label_map', yp) 112 | continue 113 | cm[label_map[yt], label_map[yp]] += 1 114 | 115 | return cm 116 | 117 | # Helper function formatting as HTML (using yattag) 118 | 119 | 120 | def ydump_confusion_matrix(doc, cm, labels, **kwargs): 121 | """Dump an sklearn confusion matrix to HTML using yatag 122 | 123 | Args: 124 | doc: yatag Doc instance 125 | cm: ConfusionMatrix instance 126 | labels: list of str 127 | labels for confusion matrix 128 | """ 129 | doc, tag, text, line = doc.ttl() 130 | with tag('table', klass='confusion-matrix', **kwargs): 131 | with tag('tr'): 132 | line('th', '') 133 | for x in labels: 134 | with tag('th', klass='col'): 135 | with tag('div'): 136 | line('span', x) 137 | for i, (l, row) in enumerate(zip(labels, cm.scaled)): 138 | with tag('tr'): 139 | line('th', str(l), klass='row') 140 | for j, x in enumerate(row): 141 | if i == j: 142 | if x == -1: 143 | # No values present in this row, column 144 | color = '#FFFFFF' 145 | elif x > 0.5: 146 | cval = np.clip(int(round(512 * (x - 0.5))), 0, 255) 147 | invhexcode = '{:02x}'.format(255 - cval) 148 | color = '#{}FF00'.format(invhexcode) 149 | else: 150 | cval = np.clip(int(round(512 * x)), 0, 255) 151 | hexcode = '{:02x}'.format(cval) 152 | color = '#FF{}00'.format(hexcode) 153 | klass = 'diagonal' 154 | else: 155 | cval = np.clip(int(round(255 * x)), 0, 255) 156 | hexcode = '{:02x}'.format(cval) 157 | invhexcode = '{:02x}'.format(255 - cval) 158 | color = '#FF{}{}'.format(invhexcode, invhexcode) 159 | klass = 'offdiagonal' 160 | with tag('td', klass=klass, bgcolor=color): 161 | raw = cm.raw[i, j] 162 | with tag('font', 163 | color='#000000', 164 | title='{0:.3f}'.format(x)): 165 | text(str(raw)) 166 | 167 | 168 | def ydump_table(doc, headings, rows, **kwargs): 169 | """Dump an html table using yatag 170 | 171 | Args: 172 | doc: yatag Doc instance 173 | headings: [str] 174 | rows: [[str]] 175 | 176 | """ 177 | doc, tag, text, line = doc.ttl() 178 | with tag('table', **kwargs): 179 | with tag('tr'): 180 | for x in headings: 181 | line('th', str(x)) 182 | for row in rows: 183 | with tag('tr'): 184 | for x in row: 185 | line('td', str(x)) 186 | 187 | 188 | def ydump_attrs(doc, results): 189 | """dump metrics for `results` to html using yatag 190 | 191 | Args: 192 | doc: yatag Doc instance 193 | results: InferenceResults instance 194 | 195 | """ 196 | doc, tag, text, line = doc.ttl() 197 | 198 | def RMS(a, b): 199 | return np.sqrt(np.square(a - b).mean()) 200 | 201 | def MAE(a, b): 202 | return abs(a - b).mean() 203 | 204 | # TODO: move computations out of loops for speed. 205 | # true_mask = np.array([(x is not None) for x in results.true_attrs]) 206 | # infer_mask = np.array([(x is not None) for x in results.inferred_attrs]) 207 | true_mask = ~np.isnan(results.true_attrs) 208 | infer_mask = ~np.isnan(results.inferred_attrs) 209 | rows = [] 210 | for dt in np.unique(results.start_dates): 211 | mask = true_mask & infer_mask & (results.start_dates == dt) 212 | rows.append( 213 | [dt, RMS(results.true_attrs[mask], results.inferred_attrs[mask]), 214 | MAE(results.true_attrs[mask], results.inferred_attrs[mask])]) 215 | 216 | with tag('div', klass='unbreakable'): 217 | line('h3', 'RMS Error by Date') 218 | ydump_table(doc, ['Start Date', 'RMS Error', 'Abs Error'], 219 | [(a.date(), '{:.2f}'.format(b), '{:.2f}'.format(c)) 220 | for (a, b, c) in rows]) 221 | 222 | logging.info(' Consolidating attributes') 223 | consolidated = consolidate_attribute_across_dates(results) 224 | # true_mask = np.array([(x is not None) for x in consolidated.true_attrs]) 225 | # infer_mask = np.array([(x is not None) for x in consolidated.inferred_attrs]) 226 | true_mask = ~np.isnan(consolidated.true_attrs) 227 | infer_mask = ~np.isnan(consolidated.inferred_attrs) 228 | 229 | logging.info(' RMS Error') 230 | with tag('div', klass='unbreakable'): 231 | line('h3', 'Overall RMS Error') 232 | text('{:.2f}'.format( 233 | RMS(consolidated.true_attrs[true_mask & infer_mask], 234 | consolidated.inferred_attrs[true_mask & infer_mask]))) 235 | 236 | logging.info(' ABS Error') 237 | with tag('div', klass='unbreakable'): 238 | line('h3', 'Overall Abs Error') 239 | text('{:.2f}'.format( 240 | MAE(consolidated.true_attrs[true_mask & infer_mask], 241 | consolidated.inferred_attrs[true_mask & infer_mask]))) 242 | 243 | def RMS_MAE_by_label(true_attrs, pred_attrs, true_labels): 244 | results = [] 245 | labels = sorted(set(true_labels)) 246 | for lbl in labels: 247 | mask = true_mask & infer_mask & (lbl == true_labels) 248 | if mask.sum(): 249 | err = RMS(true_attrs[mask], pred_attrs[mask]) 250 | abs_err = MAE(true_attrs[mask], pred_attrs[mask]) 251 | count = mask.sum() 252 | results.append( 253 | (lbl, count, err, abs_err, true_attrs[mask].mean(), 254 | true_attrs[mask].std())) 255 | return results 256 | 257 | logging.info(' Error by Label') 258 | with tag('div', klass='unbreakable'): 259 | line('h3', 'RMS Error by Label') 260 | ydump_table( 261 | doc, 262 | ['Label', 'Count', 'RMS Error', 'Abs Error', 'Mean', 'StdDev' 263 | ], # TODO: pass in length and units 264 | [ 265 | (a, count, '{:.2f}'.format(b), '{:.2f}'.format(ab), 266 | '{:.2f}'.format(c), '{:.2f}'.format(d)) 267 | for (a, count, b, ab, c, d) in RMS_MAE_by_label( 268 | consolidated.true_attrs, consolidated.inferred_attrs, 269 | consolidated.true_labels) 270 | ]) 271 | 272 | 273 | def ydump_metrics(doc, results): 274 | """dump metrics for `results` to html using yatag 275 | 276 | Args: 277 | doc: yatag Doc instance 278 | results: InferenceResults instance 279 | 280 | """ 281 | doc, tag, text, line = doc.ttl() 282 | 283 | rows = [ 284 | (x, accuracy_score(results.true_labels, results.inferred_labels, 285 | (results.start_dates == x))) 286 | for x in np.unique(results.start_dates) 287 | ] 288 | 289 | with tag('div', klass='unbreakable'): 290 | line('h3', 'Accuracy by Date') 291 | ydump_table(doc, ['Start Date', 'Accuracy'], 292 | [(a.date(), '{:.2f}'.format(b)) for (a, b) in rows]) 293 | 294 | consolidated = consolidate_across_dates(results) 295 | 296 | with tag('div', klass='unbreakable'): 297 | line('h3', 'Overall Accuracy') 298 | text('{:.2f}'.format( 299 | accuracy_score(consolidated.true_labels, 300 | consolidated.inferred_labels))) 301 | 302 | cm = confusion_matrix(consolidated) 303 | 304 | with tag('div', klass='unbreakable'): 305 | line('h3', 'Confusion Matrix') 306 | ydump_confusion_matrix(doc, cm, results.label_list) 307 | 308 | with tag('div', klass='unbreakable'): 309 | line('h3', 'Metrics by Label') 310 | row_vals = precision_recall_f1(consolidated.label_list, 311 | consolidated.true_labels, 312 | consolidated.inferred_labels) 313 | ydump_table(doc, ['Label (mmsi:true/total)', 'Precision', 'Recall', 'F1-Score'], [ 314 | (a, '{:.2f}'.format(b), '{:.2f}'.format(c), '{:.2f}'.format(d)) 315 | for (a, b, c, d) in row_vals 316 | ]) 317 | wts = weights(consolidated.label_list, consolidated.true_labels, 318 | consolidated.inferred_labels) 319 | line('h4', 'Accuracy with equal class weight') 320 | text( 321 | str( 322 | accuracy_score(consolidated.true_labels, 323 | consolidated.inferred_labels, wts))) 324 | 325 | fishing_category_map = { 326 | 'drifting_longlines' : 'drifting_longlines', 327 | 'trawlers' : 'trawlers', 328 | 'purse_seines' : 'purse_seines', 329 | 'pots_and_traps' : 'stationary_gear', 330 | 'set_gillnets' : 'stationary_gear', 331 | 'set_longlines' : 'stationary_gear' 332 | } 333 | 334 | 335 | def ydump_fishing_localisation(doc, results): 336 | doc, tag, text, line = doc.ttl() 337 | 338 | y_true = np.concatenate(results.true_fishing_by_mmsi.values()) 339 | y_pred = np.concatenate(results.pred_fishing_by_mmsi.values()) 340 | 341 | header = ['Gear Type (mmsi:true/total)', 'Precision', 'Recall', 'Accuracy', 'F1-Score'] 342 | rows = [] 343 | logging.info('Overall localisation accuracy %s', 344 | accuracy_score(y_true, y_pred)) 345 | logging.info('Overall localisation precision %s', 346 | precision_score(y_true, y_pred)) 347 | logging.info('Overall localisation recall %s', 348 | recall_score(y_true, y_pred)) 349 | 350 | for cls in sorted(set(fishing_category_map.values())) + ['other'] : 351 | true_chunks = [] 352 | pred_chunks = [] 353 | mmsi_list = [] 354 | for mmsi in results.label_map: 355 | if mmsi not in results.true_fishing_by_mmsi: 356 | continue 357 | if fishing_category_map.get(results.label_map[mmsi], 'other') != cls: 358 | continue 359 | mmsi_list.append(mmsi) 360 | true_chunks.append(results.true_fishing_by_mmsi[mmsi]) 361 | pred_chunks.append(results.pred_fishing_by_mmsi[mmsi]) 362 | if len(true_chunks): 363 | logging.info('MMSI for {}: {}'.format(cls, mmsi_list)) 364 | y_true = np.concatenate(true_chunks) 365 | y_pred = np.concatenate(pred_chunks) 366 | rows.append(['{} ({}:{}/{})'.format(cls, len(true_chunks), sum(y_true), len(y_true)), 367 | precision_score(y_true, y_pred), 368 | recall_score(y_true, y_pred), 369 | accuracy_score(y_true, y_pred), 370 | f1_score(y_true, y_pred), ]) 371 | 372 | rows.append(['', '', '', '', '']) 373 | 374 | y_true = np.concatenate(results.true_fishing_by_mmsi.values()) 375 | y_pred = np.concatenate(results.pred_fishing_by_mmsi.values()) 376 | 377 | rows.append(['Overall', 378 | precision_score(y_true, y_pred), 379 | recall_score(y_true, y_pred), 380 | accuracy_score(y_true, y_pred), 381 | f1_score(y_true, y_pred), ]) 382 | 383 | with tag('div', klass='unbreakable'): 384 | ydump_table( 385 | doc, header, 386 | [[('{:.2f}'.format(x) if isinstance(x, float) else x) for x in row] 387 | for row in rows]) -------------------------------------------------------------------------------- /classification/metrics/compute_fishing_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | 16 | 17 | Example: 18 | 19 | python -m classification.metrics.compute_fishing_metrics \ 20 | --inference-table machine_learning_dev_ttl_120d.fishing_detection_vid_features_v20190509_ \ 21 | --label-path classification/data/det_info_v20190507.csv \ 22 | --dest-path ./test_fishing_inference_0509.html \ 23 | --fishing-ranges classification/data/det_ranges_v20190507.csv 24 | 25 | """ 26 | from __future__ import division 27 | from __future__ import absolute_import 28 | from __future__ import print_function 29 | import os 30 | import csv 31 | import subprocess 32 | import numpy as np 33 | import pandas as pd 34 | import pandas_gbq 35 | import dateutil.parser 36 | import logging 37 | import argparse 38 | from collections import namedtuple, defaultdict 39 | import sys 40 | import yattag 41 | from classification.metadata import VESSEL_CLASS_DETAILED_NAMES, VESSEL_CATEGORIES, schema, atomic 42 | import gzip 43 | import dateutil.parser 44 | import datetime 45 | import pytz 46 | from .ydump import css, ydump_table 47 | import six 48 | 49 | 50 | 51 | coarse_categories = [ 52 | 'cargo_or_tanker', 'passenger', 'seismic_vessel', 'tug', 'other_fishing', 53 | 'drifting_longlines', 'seiners', 'fixed_gear', 'squid_jigger', 'trawlers', 54 | 'other_not_fishing'] 55 | 56 | coarse_mapping = defaultdict(set) 57 | for k0, extra in [('fishing', 'other_fishing'), 58 | ('non_fishing', 'other_not_fishing')]: 59 | for k1, v1 in schema['unknown'][k0].items(): 60 | key = k1 if (k1 in coarse_categories) else extra 61 | if v1 is None: 62 | coarse_mapping[key] |= {k1} 63 | else: 64 | coarse_mapping[key] |= set(atomic(v1)) 65 | 66 | coarse_mapping = [(k, coarse_mapping[k]) for k in coarse_categories] 67 | 68 | fishing_mapping = [ 69 | ['fishing', set(atomic(schema['unknown']['fishing']))], 70 | ['non_fishing', set(atomic(schema['unknown']['non_fishing']))], 71 | ] 72 | 73 | 74 | fishing_category_map = {} 75 | atomic_fishing = fishing_mapping[0][1] 76 | for coarse, fine in coarse_mapping: 77 | for atomic in fine: 78 | if atomic in atomic_fishing: 79 | fishing_category_map[atomic] = coarse 80 | 81 | 82 | # Faster than using dateutil 83 | def _parse(x): 84 | if isinstance(x, datetime.datetime): 85 | return x 86 | # 2014-08-28T13:56:16+00:00 87 | # TODO: fix generation to generate consistent datetimes 88 | if x[-6:] == '+00:00': 89 | x = x[:-6] 90 | if x.endswith('.999999'): 91 | x = x[:-7] 92 | if x.endswith('Z'): 93 | x = x[:-1] 94 | try: 95 | dt = datetime.datetime.strptime(x, '%Y-%m-%dT%H:%M:%S') 96 | except: 97 | logging.fatal('Could not parse "%s"', x) 98 | raise 99 | return dt.replace(tzinfo=pytz.UTC) 100 | 101 | 102 | LocalisationResults = namedtuple('LocalisationResults', 103 | ['true_fishing_by_id', 104 | 'pred_fishing_by_id', 'label_map']) 105 | 106 | FishingRange = namedtuple('FishingRange', 107 | ['is_fishing', 'start_time', 'end_time']) 108 | 109 | 110 | def ydump_fishing_localisation(doc, results): 111 | doc, tag, text, line = doc.ttl() 112 | 113 | y_true = np.concatenate(list(results.true_fishing_by_id.values())) 114 | y_pred = np.concatenate(list(results.pred_fishing_by_id.values())) 115 | 116 | header = ['Gear Type (id:true/total)', 'Precision', 'Recall', 'Accuracy', 'F1-Score'] 117 | rows = [] 118 | logging.info('Overall localisation accuracy %s', 119 | accuracy_score(y_true, y_pred)) 120 | logging.info('Overall localisation precision %s', 121 | precision_score(y_true, y_pred)) 122 | logging.info('Overall localisation recall %s', 123 | recall_score(y_true, y_pred)) 124 | 125 | for cls in sorted(set(fishing_category_map.values())) + ['other'] : 126 | true_chunks = [] 127 | pred_chunks = [] 128 | id_list = [] 129 | for id_ in results.label_map: 130 | if id_ not in results.true_fishing_by_id: 131 | continue 132 | if fishing_category_map.get(results.label_map[id_], 'other') != cls: 133 | continue 134 | id_list.append(id_) 135 | true_chunks.append(results.true_fishing_by_id[id_]) 136 | pred_chunks.append(results.pred_fishing_by_id[id_]) 137 | if len(true_chunks): 138 | logging.info('ID for {}: {}'.format(cls, id_list)) 139 | y_true = np.concatenate(true_chunks) 140 | y_pred = np.concatenate(pred_chunks) 141 | rows.append(['{} ({}:{}/{})'.format(cls, len(true_chunks), sum(y_true), len(y_true)), 142 | precision_score(y_true, y_pred), 143 | recall_score(y_true, y_pred), 144 | accuracy_score(y_true, y_pred), 145 | f1_score(y_true, y_pred), ]) 146 | 147 | rows.append(['', '', '', '', '']) 148 | 149 | y_true = np.concatenate(list(results.true_fishing_by_id.values())) 150 | y_pred = np.concatenate(list(results.pred_fishing_by_id.values())) 151 | 152 | rows.append(['Overall', 153 | precision_score(y_true, y_pred), 154 | recall_score(y_true, y_pred), 155 | accuracy_score(y_true, y_pred), 156 | f1_score(y_true, y_pred), ]) 157 | 158 | with tag('div', klass='unbreakable'): 159 | ydump_table( 160 | doc, header, 161 | [[('{:.2f}'.format(x) if isinstance(x, float) else x) for x in row] 162 | for row in rows]) 163 | 164 | 165 | 166 | 167 | def precision_score(y_true, y_pred): 168 | y_true = np.asarray(y_true, dtype=bool) 169 | y_pred = np.asarray(y_pred, dtype=bool) 170 | 171 | true_pos = y_true & y_pred 172 | all_pos = y_pred 173 | 174 | return true_pos.sum() / all_pos.sum() 175 | 176 | 177 | def recall_score(y_true, y_pred): 178 | y_true = np.asarray(y_true, dtype=bool) 179 | y_pred = np.asarray(y_pred, dtype=bool) 180 | 181 | true_pos = y_true & y_pred 182 | all_true = y_true 183 | 184 | return true_pos.sum() / all_true.sum() 185 | 186 | 187 | def f1_score(y_true, y_pred): 188 | prec = precision_score(y_true, y_pred) 189 | recall = recall_score(y_true, y_pred) 190 | 191 | return 2 / (1 / prec + 1 / recall) 192 | 193 | def accuracy_score(y_true, y_pred, weights=None): 194 | y_true = np.asarray(y_true) 195 | y_pred = np.asarray(y_pred) 196 | if weights is None: 197 | weights = np.ones_like(y_pred).astype(float) 198 | weights = np.asarray(weights) 199 | 200 | correct = (y_true == y_pred) 201 | 202 | return (weights * correct).sum() / weights.sum() 203 | 204 | 205 | def load_inferred_fishing(table, id_list, project_id, threshold=True): 206 | """Load inferred data and generate comparison data 207 | 208 | """ 209 | query_template = """ 210 | SELECT vessel_id as id, start_time, end_time, nnet_score FROM 211 | TABLE_DATE_RANGE([{table}], 212 | TIMESTAMP('{year}-01-01'), TIMESTAMP('{year}-12-31')) 213 | WHERE vessel_id in ({ids}) 214 | """ 215 | ids = ','.join('"{}"'.format(x) for x in id_list) 216 | ranges = defaultdict(list) 217 | for year in range(2012, 2019): 218 | query = query_template.format(table=table, year=year, ids=ids) 219 | try: 220 | df = pd.read_gbq(query, project_id=project_id, dialect='legacy') 221 | except pandas_gbq.gbq.GenericGBQException as err: 222 | if 'matches no table' in err.args[0]: 223 | print('skipping', year) 224 | continue 225 | else: 226 | print(query) 227 | raise 228 | for x in df.itertuples(): 229 | score = x.nnet_score 230 | if threshold: 231 | score = score > 0.5 232 | start = x.start_time.replace(tzinfo=pytz.utc) 233 | end = x.end_time.replace(tzinfo=pytz.utc) 234 | ranges[x.id].append(FishingRange(score, start, end)) 235 | return ranges 236 | 237 | def load_true_fishing_ranges_by_id(fishing_range_path, 238 | split_map, 239 | split, 240 | threshold=True): 241 | ranges_by_id = defaultdict(list) 242 | parse = dateutil.parser.parse 243 | with open(fishing_range_path) as f: 244 | for row in csv.DictReader(f): 245 | id_ = row['id'].strip() 246 | if not split_map.get(id_) == str(split): 247 | continue 248 | val = float(row['is_fishing']) 249 | if threshold: 250 | val = val > 0.5 251 | rng = (val, parse(row['start_time']).replace(tzinfo=pytz.UTC), 252 | parse(row['end_time']).replace(tzinfo=pytz.UTC)) 253 | ranges_by_id[id_].append(rng) 254 | return ranges_by_id 255 | 256 | 257 | def datetime_to_minute(dt): 258 | timestamp = (dt - datetime.datetime( 259 | 1970, 1, 1, tzinfo=pytz.utc)).total_seconds() 260 | return int(timestamp // 60) 261 | 262 | 263 | def compare_fishing_localisation(inferred_ranges, fishing_range_path, 264 | label_map, split_map, split): 265 | 266 | logging.debug('loading fishing ranges') 267 | true_ranges_by_id = load_true_fishing_ranges_by_id(fishing_range_path, 268 | split_map, split) 269 | print("TRUE", sorted(true_ranges_by_id.keys())[:10]) 270 | print("INF", sorted(inferred_ranges.keys())[:10]) 271 | print(repr(sorted(true_ranges_by_id.keys())[0])) 272 | print(repr(sorted(inferred_ranges.keys())[0])) 273 | true_by_id = {} 274 | pred_by_id = {} 275 | 276 | for id_ in sorted(true_ranges_by_id.keys()): 277 | id_ = six.ensure_text(id_) 278 | logging.debug('processing %s', id_) 279 | if id_ not in inferred_ranges: 280 | continue 281 | true_ranges = true_ranges_by_id[id_] 282 | if not true_ranges: 283 | continue 284 | 285 | # Determine minutes from start to finish of this id, create an array to 286 | # hold results and fill with -1 (unknown) 287 | logging.debug('processing %s true ranges', len(true_ranges)) 288 | logging.debug('finding overall range') 289 | _, start, end = true_ranges[0] 290 | for (_, s, e) in true_ranges[1:]: 291 | start = min(start, s) 292 | end = max(end, e) 293 | start_min = datetime_to_minute(start) 294 | end_min = datetime_to_minute(end) 295 | minutes = np.empty([end_min - start_min + 1, 2], dtype=int) 296 | minutes.fill(-1) 297 | 298 | # Fill in minutes[:, 0] with known true / false values 299 | logging.debug('filling 0s') 300 | for (is_fishing, s, e) in true_ranges: 301 | s_min = datetime_to_minute(s) 302 | e_min = datetime_to_minute(e) 303 | for m in range(s_min - start_min, e_min - start_min + 1): 304 | minutes[m, 0] = is_fishing 305 | 306 | # fill in minutes[:, 1] with inferred true / false values 307 | logging.debug('filling 1s') 308 | for (is_fishing, s, e) in inferred_ranges[str(id_)]: 309 | s_min = datetime_to_minute(s) 310 | e_min = datetime_to_minute(e) 311 | for m in range(s_min - start_min, e_min - start_min + 1): 312 | if 0 <= m < len(minutes): 313 | minutes[m, 1] = is_fishing 314 | 315 | mask = ((minutes[:, 0] != -1) & (minutes[:, 1] != -1)) 316 | if mask.sum(): 317 | accuracy = ( 318 | (minutes[:, 0] == minutes[:, 1]) * mask).sum() / mask.sum() 319 | logging.debug('Accuracy for ID %s: %s', id_, accuracy) 320 | 321 | true_by_id[id_] = minutes[mask, 0] 322 | pred_by_id[id_] = minutes[mask, 1] 323 | 324 | return LocalisationResults(true_by_id, pred_by_id, label_map) 325 | 326 | 327 | def compute_results(args): 328 | logging.info('Loading label maps') 329 | maps = defaultdict(dict) 330 | with open(args.label_path) as f: 331 | for row in csv.DictReader(f): 332 | id_ = row['id'].strip() 333 | if not row['split'] == str(args.split): 334 | continue 335 | for field in ['label', 'split']: 336 | if row[field]: 337 | if field == 'label': 338 | if row[field].strip( 339 | ) not in VESSEL_CLASS_DETAILED_NAMES: 340 | continue 341 | maps[field][id_] = row[field] 342 | 343 | # Sanity check the attribute mappings 344 | for field in ['length', 'tonnage', 'engine_power', 'crew_size']: 345 | for id_, value in maps[field].items(): 346 | assert float(value) > 0, (id_, value) 347 | 348 | logging.info('Loading inference data') 349 | ids = set([x for x in maps['split'] if maps['split'][x] == str(args.split)]) 350 | 351 | fishing_ranges = load_inferred_fishing(args.inference_table, ids, args.project_id) 352 | logging.info('Comparing localisation') 353 | results = {} 354 | results['localisation'] = compare_fishing_localisation( 355 | fishing_ranges, args.fishing_ranges, maps['label'], 356 | maps['split'], args.split) 357 | 358 | 359 | return results 360 | 361 | 362 | def dump_html(args, results): 363 | 364 | doc = yattag.Doc() 365 | 366 | with doc.tag('style', type='text/css'): 367 | doc.asis(css) 368 | 369 | logging.info('Dumping Localisation') 370 | doc.line('h2', 'Fishing Localisation') 371 | ydump_fishing_localisation(doc, results['localisation']) 372 | doc.stag('hr') 373 | 374 | with open(args.dest_path, 'w') as f: 375 | logging.info('Writing output') 376 | f.write(yattag.indent(doc.getvalue(), indent_text=True)) 377 | 378 | 379 | """ 380 | 381 | python -m classification.metrics.compute_fishing_metrics \ 382 | --inference-table machine_learning_dev_ttl_120d.test_dataflow_2016_ \ 383 | --label-path classification/data/fishing_classes.csv \ 384 | --dest-path test_fishing.html \ 385 | --fishing-ranges classification/data/combined_fishing_ranges.csv \ 386 | 387 | 388 | """ 389 | 390 | 391 | if __name__ == '__main__': 392 | logging.getLogger().setLevel(logging.DEBUG) 393 | 394 | parser = argparse.ArgumentParser( 395 | description='Test fishing inference results and output metrics.\n') 396 | parser.add_argument( 397 | '--inference-table', help='table of inference results', required=True) 398 | parser.add_argument( 399 | '--project-id', help='Google Cloud project id', 400 | default='world-fishing-827') 401 | parser.add_argument( 402 | '--label-path', help='path to test data', required=True) 403 | parser.add_argument('--fishing-ranges', help='path to fishing range data', required=True) 404 | parser.add_argument( 405 | '--dest-path', help='path to write results to', required=True) 406 | parser.add_argument('--split', type=int, default=0) 407 | 408 | 409 | args = parser.parse_args() 410 | 411 | results = compute_results(args) 412 | 413 | dump_html(args, results) 414 | 415 | -------------------------------------------------------------------------------- /classification/models/objectives.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import abc 16 | import calendar 17 | from collections import namedtuple, OrderedDict 18 | import datetime 19 | import logging 20 | import numpy as np 21 | import tensorflow as tf 22 | import tensorflow.metrics as metrics 23 | from classification import metadata 24 | import pytz 25 | import six 26 | """ Terminology in the context of objectives. 27 | 28 | Net: the raw input to an objective function, an embeddeding that has not 29 | yet been shaped for the predictive task in hand. 30 | Logits: the input to a softmax classifier. 31 | Prediction: the output of an objective function, be it class probabilities 32 | from a categorical function, or a continuous output vector for 33 | a regression. 34 | """ 35 | 36 | Trainer = namedtuple("Trainer", ["loss", "update_ops"]) 37 | TrainNetInfo = namedtuple("TrainNetInfo", ["optimizer", "objective_trainers"]) 38 | 39 | EPSILON = 1e-20 40 | 41 | 42 | def f1(recall, precision): 43 | rval, rop = recall 44 | pval, pop = precision 45 | f1 = 2.0 / (1.0 / rval + 1.0 / pval) 46 | return (f1, f1) 47 | 48 | 49 | class ObjectiveBase(object): 50 | __metaclass__ = abc.ABCMeta 51 | 52 | def __init__(self, metadata_label, name, loss_weight, metrics): 53 | """ 54 | args: 55 | metadata_label: 56 | name: name of this objective (for metrics) 57 | loss_weight: weight of this objective (so that can increase decrease relative to other objectives) 58 | metrics: which metrics to include. Options are currently ['all', 'minimal'] 59 | 60 | """ 61 | self.metadata_label = metadata_label 62 | self.name = name 63 | self.loss_weight = loss_weight 64 | self.prediction = None 65 | self.metrics = metrics 66 | 67 | @abc.abstractmethod 68 | def build(self, net): 69 | pass 70 | 71 | def create_metrics(self, labels): 72 | raw_metrics = self.create_raw_metrics(labels) 73 | try: 74 | eval_metrics = {"{}/{}".format(self.name, k) : v for (k, v) in raw_metrics.items()} 75 | except: 76 | logging.warning("Problem creating eval_metrics in {}".format(self)) 77 | return {} 78 | for k, v in eval_metrics.items(): 79 | tf.summary.scalar(k, v[1]) 80 | return eval_metrics 81 | 82 | 83 | 84 | class RegressionObjective(ObjectiveBase): 85 | def __init__(self, 86 | metadata_label, 87 | name, 88 | value_from_id, 89 | loss_weight=1.0, 90 | metrics='all'): 91 | super(RegressionObjective, self).__init__(metadata_label, name, 92 | loss_weight, metrics) 93 | self.value_from_id = value_from_id 94 | self.output_shape = [] 95 | 96 | def create_label(self, id_, timestamps): 97 | self.value_from_id(id_) 98 | 99 | def build(self, net): 100 | self.prediction = tf.layers.dense(net, 1, activation=None)[:, 0] 101 | 102 | def expected_and_mask(self, labels): 103 | mask = ~tf.is_nan(labels) 104 | valid = tf.boolean_mask(labels, mask) 105 | idx = tf.to_int32(tf.where(mask)) 106 | expected = tf.scatter_nd(idx, valid, tf.shape(labels)) 107 | return expected, mask 108 | 109 | def masked_mean_error(self, labels): 110 | expected, mask = self.expected_and_mask(labels) 111 | mask = tf.cast(mask, tf.float32) 112 | count = tf.reduce_sum(mask) 113 | diff = tf.abs((expected - self.prediction) * mask) 114 | error = tf.reduce_sum(diff) / tf.maximum(count, EPSILON) 115 | return error 116 | 117 | def create_loss(self, labels): 118 | raw_loss = self._masked_mean_error(self.prediction, ids) 119 | return raw_loss * self.loss_weight 120 | 121 | def create_raw_metrics(self, labels): 122 | error = self.masked_mean_error(labels) 123 | loss = self.masked_mean_loss(self.prediction) 124 | return { 125 | 'loss' : tf.metrics.mean(loss), 126 | } 127 | 128 | 129 | 130 | class LogRegressionObjective(ObjectiveBase): 131 | def __init__(self, 132 | metadata_label, 133 | name, 134 | value_from_id, 135 | loss_weight=1.0, 136 | metrics='all'): 137 | super(LogRegressionObjective, self).__init__(metadata_label, name, 138 | loss_weight, metrics) 139 | self.value_from_id = value_from_id 140 | self.output_shape = [] 141 | 142 | def create_label(self, id_, timestamps): 143 | return self.value_from_id(id_) 144 | 145 | def build(self, net): 146 | self.prediction = tf.layers.dense(net, 1, activation=None)[:, 0] 147 | 148 | def expected_and_mask(self, labels): 149 | mask = ~tf.is_nan(labels) 150 | valid = tf.boolean_mask(labels, mask) 151 | idx = tf.to_int32(tf.where(mask)) 152 | expected = tf.scatter_nd(idx, valid, tf.shape(labels)) 153 | return expected, mask 154 | 155 | def masked_mean_loss(self, labels): 156 | expected, mask = self.expected_and_mask(labels) 157 | mask = tf.cast(mask, tf.float32) 158 | count = tf.reduce_sum(mask) 159 | squared_error = ( 160 | (tf.log(expected + EPSILON) - self.prediction)**2 * mask) 161 | loss = tf.reduce_sum(squared_error) / tf.maximum(count, EPSILON) 162 | return loss 163 | 164 | def masked_mean_error(self, labels): 165 | expected, mask = self.expected_and_mask(labels) 166 | mask = tf.cast(mask, tf.float32) 167 | count = tf.reduce_sum(mask) 168 | diff = tf.abs((expected - tf.exp(self.prediction)) * mask) 169 | error = tf.reduce_sum(diff) / tf.maximum(count, EPSILON) 170 | return error 171 | 172 | def create_loss(self, labels): 173 | raw_loss = self.masked_mean_loss(labels) 174 | print(raw_loss, self.loss_weight) 175 | return raw_loss * self.loss_weight 176 | 177 | def create_raw_metrics(self, labels): 178 | loss = self.masked_mean_loss(labels) 179 | error = self.masked_mean_error(labels) 180 | return { 181 | 'loss': tf.metrics.mean(loss), 182 | 'error': tf.metrics.mean(error) 183 | } 184 | 185 | def build_json_results(self, prediction, timestamps): 186 | return {'name': self.name, 'value': np.exp(float(prediction))} 187 | 188 | 189 | 190 | class LogRegressionObjectiveMAE(LogRegressionObjective): 191 | 192 | def __init__(self, 193 | metadata_label, 194 | name, 195 | value_from_id, 196 | loss_weight=1.0, 197 | metrics='all'): 198 | super(LogRegressionObjectiveMAE, self).__init__(metadata_label, name, value_from_id, 199 | loss_weight, metrics) 200 | 201 | def create_label(self, id_, timestamps): 202 | return self.value_from_id(id_) 203 | 204 | def masked_mean_loss(self, labels): 205 | expected, mask = self.expected_and_mask(labels) 206 | mask = tf.cast(mask, tf.float32) 207 | count = tf.reduce_sum(mask) 208 | mean_absolute_error = tf.abs( 209 | (tf.log(expected + EPSILON) - self.prediction) * mask) 210 | loss = tf.reduce_sum(mean_absolute_error) / tf.maximum(count, EPSILON) 211 | return loss 212 | 213 | 214 | 215 | class MultiClassificationObjective(ObjectiveBase): 216 | def __init__(self, 217 | metadata_label, 218 | name, 219 | vessel_metadata, 220 | loss_weight=1.0, 221 | metrics='all'): 222 | super(MultiClassificationObjective, self).__init__( 223 | metadata_label, name, loss_weight, metrics) 224 | self.vessel_metadata = vessel_metadata 225 | self.classes = metadata.VESSEL_CLASS_DETAILED_NAMES 226 | self.num_classes = metadata.multihot_lookup_table.shape[-1] 227 | self.class_indices = {k[0]: i for (i, k) in enumerate(metadata.VESSEL_CATEGORIES)} 228 | self.output_shape = [self.num_classes] 229 | 230 | 231 | def build(self, net): 232 | self.logits = tf.layers.dense( 233 | net, self.num_classes, activation=None) 234 | self.prediction = tf.nn.softmax(self.logits) 235 | 236 | def create_label(self, id_, timestamps): 237 | encoded = np.zeros([self.num_classes], dtype=np.int32) 238 | lbl_str = self.vessel_metadata.vessel_label('label', id_).strip() 239 | if lbl_str: 240 | for lbl in lbl_str.split('|'): 241 | j = self.class_indices[lbl] 242 | # Use '|' rather than '+' since classes might not be disjoint 243 | encoded |= metadata.multihot_lookup_table[j] 244 | return encoded.astype(np.float32) 245 | 246 | def create_loss(self, labels): 247 | with tf.variable_scope("custom-loss"): 248 | mask = tf.to_float(tf.greater_equal(tf.reduce_sum(labels, axis=1), 1)) 249 | positives = tf.reduce_sum( 250 | tf.to_float(labels) * self.prediction, reduction_indices=[1]) 251 | raw_loss = -tf.reduce_mean(mask * tf.log(positives + EPSILON)) 252 | return raw_loss * self.loss_weight 253 | 254 | def create_raw_metrics(self, labels): 255 | mask = tf.to_float(tf.equal(tf.reduce_sum(labels, axis=1), 1)) 256 | encoded_labels = tf.to_int32(tf.argmax(labels, axis=1)) 257 | predictions = tf.to_int32(tf.argmax(self.prediction, axis=1)) 258 | loss = self.create_loss(labels) 259 | return { 260 | 'accuracy' : metrics.accuracy(predictions, encoded_labels, weights=mask), 261 | 'loss' : tf.metrics.mean(loss) 262 | } 263 | 264 | def build_json_results(self, class_probabilities, timestamps): 265 | max_prob_index = np.argmax(class_probabilities) 266 | max_probability = float(class_probabilities[max_prob_index]) 267 | max_label = self.classes[max_prob_index] 268 | full_scores = dict( 269 | zip(self.classes, [float(v) for v in class_probabilities])) 270 | 271 | return { 272 | 'name': self.name, 273 | 'max_label': max_label, 274 | 'max_label_probability': max_probability, 275 | 'label_scores': full_scores 276 | } 277 | 278 | 279 | class FishingLocalizationObjectiveCrossEntropy(ObjectiveBase): 280 | def __init__(self, 281 | metadata_label, 282 | name, 283 | vessel_metadata, 284 | loss_weight=1.0, 285 | metrics='all', 286 | window=None): 287 | super(FishingLocalizationObjectiveCrossEntropy, self).__init__(metadata_label, name, loss_weight, 288 | metrics) 289 | self.vessel_metadata = vessel_metadata 290 | self.window = window 291 | self.pos_weight = 1.0 292 | 293 | def loss_function(self, dense_labels): 294 | fishing_mask = tf.to_float(tf.not_equal(dense_labels, -1)) 295 | fishing_targets = tf.to_float(dense_labels > 0.5) 296 | logits = self.logits 297 | if self.window: 298 | b, e = self.window 299 | fishing_mask = fishing_mask[:, b:e] 300 | fishing_targets = fishing_targets[:, b:e] 301 | logits = logits[:, b:e] 302 | return tf.reduce_sum(fishing_mask * 303 | tf.nn.weighted_cross_entropy_with_logits( 304 | targets=fishing_targets, 305 | logits=logits, 306 | pos_weight=self.pos_weight)) 307 | 308 | def build(self, net): 309 | self.logits = net[:, :, 0] 310 | self.prediction = tf.sigmoid(self.logits) 311 | 312 | def create_loss(self, dense_labels): 313 | return self.loss_weight * self.loss_function(dense_labels) 314 | 315 | def create_raw_metrics(self, dense_labels): 316 | thresholded_prediction = tf.to_int32(self.prediction > 0.5) 317 | valid = tf.to_int32(tf.not_equal(dense_labels, -1)) 318 | labels = tf.to_int32(dense_labels > 0.5) 319 | weights = tf.to_float(valid) 320 | prediction = self.prediction 321 | 322 | if self.window: 323 | b, e = self.window 324 | prediction = prediction[:, b:e] 325 | dense_labels = dense_labels[:, b:e] 326 | thresholded_prediction = thresholded_prediction[:, b:e] 327 | valid = valid[:, b:e] 328 | labels = labels[:, b:e] 329 | weights = weights[:, b:e] 330 | 331 | return { 332 | 'MSE': tf.metrics.mean_squared_error(prediction, dense_labels, weights=weights), 333 | 'accuracy': tf.metrics.accuracy(labels, thresholded_prediction, weights=weights), 334 | 'precision': tf.metrics.precision(labels, thresholded_prediction, weights=weights), 335 | 'recall': tf.metrics.recall(labels, thresholded_prediction, weights=weights) 336 | } 337 | 338 | 339 | 340 | def build_json_results(self, prediction, timestamps): 341 | InferencePoint = namedtuple('InferencePoint', ['timestamp', 'is_fishing']) 342 | InferenceRange = namedtuple('InferenceRange', ['start_time', 'end_time', 'score']) 343 | 344 | assert (len(prediction) == len(timestamps)) 345 | thresholded_prediction = prediction > 0.5 346 | combined = list(six.moves.zip(timestamps, thresholded_prediction)) 347 | if self.window: 348 | b, e = self.window 349 | combined = combined[b:e] 350 | 351 | last = None 352 | fishing_ranges = [] 353 | for ts_raw, is_fishing in combined: 354 | ts = datetime.datetime.utcfromtimestamp(int(ts_raw)) 355 | if last and last.timestamp >= ts: 356 | logging.warning("last.timestamp >= timestamp") 357 | break 358 | if last and last.is_fishing == is_fishing: 359 | if ts.date() > last.timestamp.date(): 360 | # We are crossing a day boundary here, so break into two ranges 361 | end_of_day = datetime.datetime.combine(last.timestamp.date(), 362 | datetime.time(hour=23, minute=59, second=59)) 363 | # TODO: are we skipping a day here if gaps is multi day? Check 364 | start_of_day = datetime.datetime.combine(ts.date(), 365 | datetime.time(hour=0, minute=0, second=0)) 366 | fishing_ranges[-1] = fishing_ranges[-1]._replace( 367 | end_time=end_of_day.isoformat()) 368 | fishing_ranges.append( 369 | InferenceRange(start_of_day.isoformat(), None, is_fishing)) 370 | fishing_ranges[-1] = fishing_ranges[-1]._replace(end_time=ts.isoformat()) 371 | else: 372 | # TODO, append min(half the distance to previous / next point) 373 | # TODO, but maybe we should drop long ranges with no points 374 | fishing_ranges.append( 375 | InferenceRange(ts.isoformat(), ts.isoformat(), is_fishing)) 376 | last = InferencePoint(timestamp=ts, is_fishing=is_fishing) 377 | 378 | return [{'start_time': x.start_time + 'Z', 379 | 'end_time': x.end_time + 'Z', 'value': float(x.score)} 380 | for x in fishing_ranges] 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | -------------------------------------------------------------------------------- /classification/metadata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. and Skytruth Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections import defaultdict, namedtuple 16 | import csv 17 | import datetime 18 | import dateutil.parser 19 | import pytz 20 | import logging 21 | import os 22 | import sys 23 | import tensorflow as tf 24 | import yaml 25 | import numpy as np 26 | import hashlib 27 | import six 28 | from .feature_generation.file_iterator import GCSFile 29 | 30 | 31 | """ The main column for vessel classification. """ 32 | PRIMARY_VESSEL_CLASS_COLUMN = 'label' 33 | 34 | #TODO: (bitsofbits) think about extracting to config file 35 | 36 | # The 'real' categories for multihotness are the fine categories, which 'coarse' and 'fishing' 37 | # are defined in terms of. Any number of coarse categories, even with overlapping values can 38 | # be defined in principle, although at present the interaction between the mulithot and non multihot 39 | # versions makes that more complicated. 40 | 41 | try: 42 | yaml_load = yaml.safe_load 43 | except: 44 | yaml_load = yaml.load 45 | 46 | 47 | raw_schema = ''' 48 | unknown: 49 | non_fishing: 50 | passenger: 51 | gear: 52 | fish_factory: 53 | cargo_or_tanker: 54 | bunker_or_tanker: 55 | bunker: 56 | tanker: 57 | cargo_or_reefer: 58 | cargo: 59 | reefer: 60 | specialized_reefer: 61 | container_reefer: 62 | fish_tender: 63 | well_boat: 64 | patrol_vessel: 65 | research: 66 | dive_vessel: 67 | submarine: 68 | dredge_non_fishing: 69 | supply_vessel: 70 | tug: 71 | seismic_vessel: 72 | helicopter: 73 | other_not_fishing: 74 | 75 | fishing: 76 | squid_jigger: 77 | drifting_longlines: 78 | pole_and_line: 79 | other_fishing: 80 | trollers: 81 | fixed_gear: 82 | pots_and_traps: 83 | set_longlines: 84 | set_gillnets: 85 | trawlers: 86 | dredge_fishing: 87 | seiners: 88 | purse_seines: 89 | tuna_purse_seines: 90 | other_purse_seines: 91 | other_seines: 92 | driftnets: 93 | ''' 94 | 95 | 96 | schema = yaml.safe_load(raw_schema) 97 | 98 | 99 | def atomic(obj): 100 | for k, v in obj.items(): 101 | if v is None or isinstance(v, str): 102 | yield k 103 | else: 104 | for x in atomic(v): 105 | yield x 106 | 107 | def categories(obj, include_atomic=True): 108 | for k, v in obj.items(): 109 | if v is None or isinstance(v, str): 110 | if include_atomic: 111 | yield k, [k] 112 | else: 113 | yield (k, list(atomic(v))) 114 | for x in categories(v, include_atomic=include_atomic): 115 | yield x 116 | 117 | 118 | 119 | VESSEL_CLASS_DETAILED_NAMES = sorted(atomic(schema)) 120 | 121 | VESSEL_CATEGORIES = sorted(categories(schema)) 122 | 123 | TRAINING_SPLIT = 'Training' 124 | TEST_SPLIT = 'Test' 125 | 126 | FishingRange = namedtuple('FishingRange', 127 | ['start_time', 'end_time', 'is_fishing']) 128 | 129 | 130 | def stable_hash(x): 131 | x = six.ensure_binary(x) 132 | digest = hashlib.blake2b(six.ensure_binary(x)).hexdigest()[-8:] 133 | return int(digest, 16) 134 | 135 | class VesselMetadata(object): 136 | def __init__(self, 137 | metadata_dict, 138 | fishing_ranges_map): 139 | self.metadata_by_split = metadata_dict 140 | self.metadata_by_id = {} 141 | self.fishing_ranges_map = fishing_ranges_map 142 | self.id_map_int2bytes = {} 143 | for split, vessels in metadata_dict.items(): 144 | for id_, data in vessels.items(): 145 | id_ = six.ensure_binary(id_) 146 | self.metadata_by_id[id_] = data 147 | idhash = stable_hash(id_) 148 | self.id_map_int2bytes[idhash] = id_ 149 | 150 | intersection_ids = set(self.metadata_by_id.keys()).intersection( 151 | set(fishing_ranges_map.keys())) 152 | logging.info("Metadata for %d ids.", len(self.metadata_by_id)) 153 | logging.info("Fishing ranges for %d ids.", len(fishing_ranges_map)) 154 | logging.info("Vessels with both types of data: %d", 155 | len(intersection_ids)) 156 | 157 | def vessel_weight(self, id_): 158 | return self.metadata_by_id[id_][1] 159 | 160 | def vessel_label(self, label_name, id_): 161 | return self.metadata_by_id[id_][0][label_name] 162 | 163 | def ids_for_split(self, split): 164 | assert split in (TRAINING_SPLIT, TEST_SPLIT) 165 | # Check to make sure we don't have leakage 166 | if (set(self.metadata_by_split[TRAINING_SPLIT].keys()) & 167 | set(self.metadata_by_split[TEST_SPLIT].keys())): 168 | logging.warning('id in both training and test split') 169 | return self.metadata_by_split[split].keys() 170 | 171 | def weighted_training_list(self, 172 | random_state, 173 | split, 174 | max_replication_factor, 175 | row_filter=lambda row: True, 176 | boundary=1): 177 | replicated_ids = [] 178 | logging.info("Training ids: %d", len(self.ids_for_split(split))) 179 | fishing_ranges_ids = [] 180 | for id_, (row, weight) in self.metadata_by_split[split].items(): 181 | if row_filter(row): 182 | if id_ in self.fishing_ranges_map: 183 | fishing_ranges_ids.append(id_) 184 | weight = min(weight, max_replication_factor) 185 | 186 | int_n = int(weight) 187 | replicated_ids += ([id_] * int_n) 188 | frac_n = weight - float(int_n) 189 | if (random_state.uniform(0.0, 1.0) <= frac_n): 190 | replicated_ids.append(id_) 191 | missing = (-len(replicated_ids)) % boundary 192 | if missing: 193 | replicated_ids = np.concatenate( 194 | [replicated_ids, 195 | np.random.choice(replicated_ids, missing)]) 196 | random_state.shuffle(replicated_ids) 197 | logging.info("Replicated training ids: %d", len(replicated_ids)) 198 | logging.info("Fishing range ids: %d", len(fishing_ranges_ids)) 199 | 200 | return replicated_ids 201 | 202 | def fishing_range_only_list(self, random_state, split): 203 | replicated_ids = [] 204 | fishing_id_set = set( 205 | [k for (k, v) in self.fishing_ranges_map.items() if v]) 206 | fishing_range_only_ids = [id_ 207 | for id_ in self.ids_for_split(split) 208 | if id_ in fishing_id_set] 209 | logging.info("Fishing range training ids: %d / %d", 210 | len(fishing_range_only_ids), 211 | len(self.ids_for_split(split))) 212 | 213 | return fishing_range_only_ids 214 | 215 | 216 | def read_vessel_time_weighted_metadata_lines(available_ids, lines, 217 | fishing_range_dict, split): 218 | """ For a set of vessels, read metadata; use flat weights 219 | 220 | Args: 221 | available_ids: a set of all ids for which we have feature data. 222 | lines: a list of comma-separated vessel metadata lines. Columns are 223 | the id and a set of vessel type columns, containing at least one 224 | called 'label' being the primary/coarse type of the vessel e.g. 225 | (Longliner/Passenger etc.). 226 | fishing_range_dict: dictionary of mapping id to lists of fishing ranges 227 | 228 | Returns: 229 | A VesselMetadata object with weights and labels for each vessel. 230 | """ 231 | 232 | metadata_dict = {TRAINING_SPLIT : {}, TEST_SPLIT : {}} 233 | 234 | min_time_per_id = np.inf 235 | 236 | for row in lines: 237 | id_ = six.ensure_binary(row['id'].strip()) 238 | if id_ in available_ids: 239 | if id_ not in fishing_range_dict: 240 | continue 241 | # Is this id included only to supress false positives 242 | # Symptoms; fishing score for this id never different from 0 243 | item_split = raw_item_split = row['split'] 244 | if raw_item_split in '0123456789': 245 | if int(raw_item_split) == split: 246 | item_split = TEST_SPLIT 247 | else: 248 | item_split = TRAINING_SPLIT 249 | if item_split not in (TRAINING_SPLIT, TEST_SPLIT): 250 | logging.warning( 251 | 'id %s has no valid split assigned (%s); using for Training', 252 | id_, split) 253 | split = TRAINING_SPLIT 254 | time_for_this_id = 0 255 | for rng in fishing_range_dict[id_]: 256 | time_for_this_id += ( 257 | rng.end_time - rng.start_time).total_seconds() 258 | metadata_dict[item_split][id_] = (row, time_for_this_id) 259 | if split is None and raw_item_split in '0123456789': 260 | # Test on everything even though we are training on everything 261 | metadata_dict[TEST_SPLIT][id_] = (row, time_for_this_id) 262 | 263 | if time_for_this_id: 264 | min_time_per_id = min(min_time_per_id, time_for_this_id) 265 | 266 | # This weighting is fiddly. We are keeping it for now to match up 267 | # with older data, but should replace when we move to sets, etc. 268 | MAX_WEIGHT = 100.0 269 | for split_dict in metadata_dict.values(): 270 | for id_ in split_dict: 271 | row, time = split_dict[id_] 272 | split_dict[id_] = (row, min(MAX_WEIGHT, time / min_time_per_id)) 273 | 274 | return VesselMetadata(metadata_dict, fishing_range_dict) 275 | 276 | 277 | def read_vessel_time_weighted_metadata(available_ids, 278 | metadata_file, 279 | fishing_range_dict={}, 280 | split=0): 281 | reader = metadata_file_reader(metadata_file) 282 | 283 | return read_vessel_time_weighted_metadata_lines(available_ids, reader, 284 | fishing_range_dict, 285 | split) 286 | 287 | 288 | def read_vessel_multiclass_metadata_lines(available_ids, lines, 289 | fishing_range_dict): 290 | """ For a set of vessels, read metadata and calculate class weights. 291 | 292 | Args: 293 | available_ids: a set of all ids for which we have feature data. 294 | lines: a list of comma-separated vessel metadata lines. Columns are 295 | the id and a set of vessel type columns, containing at least one 296 | called 'label' being the primary/coarse type of the vessel e.g. 297 | (Longliner/Passenger etc.). 298 | fishing_range_dict: dictionary of mapping id to lists of fishing ranges 299 | Returns: 300 | A VesselMetadata object with weights and labels for each vessel. 301 | """ 302 | 303 | vessel_type_set = set() 304 | dataset_kind_counts = defaultdict(lambda: defaultdict(lambda: 0)) 305 | vessel_types = [] 306 | 307 | cat_map = {k: v for (k, v) in VESSEL_CATEGORIES} 308 | 309 | available_ids = set(available_ids) 310 | for row in lines: 311 | id_ = six.ensure_binary(row['id'].strip()) 312 | if id_ not in available_ids: 313 | continue 314 | raw_vessel_type = row[PRIMARY_VESSEL_CLASS_COLUMN] 315 | if not raw_vessel_type: 316 | continue 317 | atomic_types = set() 318 | for kind in raw_vessel_type.split('|'): 319 | try: 320 | for atm in cat_map[kind]: 321 | atomic_types.add(atm) 322 | except StandardError as err: 323 | logging.warning('unknown vessel type: {}\n{}'.format(kind, err)) 324 | if not atomic_types: 325 | continue 326 | scale = 1.0 / len(atomic_types) 327 | split = row['split'].strip() 328 | assert split in ('Training', 'Test'), repr(split) 329 | vessel_types.append((id_, split, raw_vessel_type, row)) 330 | for atm in atomic_types: 331 | dataset_kind_counts[split][atm] += scale 332 | vessel_type_set |= atomic_types 333 | # else: 334 | # logging.warning('No training data for %s, (%s) %s %s', id_, sorted(available_ids)[:10], 335 | # type(id_), type(sorted(available_ids)[0])) 336 | 337 | # # Calculate weights for each vessel type per split, for 338 | # # now use weights of sqrt(max_count / count) 339 | dataset_kind_weights = defaultdict(lambda: {}) 340 | for split, counts in dataset_kind_counts.items(): 341 | max_count = max(counts.values()) 342 | for atomic_vessel_type, count in counts.items(): 343 | dataset_kind_weights[split][atomic_vessel_type] = np.sqrt(max_count / float(count)) 344 | 345 | metadata_dict = defaultdict(lambda: {}) 346 | for id_, split, raw_vessel_type, row in vessel_types: 347 | if split == 'Training': 348 | weights = [] 349 | for kind in raw_vessel_type.split('|'): 350 | for atm in cat_map.get(kind, 'unknown'): 351 | weights.append(dataset_kind_weights[split][atm]) 352 | metadata_dict[split][id_] = (row, np.mean(weights)) 353 | elif split == "Test": 354 | metadata_dict[split][id_] = (row, 1.0) 355 | else: 356 | logging.warning("unknown split {}".format(split)) 357 | 358 | if len(vessel_type_set) == 0: 359 | logging.fatal('No vessel types found for training.') 360 | sys.exit(-1) 361 | 362 | logging.info("Vessel types: %s", list(vessel_type_set)) 363 | 364 | return VesselMetadata( 365 | dict(metadata_dict), fishing_range_dict) 366 | 367 | 368 | def metadata_file_reader(metadata_file): 369 | """ 370 | 371 | 372 | """ 373 | with open(metadata_file, 'r') as f: 374 | reader = csv.DictReader(f) 375 | logging.info("Metadata columns: %s", reader.fieldnames) 376 | for row in reader: 377 | yield row 378 | 379 | 380 | def read_vessel_multiclass_metadata(available_ids, 381 | metadata_file, 382 | fishing_range_dict={}): 383 | reader = metadata_file_reader(metadata_file) 384 | 385 | return read_vessel_multiclass_metadata_lines( 386 | available_ids, reader, fishing_range_dict) 387 | 388 | 389 | def find_available_ids(feature_path): 390 | with tf.Session() as sess: 391 | logging.info('Reading id list file.') 392 | root_output_path = os.path.dirname(feature_path) 393 | # The feature pipeline stage that outputs the id list is sharded to only 394 | # produce a single file, so no need to glob or loop here. 395 | id_path = os.path.join(root_output_path, 'ids/part-00000-of-00001.txt') 396 | logging.info('Reading id list file from {}'.format(id_path)) 397 | with GCSFile(id_path) as f: 398 | els = f.read().split(b'\n') 399 | id_list = [id_.strip() for id_ in els if id_.strip() != ''] 400 | 401 | logging.info('Found %d ids.', len(id_list)) 402 | return set(id_list) 403 | 404 | 405 | def parse_date(date): 406 | try: 407 | unix_timestamp = float(date) 408 | return datetime.datetime.utcfromtimestamp(unix_timestamp).replace( 409 | tzinfo=pytz.utc) 410 | except: 411 | try: 412 | return dateutil.parser.parse(date) 413 | except: 414 | logging.fatal('could not parse date "{}"'.format(date)) 415 | raise 416 | 417 | 418 | def read_fishing_ranges(fishing_range_file): 419 | """ Read vessel fishing ranges, return a dict of id to classified fishing 420 | or non-fishing ranges for that vessel. 421 | """ 422 | fishing_range_dict = defaultdict(lambda: []) 423 | with open(fishing_range_file, 'r') as f: 424 | for l in f.readlines()[1:]: 425 | els = l.split(',') 426 | id_ = six.ensure_binary(els[0].strip()) 427 | start_time = parse_date(els[1]).replace(tzinfo=pytz.utc) 428 | end_time = parse_date(els[2]).replace(tzinfo=pytz.utc) 429 | is_fishing = float(els[3]) 430 | fishing_range_dict[id_].append( 431 | FishingRange(start_time, end_time, is_fishing)) 432 | 433 | return dict(fishing_range_dict) 434 | 435 | 436 | def build_multihot_lookup_table(): 437 | n_base = len(VESSEL_CLASS_DETAILED_NAMES) 438 | n_categories = len(VESSEL_CATEGORIES) 439 | # 440 | table = np.zeros([n_categories, n_base], dtype=np.int32) 441 | for i, (_, base_labels) in enumerate(VESSEL_CATEGORIES): 442 | for lbl in base_labels: 443 | j = VESSEL_CLASS_DETAILED_NAMES.index(lbl) 444 | table[i, j] = 1 445 | return table 446 | 447 | 448 | multihot_lookup_table = build_multihot_lookup_table() 449 | 450 | 451 | def multihot_encode(label): 452 | """Multihot encode based on fine, coarse and is_fishing label 453 | 454 | Args: 455 | label: Tensor (int) 456 | 457 | Returns: 458 | Tensor with bits set for every allowable vessel type based on the inputs 459 | 460 | 461 | """ 462 | tf_multihot_lookup_table = tf.convert_to_tensor(multihot_lookup_table) 463 | return tf.gather(tf_multihot_lookup_table, label) 464 | --------------------------------------------------------------------------------