├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── NOTICE ├── README.md ├── docs └── overview.png ├── examples ├── README.md ├── bcq.py ├── bcq_parameter.py ├── quant_model_bcq.py ├── quant_model_rtn.py ├── rtn_parameter.py └── utils.py ├── lutGEMM ├── CMakeLists.txt ├── include │ ├── kernels.h │ ├── lutGEMM │ └── nQWeight_fp16.h └── src │ ├── cuda │ ├── kernels │ │ ├── cublas.cu │ │ ├── cublas.h │ │ ├── dequant.hpp │ │ ├── dequant_fp16.hpp │ │ ├── gptq_faster_fp16_bias.hpp │ │ ├── gptq_fp16_bias.hpp │ │ ├── mm_t.hpp │ │ ├── mv.hpp │ │ ├── mv_fp16.hpp │ │ └── mv_fp16_bias.hpp │ └── tmpWeight.hpp │ ├── kernels.cu │ └── nQWeight_fp16.cu ├── tests ├── CMakeLists.txt ├── include │ ├── _cublas.h │ ├── custom_random.h │ ├── tests.h │ └── timer.h ├── main.cc ├── opt │ ├── _cublas.cc │ ├── _cublas.cu │ └── fp16 │ │ └── int3_col_wise_matmul_fp16.cu └── src │ └── custom_random.cpp └── thirdparty └── googletest ├── .clang-format ├── .github ├── ISSUE_TEMPLATE │ ├── 00-bug_report.md │ ├── 10-feature_request.md │ └── config.yml └── workflows │ └── gtest-ci.yml ├── .gitignore ├── BUILD.bazel ├── CMakeLists.txt ├── CONTRIBUTING.md ├── CONTRIBUTORS ├── LICENSE ├── README.md ├── WORKSPACE ├── ci ├── linux-presubmit.sh └── macos-presubmit.sh ├── docs ├── _config.yml ├── _data │ └── navigation.yml ├── _layouts │ └── default.html ├── _sass │ └── main.scss ├── advanced.md ├── assets │ └── css │ │ └── style.scss ├── community_created_documentation.md ├── faq.md ├── gmock_cheat_sheet.md ├── gmock_cook_book.md ├── gmock_faq.md ├── gmock_for_dummies.md ├── index.md ├── pkgconfig.md ├── platforms.md ├── primer.md ├── quickstart-bazel.md ├── quickstart-cmake.md ├── reference │ ├── actions.md │ ├── assertions.md │ ├── matchers.md │ ├── mocking.md │ └── testing.md └── samples.md ├── googlemock ├── CMakeLists.txt ├── README.md ├── cmake │ ├── gmock.pc.in │ └── gmock_main.pc.in ├── docs │ └── README.md ├── include │ └── gmock │ │ ├── gmock-actions.h │ │ ├── gmock-cardinalities.h │ │ ├── gmock-function-mocker.h │ │ ├── gmock-matchers.h │ │ ├── gmock-more-actions.h │ │ ├── gmock-more-matchers.h │ │ ├── gmock-nice-strict.h │ │ ├── gmock-spec-builders.h │ │ ├── gmock.h │ │ └── internal │ │ ├── custom │ │ ├── README.md │ │ ├── gmock-generated-actions.h │ │ ├── gmock-matchers.h │ │ └── gmock-port.h │ │ ├── gmock-internal-utils.h │ │ ├── gmock-port.h │ │ └── gmock-pp.h ├── src │ ├── gmock-all.cc │ ├── gmock-cardinalities.cc │ ├── gmock-internal-utils.cc │ ├── gmock-matchers.cc │ ├── gmock-spec-builders.cc │ ├── gmock.cc │ └── gmock_main.cc └── test │ ├── BUILD.bazel │ ├── gmock-actions_test.cc │ ├── gmock-cardinalities_test.cc │ ├── gmock-function-mocker_test.cc │ ├── gmock-internal-utils_test.cc │ ├── gmock-matchers_test.cc │ ├── gmock-more-actions_test.cc │ ├── gmock-nice-strict_test.cc │ ├── gmock-port_test.cc │ ├── gmock-pp-string_test.cc │ ├── gmock-pp_test.cc │ ├── gmock-spec-builders_test.cc │ ├── gmock_all_test.cc │ ├── gmock_ex_test.cc │ ├── gmock_leak_test.py │ ├── gmock_leak_test_.cc │ ├── gmock_link2_test.cc │ ├── gmock_link_test.cc │ ├── gmock_link_test.h │ ├── gmock_output_test.py │ ├── gmock_output_test_.cc │ ├── gmock_output_test_golden.txt │ ├── gmock_stress_test.cc │ ├── gmock_test.cc │ └── gmock_test_utils.py └── googletest ├── CMakeLists.txt ├── README.md ├── cmake ├── Config.cmake.in ├── gtest.pc.in ├── gtest_main.pc.in ├── internal_utils.cmake └── libgtest.la.in ├── docs └── README.md ├── include └── gtest │ ├── gtest-death-test.h │ ├── gtest-matchers.h │ ├── gtest-message.h │ ├── gtest-param-test.h │ ├── gtest-printers.h │ ├── gtest-spi.h │ ├── gtest-test-part.h │ ├── gtest-typed-test.h │ ├── gtest.h │ ├── gtest_pred_impl.h │ ├── gtest_prod.h │ └── internal │ ├── custom │ ├── README.md │ ├── gtest-port.h │ ├── gtest-printers.h │ └── gtest.h │ ├── gtest-death-test-internal.h │ ├── gtest-filepath.h │ ├── gtest-internal.h │ ├── gtest-param-util.h │ ├── gtest-port-arch.h │ ├── gtest-port.h │ ├── gtest-string.h │ └── gtest-type-util.h ├── samples ├── prime_tables.h ├── sample1.cc ├── sample1.h ├── sample10_unittest.cc ├── sample1_unittest.cc ├── sample2.cc ├── sample2.h ├── sample2_unittest.cc ├── sample3-inl.h ├── sample3_unittest.cc ├── sample4.cc ├── sample4.h ├── sample4_unittest.cc ├── sample5_unittest.cc ├── sample6_unittest.cc ├── sample7_unittest.cc ├── sample8_unittest.cc └── sample9_unittest.cc ├── src ├── gtest-all.cc ├── gtest-death-test.cc ├── gtest-filepath.cc ├── gtest-internal-inl.h ├── gtest-matchers.cc ├── gtest-port.cc ├── gtest-printers.cc ├── gtest-test-part.cc ├── gtest-typed-test.cc ├── gtest.cc └── gtest_main.cc └── test ├── BUILD.bazel ├── googletest-break-on-failure-unittest.py ├── googletest-break-on-failure-unittest_.cc ├── googletest-catch-exceptions-test.py ├── googletest-catch-exceptions-test_.cc ├── googletest-color-test.py ├── googletest-color-test_.cc ├── googletest-death-test-test.cc ├── googletest-death-test_ex_test.cc ├── googletest-env-var-test.py ├── googletest-env-var-test_.cc ├── googletest-failfast-unittest.py ├── googletest-failfast-unittest_.cc ├── googletest-filepath-test.cc ├── googletest-filter-unittest.py ├── googletest-filter-unittest_.cc ├── googletest-global-environment-unittest.py ├── googletest-global-environment-unittest_.cc ├── googletest-json-outfiles-test.py ├── googletest-json-output-unittest.py ├── googletest-list-tests-unittest.py ├── googletest-list-tests-unittest_.cc ├── googletest-listener-test.cc ├── googletest-message-test.cc ├── googletest-options-test.cc ├── googletest-output-test-golden-lin.txt ├── googletest-output-test.py ├── googletest-output-test_.cc ├── googletest-param-test-invalid-name1-test.py ├── googletest-param-test-invalid-name1-test_.cc ├── googletest-param-test-invalid-name2-test.py ├── googletest-param-test-invalid-name2-test_.cc ├── googletest-param-test-test.cc ├── googletest-param-test-test.h ├── googletest-param-test2-test.cc ├── googletest-port-test.cc ├── googletest-printers-test.cc ├── googletest-setuptestsuite-test.py ├── googletest-setuptestsuite-test_.cc ├── googletest-shuffle-test.py ├── googletest-shuffle-test_.cc ├── googletest-test-part-test.cc ├── googletest-throw-on-failure-test.py ├── googletest-throw-on-failure-test_.cc ├── googletest-uninitialized-test.py ├── googletest-uninitialized-test_.cc ├── gtest-typed-test2_test.cc ├── gtest-typed-test_test.cc ├── gtest-typed-test_test.h ├── gtest-unittest-api_test.cc ├── gtest_all_test.cc ├── gtest_assert_by_exception_test.cc ├── gtest_environment_test.cc ├── gtest_help_test.py ├── gtest_help_test_.cc ├── gtest_json_test_utils.py ├── gtest_list_output_unittest.py ├── gtest_list_output_unittest_.cc ├── gtest_main_unittest.cc ├── gtest_no_test_unittest.cc ├── gtest_pred_impl_unittest.cc ├── gtest_premature_exit_test.cc ├── gtest_prod_test.cc ├── gtest_repeat_test.cc ├── gtest_skip_check_output_test.py ├── gtest_skip_environment_check_output_test.py ├── gtest_skip_in_environment_setup_test.cc ├── gtest_skip_test.cc ├── gtest_sole_header_test.cc ├── gtest_stress_test.cc ├── gtest_test_macro_stack_footprint_test.cc ├── gtest_test_utils.py ├── gtest_testbridge_test.py ├── gtest_testbridge_test_.cc ├── gtest_throw_on_failure_ex_test.cc ├── gtest_unittest.cc ├── gtest_xml_outfile1_test_.cc ├── gtest_xml_outfile2_test_.cc ├── gtest_xml_outfiles_test.py ├── gtest_xml_output_unittest.py ├── gtest_xml_output_unittest_.cc ├── gtest_xml_test_utils.py ├── production.cc └── production.h /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.18) 2 | project(nQmatmul CXX C) 3 | enable_language(CUDA) 4 | 5 | find_package(CUDA 10.1 REQUIRED) 6 | 7 | 8 | if(NOT DEFINED ${CMAKE_CUDA_ARCHITECTURES}) 9 | set(CMAKE_CUDA_ARCHITECTURES 80) 10 | endif() 11 | 12 | add_subdirectory(thirdparty/googletest) 13 | add_subdirectory(lutGEMM) 14 | add_subdirectory(tests) 15 | 16 | add_library(${PROJECT_NAME} 17 | INTERFACE 18 | ) 19 | 20 | target_link_libraries(${PROJECT_NAME} 21 | INTERFACE 22 | gtest 23 | lutGEMM 24 | ) 25 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | LUT-GEMM 2 | Copyright (c) 2024-present NAVER Cloud Corp. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | 16 | -------------------------------------------------------------------------------------- 17 | 18 | This project contains subcomponents with separate copyright notices and license terms. 19 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 20 | 21 | ===== 22 | 23 | google/googletest 24 | https://github.com/google/googletest 25 | 26 | 27 | Copyright 2008, Google Inc. 28 | All rights reserved. 29 | 30 | Redistribution and use in source and binary forms, with or without 31 | modification, are permitted provided that the following conditions are 32 | met: 33 | 34 | * Redistributions of source code must retain the above copyright 35 | notice, this list of conditions and the following disclaimer. 36 | * Redistributions in binary form must reproduce the above 37 | copyright notice, this list of conditions and the following disclaimer 38 | in the documentation and/or other materials provided with the 39 | distribution. 40 | * Neither the name of Google Inc. nor the names of its 41 | contributors may be used to endorse or promote products derived from 42 | this software without specific prior written permission. 43 | 44 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 45 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 46 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 47 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 48 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 49 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 50 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 51 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 52 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 53 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 54 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 55 | 56 | ===== 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LUT-GEMM 2 | 3 | This repository provides the official implementation of LUT-GEMM from the following paper. 4 | 5 | **LUT-GEMM: Qantized Matrix Multiplication based on LUTs for Efficient Inference in Large-Scale Generative Language Models** 6 | 7 | _Gunho Park, Baeseong Park, Minsub Kim, Sungjae Lee, Jeonghoon Kim, Beomseok Kwon, Se Jung Kwon, Byeongwook Kim, Youngjoo Lee, and Dongsoo Lee_ 8 | 9 | Paper: https://arxiv.org/pdf/2206.09557.pdf 10 | 11 | Abstract: Our proposed kernel, LUT-GEMM, accelerates quantized matrix multiplication by leveraging both uniform and non-uniform quantization techniques. Utilizing sub-4-bit quantized weights, it offers flexibility and achieves high compression ratios, allowing a balance between accuracy and efficiency. Through the use of low-bit quantization and efficient LUT-based operations, it effectively reduces memory usage and computational costs, thereby significantly enhancing the inference speed of large-scale language models. 12 | 13 |

image

14 | 15 | 16 | ## Quick Start 17 | 18 | Run the following commands to get **`Kernel Evaluation`** results in Table 1. 19 | 20 | ``` sh 21 | mkdir build 22 | cd build 23 | cmake -DCMAKE_CUDA_ARCHITECTURES=80 .. 24 | make -j8 25 | ./tests/tests 26 | ``` 27 | 28 | ## Citation 29 | 30 | ``` 31 | @misc{park2023lutgemm, 32 | title={LUT-GEMM: Quantized Matrix Multiplication based on LUTs for Efficient Inference in Large-Scale Generative Language Models}, 33 | author={Gunho Park, Baeseong Park, Minsub Kim, Sungjae Lee, Jeonghoon Kim, Beomseok Kwon, Se Jung Kwon, Byeongwook Kim, Youngjoo Lee and Dongsoo Lee}, 34 | year={2023}, 35 | eprint={2206.09557}, 36 | archivePrefix={arXiv}, 37 | primaryClass={cs.DC} 38 | } 39 | ``` 40 | 41 | ## License 42 | 43 | ``` 44 | Copyright (c) 2024-present NAVER Cloud Corp. 45 | 46 | Licensed under the Apache License, Version 2.0 (the "License"); 47 | you may not use this file except in compliance with the License. 48 | You may obtain a copy of the License at 49 | 50 | http://www.apache.org/licenses/LICENSE-2.0 51 | 52 | Unless required by applicable law or agreed to in writing, software 53 | distributed under the License is distributed on an "AS IS" BASIS, 54 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 55 | See the License for the specific language governing permissions and 56 | limitations under the License. 57 | ``` 58 | -------------------------------------------------------------------------------- /docs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-aics/lut-gemm/647d49fcf0b054b6dafee04b5ce1c8c2adb047e5/docs/overview.png -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # LUT-GEMM 2 | 3 | ## Model Quantization Examples 4 | 5 | Run the following commands to get the binary matrices and scaling factor matrices from the pre-trained weights. 6 | 7 | ``` sh 8 | python quant_model_bcq.py \ 9 | --model_name_or_path facebook/opt-125m \ 10 | --qbits 4 \ 11 | --group_size 128 12 | ``` 13 | 14 | ``` sh 15 | python quant_model_rtn.py \ 16 | --model_name_or_path facebook/opt-125m \ 17 | --qbits 4 \ 18 | --group_size 128 19 | ``` -------------------------------------------------------------------------------- /examples/bcq_parameter.py: -------------------------------------------------------------------------------- 1 | # LUT-GEMM 2 | # Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | from bcq import quantize 20 | from utils import CompressionParameter, PACKER 21 | 22 | class BCQParameter(CompressionParameter): 23 | def compress(self, do_packing=False, in_ch_wise=False, **kwargs): 24 | global PACKER 25 | _, binary, alpha, _ = quantize(self.data, transpose=in_ch_wise, **kwargs) 26 | 27 | binary_shape = binary.shape 28 | if do_packing == True: 29 | binary, binary_shape = PACKER.pack(binary) 30 | binary = binary.to(self.data.device) 31 | 32 | return alpha, binary, binary_shape 33 | 34 | def decompress(self, alpha, binary, binary_shape, offset=None, do_packing=False, in_ch_wise=False): 35 | global PACKER 36 | 37 | if do_packing == True: 38 | binary = PACKER.unpack(binary, binary_shape, dtype=self.data.dtype) 39 | binary = binary.to(self.data.device) 40 | 41 | # w.shape = [out_ch, in_ch] 42 | # in_ch_wise == True 43 | # -> binary.shape = [in_ch, out_ch//group_size, group_size, qbits] 44 | # -> alpha.shape = [in_ch, out_ch//group_size, qbits] 45 | # -> offset.shape = [in_ch, out_ch//group_size, 1] 46 | # in_ch_wise == False 47 | # -> binary.shape = [out_ch, in_ch//group_size, group_size, qbits] 48 | # -> alpha.shape = [out_ch, in_ch//group_size, qbits] 49 | # -> offset.shape = [out_ch, in_ch//group_size, 1] 50 | 51 | if in_ch_wise == True: 52 | out_ch = binary_shape[1] * binary_shape[2] 53 | decomp_w = torch.einsum('iogb,iob->iog', (binary, alpha)) 54 | if offset is not None: 55 | decomp_w = decomp_w + offset 56 | decomp_w = decomp_w.reshape([-1, out_ch]).T 57 | else: 58 | out_ch = binary_shape[0] 59 | decomp_w = torch.einsum('oigb,oib->oig', (binary, alpha)) 60 | if offset is not None: 61 | decomp_w = decomp_w + offset 62 | decomp_w = decomp_w.reshape([out_ch, -1]) 63 | self.data = decomp_w 64 | 65 | class BCQTunedParameter: 66 | def __init__(self, alpha, binary, binary_shape, do_packing=False, in_ch_wise=False): 67 | #self.alpha = nn.Parameter( 68 | pass 69 | 70 | if __name__ == '__main__': 71 | w_org = torch.randn(1024, 256) 72 | 73 | w_bcq = BCQParameter(w_org) 74 | alpha, binary, binary_shape = w_bcq.compress(do_packing=False, in_ch_wise=True, qbits=4, rounds=15, group_size=128) 75 | w_bcq.decompress(alpha, binary, binary_shape, do_packing=False, in_ch_wise=True) 76 | print(abs(w_org-w_bcq.data).mean()) 77 | 78 | w_bcq = BCQParameter(w_org) 79 | alpha, binary, binary_shape = w_bcq.compress(do_packing=False, in_ch_wise=False, qbits=4, rounds=15, group_size=128) 80 | w_bcq.decompress(alpha, binary, binary_shape, do_packing=False, in_ch_wise=False) 81 | print(abs(w_org-w_bcq.data).mean()) 82 | -------------------------------------------------------------------------------- /examples/quant_model_bcq.py: -------------------------------------------------------------------------------- 1 | # LUT-GEMM 2 | # Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import transformers 18 | import argparse 19 | 20 | from transformers import ( 21 | AutoTokenizer, 22 | AutoModelForCausalLM, 23 | set_seed, 24 | Seq2SeqTrainer, 25 | GPTQConfig, 26 | LlamaTokenizer 27 | 28 | ) 29 | 30 | from rtn_parameter import RTNParameter 31 | from bcq_parameter import BCQParameter 32 | 33 | 34 | layers = ["q_proj","k_proj","v_proj","out_proj","fc1","fc2"] 35 | 36 | def parse_args(): 37 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task") 38 | parser.add_argument( 39 | "--model_name_or_path", 40 | type=str, 41 | default='facebook/opt-125m', 42 | help="Path to pretrained model or model identifier from huggingface.co/models.", 43 | ) 44 | parser.add_argument( 45 | "--cache_dir", 46 | type=str, 47 | default=None, 48 | help="Pretrained config name or path if not the same as model_name", 49 | ) 50 | parser.add_argument( 51 | "--qbits", 52 | type=int, 53 | default=4, 54 | help="Quantization Bits.", 55 | ) 56 | parser.add_argument( 57 | "--group_size", 58 | type=int, 59 | default=128, 60 | help="quantization grouping size for weights", 61 | ) 62 | args = parser.parse_args() 63 | 64 | return args 65 | 66 | #def quant_model(model, module_to_not_convert:str = "lm_head"): 67 | def quant_model(model, args): 68 | for name, module in model.named_children(): 69 | if len(list(module.children())) > 0: 70 | quant_model(module, args) 71 | 72 | if any(x in name for x in layers): 73 | print(name) 74 | original_weight = module.weight.clone().detach() 75 | # INT4 Quantization -> BCQ 76 | w_bcq = BCQParameter(original_weight) 77 | alpha, binary, binary_shape = w_bcq.compress( 78 | do_packing=False, in_ch_wise=True, qbits=args.qbits, 79 | rounds=15, group_size=args.group_size) 80 | 81 | print(alpha.size()) 82 | print(binary.size()) 83 | print("="*30) 84 | 85 | return model 86 | 87 | def main(): 88 | args = parse_args() 89 | 90 | model = AutoModelForCausalLM.from_pretrained( 91 | args.model_name_or_path, 92 | cache_dir=args.cache_dir, 93 | ) 94 | model = quant_model(model, args) 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /lutGEMM/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | # project(lutGEMM) 4 | project(lutGEMM LANGUAGES CXX CUDA) 5 | enable_language(CUDA) 6 | 7 | add_library(${PROJECT_NAME} INTERFACE) 8 | 9 | 10 | add_library(lutgemm STATIC src/nQWeight_fp16.cu src/kernels.cu) 11 | set_property(TARGET lutgemm PROPERTY POSITION_INDEPENDENT_CODE ON) 12 | set_property(TARGET lutgemm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 13 | target_link_libraries(lutgemm PUBLIC -lcublas -lcudart -lcurand -lgomp) 14 | 15 | 16 | if(NOT DEFINED ${CMAKE_CUDA_ARCHITECTURES}) 17 | set(CMAKE_CUDA_ARCHITECTURES 80) 18 | endif() 19 | 20 | target_include_directories(${PROJECT_NAME} 21 | INTERFACE 22 | include 23 | ) 24 | -------------------------------------------------------------------------------- /lutGEMM/include/kernels.h: -------------------------------------------------------------------------------- 1 | /* LUT-GEMM 2 | * Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma ones 18 | 19 | #ifndef KERNELS_H 20 | #define KERNELS_H 21 | 22 | #include "nQWeight_fp16.h" 23 | 24 | namespace lutGEMM{ 25 | 26 | void matmul(void* output, nQWeight_fp16 &nqW, void* input, int n, int algo=0); 27 | void matmul(void* output, void* input, nQWeight_fp16 &nqW, int m, int algo=0); 28 | void matmul_gptq( 29 | int m, int n, int k, void *scale, void *bias, 30 | void *A, void *B, void *C); 31 | void matmul_gptq_faster( 32 | int m, int n, int k, void *scale, void *bias, 33 | void *A, void *B, void *C); 34 | } 35 | 36 | #endif 37 | 38 | -------------------------------------------------------------------------------- /lutGEMM/include/lutGEMM: -------------------------------------------------------------------------------- 1 | /* LUT-GEMM 2 | * Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "nQWeight_fp16.h" 18 | #include "kernels.h" 19 | -------------------------------------------------------------------------------- /lutGEMM/include/nQWeight_fp16.h: -------------------------------------------------------------------------------- 1 | /* LUT-GEMM 2 | * Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef N_Q_WEIGHT_FP16_H 18 | #define N_Q_WEIGHT_FP16_H 19 | 20 | namespace lutGEMM{ 21 | 22 | class nQWeight_fp16; 23 | 24 | void dequantize_gpu(nQWeight_fp16 &nqw, void *d_fW, int algo=0); 25 | void dequantize_cpu(nQWeight_fp16 &nqw, void *fW); 26 | 27 | class nQWeight_fp16{ 28 | public: 29 | unsigned int* bWeight; // Weight[kSize/32][nb][mSize] 30 | void* alpha; // alpha[num_alpha_groups][nb][mSize] 31 | void* q_bias; //q_bias[num_alpha_groups][mSize] 32 | int num_groups; 33 | int group_size; 34 | int mSize; 35 | int kSize; 36 | int nb; 37 | bool is_row_wise_quantize; 38 | nQWeight_fp16() {} 39 | 40 | /* uint32 bW[kSize/32][nb][mSize] alpha[num_alpha_groups][mSize][nb] */ 41 | nQWeight_fp16(unsigned int *bW, float *A, int row, int col, int num_bits, 42 | bool is_row_wise_quantize, int num_alpha_groups=1, float* q_bias=nullptr){ 43 | parsing(bW, A, row, col, num_bits, is_row_wise_quantize, num_alpha_groups, q_bias); 44 | } 45 | 46 | void parsing(unsigned int *bW, float *A, int row, int col, int num_bits, 47 | bool is_row_wise_quantize, int num_alpha_groups=1, float* q_bias=nullptr); 48 | 49 | ~nQWeight_fp16(); 50 | 51 | void* getDequantiedWeight(bool onGPU=true); 52 | }; 53 | 54 | } 55 | #endif // N_Q_WEIGHT_FP16_H 56 | -------------------------------------------------------------------------------- /lutGEMM/src/cuda/kernels/cublas.cu: -------------------------------------------------------------------------------- 1 | /* LUT-GEMM 2 | * Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | namespace kernel{ 24 | 25 | template 26 | inline cublasStatus_t cublas_gemm_ex(T *A, T *B, S *C, 27 | int m, int n, int k); 28 | 29 | typedef cublasStatus_t<__half 30 | 31 | template 32 | inline cublasStatus_t cublas_gemm_ex(T *A, T *B, S *C, 33 | int m, int n, int k) { 34 | static S alpha = 1; 35 | static S beta = 0; 36 | static cublasHandle_t handle = nullptr; 37 | if(handle == nullptr) cublasCreate(&handle); 38 | 39 | cudaDataType_t AType, BType, CType; 40 | cublasComputeType_t ComputeType; 41 | if (std::is_same::value) { 42 | AType = BType = CType = CUDA_R_32F; 43 | ComputeType = CUBLAS_COMPUTE_32F_FAST_TF32; 44 | } else if (std::is_same::value) { 45 | AType = BType = CType = CUDA_R_16F; 46 | ComputeType = CUBLAS_COMPUTE_16F; 47 | } else if (std::is_same::value) { 48 | AType = BType = CUDA_R_8I; 49 | CType = CUDA_R_32I; 50 | ComputeType = CUBLAS_COMPUTE_32I; 51 | } else { 52 | printf("Not supported data type."); 53 | return CUBLAS_STATUS_NOT_SUPPORTED; 54 | } 55 | return cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, 56 | n, m, k, 57 | &alpha, 58 | B, BType, n, 59 | A, AType, k, 60 | &beta, 61 | C, CType, n, 62 | ComputeType, 63 | CUBLAS_GEMM_DFALT); 64 | } 65 | 66 | } 67 | -------------------------------------------------------------------------------- /lutGEMM/src/cuda/kernels/cublas.h: -------------------------------------------------------------------------------- 1 | /* LUT-GEMM 2 | * Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | namespace kernel{ 24 | template 25 | inline cublasStatus_t cublas_gemm_ex(T *A, T *B, S *C, 26 | int m, int n, int k) { 27 | static S alpha = 1; 28 | static S beta = 0; 29 | static cublasHandle_t handle = nullptr; 30 | if(handle == nullptr) cublasCreate(&handle); 31 | 32 | cudaDataType_t AType, BType, CType; 33 | cublasComputeType_t ComputeType; 34 | if (std::is_same::value) { 35 | AType = BType = CType = CUDA_R_32F; 36 | ComputeType = CUBLAS_COMPUTE_32F_FAST_TF32; 37 | } else if (std::is_same::value) { 38 | AType = BType = CType = CUDA_R_16F; 39 | ComputeType = CUBLAS_COMPUTE_16F; 40 | } else if (std::is_same::value) { 41 | AType = BType = CUDA_R_8I; 42 | CType = CUDA_R_32I; 43 | ComputeType = CUBLAS_COMPUTE_32I; 44 | } else { 45 | return CUBLAS_STATUS_NOT_SUPPORTED; 46 | } 47 | return cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, 48 | n, m, k, 49 | &alpha, 50 | B, BType, n, 51 | A, AType, k, 52 | &beta, 53 | C, CType, n, 54 | ComputeType, 55 | CUBLAS_GEMM_DFALT); 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /lutGEMM/src/cuda/kernels/mm_t.hpp: -------------------------------------------------------------------------------- 1 | /* LUT-GEMM 2 | * Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef KERNELS_MM_T_HPP 18 | #define KERNELS_MM_T_HPP 19 | 20 | namespace kernel{ 21 | 22 | 23 | template 24 | __global__ void _nqmm_t(uint32_t *W, float *alpha, float *input, float *output, int M, int N, int K){ 25 | 26 | __shared__ float lut[K_TILE_SIZE/8][256][N_TILE_SIZE]; 27 | 28 | const int lut_y_size = K_TILE_SIZE/8; 29 | const int lut_x_size = blockDim.y / (K_TILE_SIZE/8); 30 | 31 | int lut_y = threadIdx.y/lut_x_size; 32 | int lut_x = threadIdx.y%lut_x_size; 33 | int lut_z = threadIdx.x; 34 | 35 | float *_inp = &input[lut_z*K + (blockIdx.y * K_TILE_SIZE + lut_y * 8) ]; 36 | float base = + (2 * ((lut_x>>0) & 1) - 1) * _inp[0] 37 | + (2 * ((lut_x>>1) & 1) - 1) * _inp[1] 38 | + (2 * ((lut_x>>2) & 1) - 1) * _inp[2] 39 | + (2 * ((lut_x>>3) & 1) - 1) * _inp[3] 40 | + (2 * ((lut_x>>4) & 1) - 1) * _inp[4] 41 | + (2 * ((lut_x>>5) & 1) - 1) * _inp[5] 42 | + (2 * ((lut_x>>6) & 1) - 1) * _inp[6] 43 | + (2 * ((lut_x>>7) & 1) - 1) * _inp[7] ; 44 | 45 | lut[lut_y][lut_x][lut_z] = base; 46 | 47 | int s = (lut_x_size==1) ?0: 48 | (lut_x_size==2) ?1: 49 | (lut_x_size==4) ?2: 50 | (lut_x_size==8) ?3: 51 | (lut_x_size==16) ?4: 52 | (lut_x_size==32) ?5: 53 | (lut_x_size==64) ?6: 54 | (lut_x_size==128)?7: 8; 55 | for(;s<8;s++){ 56 | float iValue = 2*_inp[s]; 57 | for (int i = (1 << s); i < (1 << (s + 1)); i += lut_x_size) { 58 | lut[lut_y][i + lut_x][lut_z] = lut[lut_y][i + lut_x - (1 << s)][lut_z] + iValue; 59 | } 60 | } 61 | __syncthreads(); 62 | 63 | int m_start = blockIdx.x * M_TILE_SIZE + threadIdx.y; 64 | int m_end = (blockIdx.x + 1) * M_TILE_SIZE; 65 | m_end = (m_end < M) ? m_end : M; 66 | int m_step = blockDim.y; 67 | 68 | uint32_t *bW = &W[blockIdx.y * K_TILE_SIZE/32 * NUM_BITS * M]; 69 | float *_output = &output[lut_z * M]; 70 | for(int m = m_start;m < m_end;m += m_step){ 71 | float reg_o = 0; 72 | for(int b=0;b < NUM_BITS;b++){ 73 | float reg_a = alpha[b * M + m]; 74 | float reg_t_o = 0; 75 | for(int kt=0;kt < K_TILE_SIZE/32;kt++){ 76 | uint32_t reg_w = bW[kt * NUM_BITS * M + b * M + m]; 77 | int reg_w0 = (reg_w >> 8 * 0) & 255; reg_t_o += + lut[kt*4 + 0][reg_w0][lut_z]; 78 | int reg_w1 = (reg_w >> 8 * 1) & 255; reg_t_o += + lut[kt*4 + 1][reg_w1][lut_z]; 79 | int reg_w2 = (reg_w >> 8 * 2) & 255; reg_t_o += + lut[kt*4 + 2][reg_w2][lut_z]; 80 | int reg_w3 = (reg_w >> 8 * 3) & 255; reg_t_o += + lut[kt*4 + 3][reg_w3][lut_z]; 81 | } 82 | reg_o += reg_a * reg_t_o; 83 | } 84 | atomicAdd(&_output[m], reg_o); 85 | } 86 | } 87 | 88 | 89 | 90 | } 91 | 92 | #endif //KERNELS_MM_T_HPP -------------------------------------------------------------------------------- /lutGEMM/src/cuda/tmpWeight.hpp: -------------------------------------------------------------------------------- 1 | /* LUT-GEMM 2 | * Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef TMP_WEIGHT_HPP 18 | #define TMP_WEIGHT_HPP 19 | 20 | 21 | class tmpWeight{ 22 | public: 23 | static tmpWeight& getInstance(){ 24 | static tmpWeight ins; 25 | return ins; 26 | } 27 | 28 | float* getWeight(int Size){ 29 | if(Size > size){ 30 | mem_free(); 31 | size = Size; 32 | cudaMallocManaged(&mem, sizeof(float) * size); 33 | } 34 | return mem; 35 | } 36 | 37 | private: 38 | void mem_free(){ 39 | if(mem != nullptr) 40 | cudaFree(mem); 41 | } 42 | float *mem = nullptr; 43 | int size = 0; 44 | 45 | tmpWeight(const tmpWeight&) = delete; 46 | tmpWeight& operator=(const tmpWeight&) = delete; 47 | tmpWeight(/* args */){ } 48 | ~tmpWeight(){ 49 | mem_free(); 50 | } 51 | }; 52 | 53 | 54 | #endif // TMP_WEIGHT_HPP -------------------------------------------------------------------------------- /lutGEMM/src/kernels.cu: -------------------------------------------------------------------------------- 1 | /* LUT-GEMM 2 | * Copyright (c) 2024-present NAVER Cloud Corp. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "../include/kernels.h" 18 | #include 19 | #include 20 | #include 21 | 22 | namespace lutGEMM{ 23 | 24 | #include "../src/cuda/kernels/cublas.h" 25 | #include "../src/cuda/kernels/mv_fp16.hpp" 26 | #include "../src/cuda/kernels/mv_fp16_bias.hpp" 27 | #include "../src/cuda/kernels/gptq_fp16_bias.hpp" 28 | #include "../src/cuda/kernels/gptq_faster_fp16_bias.hpp" 29 | 30 | void matmul(void* output, nQWeight_fp16 &nqW, void* input, int n, int algo); 31 | void matmul(void* output, void* input, nQWeight_fp16 &nqW, int m, int algo); 32 | 33 | 34 | 35 | /* float16 */ 36 | inline void matmul_useCublas(__half* output, nQWeight_fp16 &nqW, __half* input, int n); 37 | inline void matmul_useCublas(__half* output, __half* input, nQWeight_fp16 &nqW, int m); 38 | /************************** float16 ***********************/ 39 | 40 | void matmul_gptq( 41 | int m, int n, int k, void *scale, void *bias, 42 | void *A, void *B, void *C){ 43 | cudaMemset(C, 0, sizeof(__half) * m * n); 44 | kernel::gptq(n, k, (__half*)scale, (__half*)bias, 45 | (__half*)A, (uint32_t*)B, (__half*)C); 46 | } 47 | 48 | void matmul_gptq_faster( 49 | int m, int n, int k, void *scale, void *bias, 50 | void *A, void *B, void *C){ 51 | cudaMemset(C, 0, sizeof(__half) * m * n); 52 | kernel::gptq_faster(n, k, (__half*)scale, (__half*)bias, 53 | (half2*)A, (uint32_t*)B, (__half*)C); 54 | } 55 | 56 | void matmul(void* output, nQWeight_fp16 &nqW, void* input, int n, int algo){ 57 | if(n==1){ 58 | cudaMemset(output, 0, sizeof(__half) * nqW.mSize); // 0.007ms 0.04 59 | if(nqW.q_bias == nullptr) kernel::nqmv((__half*)output, nqW, (__half*)input, algo); 60 | else kernel::nqmv_bias((__half*)output, nqW, (__half*)input, algo); 61 | } 62 | else matmul_useCublas((__half*)output, nqW, (__half*)input, n); 63 | } 64 | void matmul(void* output, void* input, nQWeight_fp16 &nqW, int m, int algo){ 65 | if(m==1){ 66 | cudaMemset(output, 0, sizeof(__half) * nqW.mSize); 67 | if(nqW.q_bias == nullptr) kernel::nqmv((__half*)output, nqW, (__half*)input, algo); 68 | else kernel::nqmv_bias((__half*)output, nqW, (__half*)input, algo); 69 | } 70 | else matmul_useCublas((__half*)output, (__half*)input, nqW, m); 71 | } 72 | 73 | inline void matmul_useCublas(__half* output, nQWeight_fp16 &nqW, __half* input, int n) { 74 | kernel::cublas_gemm_ex((__half*)nqW.getDequantiedWeight(true), input, output, nqW.mSize, n, nqW.kSize); 75 | } 76 | 77 | inline void matmul_useCublas(__half* output, __half* input, nQWeight_fp16 &nqW, int m) { 78 | kernel::cublas_gemm_ex(input, (__half*)nqW.getDequantiedWeight(true), output, m, nqW.mSize, nqW.kSize); 79 | } 80 | 81 | } 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.18 FATAL_ERROR) 2 | project(tests LANGUAGES CXX CUDA) 3 | enable_language(CUDA) 4 | 5 | # set(CMAKE_CXX_STANDARD 14) 6 | # set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | # set(CMAKE_CXX_EXTENSIONS OFF) 8 | # add_library(my_lib STATIC opt/_cublas.cu) 9 | # set_property(TARGET my_lib PROPERTY POSITION_INDEPENDENT_CODE ON) 10 | # set_property(TARGET my_lib PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) 11 | # target_link_libraries(my_lib PUBLIC -lcublas -lcudart -lcurand -lgomp) 12 | # target_include_directories(my_lib 13 | # PUBLIC 14 | # include 15 | # gtest 16 | # lutGEMM 17 | 18 | # ) 19 | add_executable(tests 20 | src/custom_random.cpp 21 | main.cc 22 | 23 | opt/fp16/int3_col_wise_matmul_fp16.cu 24 | 25 | ) 26 | set_target_properties(tests PROPERTIES POSITION_INDEPENDENT_CODE ON) 27 | set_target_properties(tests PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON) 28 | set_target_properties(tests PROPERTIES LINKER_LANGUAGE CXX) 29 | target_link_libraries(tests -lcublasLt -lcublas -lcurand -lcudart -lcuda -lgomp lutgemm) 30 | 31 | # set(CMAKE_CXX_FLAGS_DEBUG "-O3 -fopenmp") 32 | 33 | 34 | 35 | set(CPACK_PROJECT_NAME ${PROJECT_NAME}) 36 | set(CPACK_PROJECT_VERSION ${PROJECT_VERSION}) 37 | include(CPack) 38 | 39 | target_include_directories(${PROJECT_NAME} 40 | PUBLIC 41 | include 42 | gtest 43 | lutGEMM 44 | 45 | ) 46 | 47 | target_link_libraries(${PROJECT_NAME} 48 | # PRIVATE 49 | gtest 50 | lutGEMM 51 | ) -------------------------------------------------------------------------------- /tests/include/_cublas.h: -------------------------------------------------------------------------------- 1 | #pragma ones 2 | void test_cuBlas(int m, int n, int k, bool cmp_check, int iter = 128); 3 | -------------------------------------------------------------------------------- /tests/include/custom_random.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | void random_seed(); 5 | bool rand_bool(); 6 | double rand_fp64(double max=1.0); 7 | float rand_fp32(float max=1.0); -------------------------------------------------------------------------------- /tests/include/tests.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | /* submodule */ 8 | #include "gtest/gtest.h" 9 | 10 | /* custom module */ 11 | #include "custom_random.h" 12 | #include "timer.h" 13 | 14 | #include 15 | 16 | #include "lutGEMM" 17 | 18 | #ifndef GTEST_PIRNTF 19 | #define GTEST_PIRNTF(...){\ 20 | printf("\033[32m[ ]");\ 21 | printf("\033[0m ");\ 22 | printf(__VA_ARGS__);\ 23 | printf("\n");\ 24 | } 25 | #endif 26 | -------------------------------------------------------------------------------- /tests/include/timer.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | class timer { 6 | public: 7 | std::vector arr; 8 | bool sort_flag = false; 9 | double s; 10 | void start(){ 11 | s = omp_get_wtime(); 12 | } 13 | double end(){ 14 | double l = (omp_get_wtime() - s) * 1000.0; 15 | arr.push_back(l); 16 | sort_flag = false; 17 | return l; 18 | } 19 | 20 | double mean(){ 21 | double sum=0; 22 | for(auto it : arr) 23 | sum += it; 24 | return sum/arr.size(); 25 | } 26 | 27 | void sort(){ 28 | if(sort_flag) return; 29 | std::sort(arr.begin(), arr.end()); 30 | sort_flag = true; 31 | } 32 | 33 | 34 | double pile(float p){ 35 | sort(); 36 | int idx = (arr.size() - 1) * p; 37 | return arr[idx]; 38 | } 39 | 40 | double max(){ 41 | sort(); 42 | return arr[arr.size() - 1]; 43 | } 44 | double min(){ 45 | sort(); 46 | return arr[0]; 47 | } 48 | 49 | timer(/* args */){} 50 | ~timer(){} 51 | }; 52 | 53 | -------------------------------------------------------------------------------- /tests/main.cc: -------------------------------------------------------------------------------- 1 | #include "tests.h" 2 | 3 | int main(int argc, char **argv) { 4 | ::testing::InitGoogleTest(&argc, argv); 5 | return RUN_ALL_TESTS(); 6 | } 7 | -------------------------------------------------------------------------------- /tests/opt/_cublas.cc: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "_cublas.h" 3 | 4 | TEST(tests, cublas_att){ 5 | test_cuBlas(32, 1, 32, true, 0); 6 | EXPECT_TRUE(true); 7 | } 8 | const int hidden_size = 8192; 9 | const int num_gpu = 1; 10 | TEST(tests, cublas_ffn){ 11 | test_cuBlas(251*128, 1024, 1024, true, 128); 12 | 13 | EXPECT_TRUE(true); 14 | } 15 | -------------------------------------------------------------------------------- /tests/src/custom_random.cpp: -------------------------------------------------------------------------------- 1 | #include "custom_random.h" 2 | 3 | void random_seed(){ 4 | time_t t; 5 | srand((unsigned int)time(&t)); 6 | } 7 | bool rand_bool(){ 8 | return rand()>(RAND_MAX/2); 9 | } 10 | double rand_fp64(double max){ 11 | double sign[] = {-1.0,1.0}; 12 | return (double)sign[rand_bool()]*rand()/RAND_MAX*rand()/RAND_MAX*max; 13 | } 14 | 15 | float rand_fp32(float max){ 16 | return rand_fp64()*max; 17 | } -------------------------------------------------------------------------------- /thirdparty/googletest/.clang-format: -------------------------------------------------------------------------------- 1 | # Run manually to reformat a file: 2 | # clang-format -i --style=file 3 | Language: Cpp 4 | BasedOnStyle: Google 5 | -------------------------------------------------------------------------------- /thirdparty/googletest/.github/ISSUE_TEMPLATE/00-bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: 'bug' 6 | assignees: '' 7 | --- 8 | 9 | **Describe the bug** 10 | 11 | Include a clear and concise description of what the problem is, including what 12 | you expected to happen, and what actually happened. 13 | 14 | **Steps to reproduce the bug** 15 | 16 | It's important that we are able to reproduce the problem that you are 17 | experiencing. Please provide all code and relevant steps to reproduce the 18 | problem, including your `BUILD`/`CMakeLists.txt` file and build commands. Links 19 | to a GitHub branch or [godbolt.org](https://godbolt.org/) that demonstrate the 20 | problem are also helpful. 21 | 22 | **Does the bug persist in the most recent commit?** 23 | 24 | We recommend using the latest commit in the master branch in your projects. 25 | 26 | **What operating system and version are you using?** 27 | 28 | If you are using a Linux distribution please include the name and version of the 29 | distribution as well. 30 | 31 | **What compiler and version are you using?** 32 | 33 | Please include the output of `gcc -v` or `clang -v`, or the equivalent for your 34 | compiler. 35 | 36 | **What build system are you using?** 37 | 38 | Please include the output of `bazel --version` or `cmake --version`, or the 39 | equivalent for your build system. 40 | 41 | **Additional context** 42 | 43 | Add any other context about the problem here. 44 | -------------------------------------------------------------------------------- /thirdparty/googletest/.github/ISSUE_TEMPLATE/10-feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Propose a new feature 4 | title: '' 5 | labels: 'enhancement' 6 | assignees: '' 7 | --- 8 | 9 | **Does the feature exist in the most recent commit?** 10 | 11 | We recommend using the latest commit from GitHub in your projects. 12 | 13 | **Why do we need this feature?** 14 | 15 | Ideally, explain why a combination of existing features cannot be used instead. 16 | 17 | **Describe the proposal** 18 | 19 | Include a detailed description of the feature, with usage examples. 20 | 21 | **Is the feature specific to an operating system, compiler, or build system version?** 22 | 23 | If it is, please specify which versions. 24 | 25 | -------------------------------------------------------------------------------- /thirdparty/googletest/.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | -------------------------------------------------------------------------------- /thirdparty/googletest/.github/workflows/gtest-ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: 4 | push: 5 | pull_request: 6 | 7 | jobs: 8 | Linux: 9 | runs-on: ubuntu-latest 10 | steps: 11 | 12 | - uses: actions/checkout@v2 13 | with: 14 | fetch-depth: 0 15 | 16 | - name: Tests 17 | run: bazel test --test_output=errors //... 18 | 19 | MacOs: 20 | runs-on: macos-latest 21 | steps: 22 | 23 | - uses: actions/checkout@v2 24 | with: 25 | fetch-depth: 0 26 | 27 | - name: Tests 28 | run: bazel test --test_output=errors //... 29 | 30 | 31 | Windows: 32 | runs-on: windows-latest 33 | steps: 34 | 35 | - uses: actions/checkout@v2 36 | with: 37 | fetch-depth: 0 38 | 39 | - name: Tests 40 | run: bazel test --test_output=errors //... 41 | -------------------------------------------------------------------------------- /thirdparty/googletest/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore CI build directory 2 | build/ 3 | xcuserdata 4 | cmake-build-debug/ 5 | .idea/ 6 | bazel-bin 7 | bazel-genfiles 8 | bazel-googletest 9 | bazel-out 10 | bazel-testlogs 11 | # python 12 | *.pyc 13 | 14 | # Visual Studio files 15 | .vs 16 | *.sdf 17 | *.opensdf 18 | *.VC.opendb 19 | *.suo 20 | *.user 21 | _ReSharper.Caches/ 22 | Win32-Debug/ 23 | Win32-Release/ 24 | x64-Debug/ 25 | x64-Release/ 26 | 27 | # Ignore autoconf / automake files 28 | Makefile.in 29 | aclocal.m4 30 | configure 31 | build-aux/ 32 | autom4te.cache/ 33 | googletest/m4/libtool.m4 34 | googletest/m4/ltoptions.m4 35 | googletest/m4/ltsugar.m4 36 | googletest/m4/ltversion.m4 37 | googletest/m4/lt~obsolete.m4 38 | googlemock/m4 39 | 40 | # Ignore generated directories. 41 | googlemock/fused-src/ 42 | googletest/fused-src/ 43 | 44 | # macOS files 45 | .DS_Store 46 | googletest/.DS_Store 47 | googletest/xcode/.DS_Store 48 | 49 | # Ignore cmake generated directories and files. 50 | CMakeFiles 51 | CTestTestfile.cmake 52 | Makefile 53 | cmake_install.cmake 54 | googlemock/CMakeFiles 55 | googlemock/CTestTestfile.cmake 56 | googlemock/Makefile 57 | googlemock/cmake_install.cmake 58 | googlemock/gtest 59 | /bin 60 | /googlemock/gmock.dir 61 | /googlemock/gmock_main.dir 62 | /googlemock/RUN_TESTS.vcxproj.filters 63 | /googlemock/RUN_TESTS.vcxproj 64 | /googlemock/INSTALL.vcxproj.filters 65 | /googlemock/INSTALL.vcxproj 66 | /googlemock/gmock_main.vcxproj.filters 67 | /googlemock/gmock_main.vcxproj 68 | /googlemock/gmock.vcxproj.filters 69 | /googlemock/gmock.vcxproj 70 | /googlemock/gmock.sln 71 | /googlemock/ALL_BUILD.vcxproj.filters 72 | /googlemock/ALL_BUILD.vcxproj 73 | /lib 74 | /Win32 75 | /ZERO_CHECK.vcxproj.filters 76 | /ZERO_CHECK.vcxproj 77 | /RUN_TESTS.vcxproj.filters 78 | /RUN_TESTS.vcxproj 79 | /INSTALL.vcxproj.filters 80 | /INSTALL.vcxproj 81 | /googletest-distribution.sln 82 | /CMakeCache.txt 83 | /ALL_BUILD.vcxproj.filters 84 | /ALL_BUILD.vcxproj 85 | -------------------------------------------------------------------------------- /thirdparty/googletest/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Note: CMake support is community-based. The maintainers do not use CMake 2 | # internally. 3 | 4 | cmake_minimum_required(VERSION 3.5) 5 | 6 | if (POLICY CMP0048) 7 | cmake_policy(SET CMP0048 NEW) 8 | endif (POLICY CMP0048) 9 | 10 | project(googletest-distribution) 11 | set(GOOGLETEST_VERSION 1.11.0) 12 | 13 | if(NOT CYGWIN AND NOT MSYS AND NOT ${CMAKE_SYSTEM_NAME} STREQUAL QNX) 14 | set(CMAKE_CXX_EXTENSIONS OFF) 15 | endif() 16 | 17 | enable_testing() 18 | 19 | include(CMakeDependentOption) 20 | include(GNUInstallDirs) 21 | 22 | #Note that googlemock target already builds googletest 23 | option(BUILD_GMOCK "Builds the googlemock subproject" ON) 24 | option(INSTALL_GTEST "Enable installation of googletest. (Projects embedding googletest may want to turn this OFF.)" ON) 25 | 26 | if(BUILD_GMOCK) 27 | add_subdirectory( googlemock ) 28 | else() 29 | add_subdirectory( googletest ) 30 | endif() 31 | -------------------------------------------------------------------------------- /thirdparty/googletest/CONTRIBUTORS: -------------------------------------------------------------------------------- 1 | # This file contains a list of people who've made non-trivial 2 | # contribution to the Google C++ Testing Framework project. People 3 | # who commit code to the project are encouraged to add their names 4 | # here. Please keep the list sorted by first names. 5 | 6 | Ajay Joshi 7 | Balázs Dán 8 | Benoit Sigoure 9 | Bharat Mediratta 10 | Bogdan Piloca 11 | Chandler Carruth 12 | Chris Prince 13 | Chris Taylor 14 | Dan Egnor 15 | Dave MacLachlan 16 | David Anderson 17 | Dean Sturtevant 18 | Eric Roman 19 | Gene Volovich 20 | Hady Zalek 21 | Hal Burch 22 | Jeffrey Yasskin 23 | Jim Keller 24 | Joe Walnes 25 | Jon Wray 26 | Jói Sigurðsson 27 | Keir Mierle 28 | Keith Ray 29 | Kenton Varda 30 | Kostya Serebryany 31 | Krystian Kuzniarek 32 | Lev Makhlis 33 | Manuel Klimek 34 | Mario Tanev 35 | Mark Paskin 36 | Markus Heule 37 | Martijn Vels 38 | Matthew Simmons 39 | Mika Raento 40 | Mike Bland 41 | Miklós Fazekas 42 | Neal Norwitz 43 | Nermin Ozkiranartli 44 | Owen Carlsen 45 | Paneendra Ba 46 | Pasi Valminen 47 | Patrick Hanna 48 | Patrick Riley 49 | Paul Menage 50 | Peter Kaminski 51 | Piotr Kaminski 52 | Preston Jackson 53 | Rainer Klaffenboeck 54 | Russ Cox 55 | Russ Rufer 56 | Sean Mcafee 57 | Sigurður Ásgeirsson 58 | Sverre Sundsdal 59 | Takeshi Yoshino 60 | Tracy Bialik 61 | Vadim Berman 62 | Vlad Losev 63 | Wolfgang Klier 64 | Zhanyong Wan 65 | -------------------------------------------------------------------------------- /thirdparty/googletest/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2008, Google Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are 6 | met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above 11 | copyright notice, this list of conditions and the following disclaimer 12 | in the documentation and/or other materials provided with the 13 | distribution. 14 | * Neither the name of Google Inc. nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /thirdparty/googletest/WORKSPACE: -------------------------------------------------------------------------------- 1 | workspace(name = "com_google_googletest") 2 | 3 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 4 | 5 | http_archive( 6 | name = "com_google_absl", 7 | sha256 = "aeba534f7307e36fe084b452299e49b97420667a8d28102cf9a0daeed340b859", 8 | strip_prefix = "abseil-cpp-7971fb358ae376e016d2d4fc9327aad95659b25e", 9 | urls = ["https://github.com/abseil/abseil-cpp/archive/7971fb358ae376e016d2d4fc9327aad95659b25e.zip"], # 2021-05-20T02:59:16Z 10 | ) 11 | 12 | http_archive( 13 | name = "rules_python", 14 | sha256 = "98b3c592faea9636ac8444bfd9de7f3fb4c60590932d6e6ac5946e3f8dbd5ff6", 15 | strip_prefix = "rules_python-ed6cc8f2c3692a6a7f013ff8bc185ba77eb9b4d2", 16 | urls = ["https://github.com/bazelbuild/rules_python/archive/ed6cc8f2c3692a6a7f013ff8bc185ba77eb9b4d2.zip"], # 2021-05-17T00:24:16Z 17 | ) 18 | -------------------------------------------------------------------------------- /thirdparty/googletest/ci/macos-presubmit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2020, Google Inc. 4 | # All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are 8 | # met: 9 | # 10 | # * Redistributions of source code must retain the above copyright 11 | # notice, this list of conditions and the following disclaimer. 12 | # * Redistributions in binary form must reproduce the above 13 | # copyright notice, this list of conditions and the following disclaimer 14 | # in the documentation and/or other materials provided with the 15 | # distribution. 16 | # * Neither the name of Google Inc. nor the names of its 17 | # contributors may be used to endorse or promote products derived from 18 | # this software without specific prior written permission. 19 | # 20 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | set -euox pipefail 33 | 34 | if [[ -z ${GTEST_ROOT:-} ]]; then 35 | GTEST_ROOT="$(realpath $(dirname ${0})/..)" 36 | fi 37 | 38 | # Test the CMake build 39 | for cmake_off_on in OFF ON; do 40 | BUILD_DIR=$(mktemp -d build_dir.XXXXXXXX) 41 | cd ${BUILD_DIR} 42 | time cmake ${GTEST_ROOT} \ 43 | -DCMAKE_CXX_STANDARD=11 \ 44 | -Dgtest_build_samples=ON \ 45 | -Dgtest_build_tests=ON \ 46 | -Dgmock_build_tests=ON \ 47 | -Dcxx_no_exception=${cmake_off_on} \ 48 | -Dcxx_no_rtti=${cmake_off_on} 49 | time make 50 | time ctest -j$(nproc) --output-on-failure 51 | done 52 | 53 | # Test the Bazel build 54 | 55 | # If we are running on Kokoro, check for a versioned Bazel binary. 56 | KOKORO_GFILE_BAZEL_BIN="bazel-3.7.0-darwin-x86_64" 57 | if [[ ${KOKORO_GFILE_DIR:-} ]] && [[ -f ${KOKORO_GFILE_DIR}/${KOKORO_GFILE_BAZEL_BIN} ]]; then 58 | BAZEL_BIN="${KOKORO_GFILE_DIR}/${KOKORO_GFILE_BAZEL_BIN}" 59 | chmod +x ${BAZEL_BIN} 60 | else 61 | BAZEL_BIN="bazel" 62 | fi 63 | 64 | cd ${GTEST_ROOT} 65 | for absl in 0 1; do 66 | ${BAZEL_BIN} test ... \ 67 | --copt="-Wall" \ 68 | --copt="-Werror" \ 69 | --define="absl=${absl}" \ 70 | --keep_going \ 71 | --show_timestamps \ 72 | --test_output=errors 73 | done 74 | -------------------------------------------------------------------------------- /thirdparty/googletest/docs/_config.yml: -------------------------------------------------------------------------------- 1 | title: GoogleTest 2 | -------------------------------------------------------------------------------- /thirdparty/googletest/docs/_data/navigation.yml: -------------------------------------------------------------------------------- 1 | nav: 2 | - section: "Get Started" 3 | items: 4 | - title: "Supported Platforms" 5 | url: "/platforms.html" 6 | - title: "Quickstart: Bazel" 7 | url: "/quickstart-bazel.html" 8 | - title: "Quickstart: CMake" 9 | url: "/quickstart-cmake.html" 10 | - section: "Guides" 11 | items: 12 | - title: "GoogleTest Primer" 13 | url: "/primer.html" 14 | - title: "Advanced Topics" 15 | url: "/advanced.html" 16 | - title: "Mocking for Dummies" 17 | url: "/gmock_for_dummies.html" 18 | - title: "Mocking Cookbook" 19 | url: "/gmock_cook_book.html" 20 | - title: "Mocking Cheat Sheet" 21 | url: "/gmock_cheat_sheet.html" 22 | - section: "References" 23 | items: 24 | - title: "Testing Reference" 25 | url: "/reference/testing.html" 26 | - title: "Mocking Reference" 27 | url: "/reference/mocking.html" 28 | - title: "Assertions" 29 | url: "/reference/assertions.html" 30 | - title: "Matchers" 31 | url: "/reference/matchers.html" 32 | - title: "Actions" 33 | url: "/reference/actions.html" 34 | - title: "Testing FAQ" 35 | url: "/faq.html" 36 | - title: "Mocking FAQ" 37 | url: "/gmock_faq.html" 38 | - title: "Code Samples" 39 | url: "/samples.html" 40 | - title: "Using pkg-config" 41 | url: "/pkgconfig.html" 42 | - title: "Community Documentation" 43 | url: "/community_created_documentation.html" 44 | -------------------------------------------------------------------------------- /thirdparty/googletest/docs/_layouts/default.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | {% seo %} 9 | 10 | 18 | 19 | 20 | 21 | 44 |
45 |
46 | {{ content }} 47 |
48 | 54 |
55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /thirdparty/googletest/docs/assets/css/style.scss: -------------------------------------------------------------------------------- 1 | --- 2 | --- 3 | 4 | @import "jekyll-theme-primer"; 5 | @import "main"; 6 | -------------------------------------------------------------------------------- /thirdparty/googletest/docs/community_created_documentation.md: -------------------------------------------------------------------------------- 1 | # Community-Created Documentation 2 | 3 | The following is a list, in no particular order, of links to documentation 4 | created by the Googletest community. 5 | 6 | * [Googlemock Insights](https://github.com/ElectricRCAircraftGuy/eRCaGuy_dotfiles/blob/master/googletest/insights.md), 7 | by [ElectricRCAircraftGuy](https://github.com/ElectricRCAircraftGuy) 8 | -------------------------------------------------------------------------------- /thirdparty/googletest/docs/index.md: -------------------------------------------------------------------------------- 1 | # GoogleTest User's Guide 2 | 3 | ## Welcome to GoogleTest! 4 | 5 | GoogleTest is Google's C++ testing and mocking framework. This user's guide has 6 | the following contents: 7 | 8 | * [GoogleTest Primer](primer.md) - Teaches you how to write simple tests using 9 | GoogleTest. Read this first if you are new to GoogleTest. 10 | * [GoogleTest Advanced](advanced.md) - Read this when you've finished the 11 | Primer and want to utilize GoogleTest to its full potential. 12 | * [GoogleTest Samples](samples.md) - Describes some GoogleTest samples. 13 | * [GoogleTest FAQ](faq.md) - Have a question? Want some tips? Check here 14 | first. 15 | * [Mocking for Dummies](gmock_for_dummies.md) - Teaches you how to create mock 16 | objects and use them in tests. 17 | * [Mocking Cookbook](gmock_cook_book.md) - Includes tips and approaches to 18 | common mocking use cases. 19 | * [Mocking Cheat Sheet](gmock_cheat_sheet.md) - A handy reference for 20 | matchers, actions, invariants, and more. 21 | * [Mocking FAQ](gmock_faq.md) - Contains answers to some mocking-specific 22 | questions. 23 | -------------------------------------------------------------------------------- /thirdparty/googletest/docs/platforms.md: -------------------------------------------------------------------------------- 1 | # Supported Platforms 2 | 3 | GoogleTest requires a codebase and compiler compliant with the C++11 standard or 4 | newer. 5 | 6 | The GoogleTest code is officially supported on the following platforms. 7 | Operating systems or tools not listed below are community-supported. For 8 | community-supported platforms, patches that do not complicate the code may be 9 | considered. 10 | 11 | If you notice any problems on your platform, please file an issue on the 12 | [GoogleTest GitHub Issue Tracker](https://github.com/google/googletest/issues). 13 | Pull requests containing fixes are welcome! 14 | 15 | ### Operating systems 16 | 17 | * Linux 18 | * macOS 19 | * Windows 20 | 21 | ### Compilers 22 | 23 | * gcc 5.0+ 24 | * clang 5.0+ 25 | * MSVC 2015+ 26 | 27 | **macOS users:** Xcode 9.3+ provides clang 5.0+. 28 | 29 | ### Build systems 30 | 31 | * [Bazel](https://bazel.build/) 32 | * [CMake](https://cmake.org/) 33 | 34 | Bazel is the build system used by the team internally and in tests. CMake is 35 | supported on a best-effort basis and by the community. 36 | -------------------------------------------------------------------------------- /thirdparty/googletest/docs/samples.md: -------------------------------------------------------------------------------- 1 | # Googletest Samples 2 | 3 | If you're like us, you'd like to look at 4 | [googletest samples.](https://github.com/google/googletest/tree/master/googletest/samples) 5 | The sample directory has a number of well-commented samples showing how to use a 6 | variety of googletest features. 7 | 8 | * Sample #1 shows the basic steps of using googletest to test C++ functions. 9 | * Sample #2 shows a more complex unit test for a class with multiple member 10 | functions. 11 | * Sample #3 uses a test fixture. 12 | * Sample #4 teaches you how to use googletest and `googletest.h` together to 13 | get the best of both libraries. 14 | * Sample #5 puts shared testing logic in a base test fixture, and reuses it in 15 | derived fixtures. 16 | * Sample #6 demonstrates type-parameterized tests. 17 | * Sample #7 teaches the basics of value-parameterized tests. 18 | * Sample #8 shows using `Combine()` in value-parameterized tests. 19 | * Sample #9 shows use of the listener API to modify Google Test's console 20 | output and the use of its reflection API to inspect test results. 21 | * Sample #10 shows use of the listener API to implement a primitive memory 22 | leak checker. 23 | -------------------------------------------------------------------------------- /thirdparty/googletest/googlemock/README.md: -------------------------------------------------------------------------------- 1 | # Googletest Mocking (gMock) Framework 2 | 3 | ### Overview 4 | 5 | Google's framework for writing and using C++ mock classes. It can help you 6 | derive better designs of your system and write better tests. 7 | 8 | It is inspired by: 9 | 10 | * [jMock](http://www.jmock.org/) 11 | * [EasyMock](http://www.easymock.org/) 12 | * [Hamcrest](http://code.google.com/p/hamcrest/) 13 | 14 | It is designed with C++'s specifics in mind. 15 | 16 | gMock: 17 | 18 | - Provides a declarative syntax for defining mocks. 19 | - Can define partial (hybrid) mocks, which are a cross of real and mock 20 | objects. 21 | - Handles functions of arbitrary types and overloaded functions. 22 | - Comes with a rich set of matchers for validating function arguments. 23 | - Uses an intuitive syntax for controlling the behavior of a mock. 24 | - Does automatic verification of expectations (no record-and-replay needed). 25 | - Allows arbitrary (partial) ordering constraints on function calls to be 26 | expressed. 27 | - Lets a user extend it by defining new matchers and actions. 28 | - Does not use exceptions. 29 | - Is easy to learn and use. 30 | 31 | Details and examples can be found here: 32 | 33 | * [gMock for Dummies](https://google.github.io/googletest/gmock_for_dummies.html) 34 | * [Legacy gMock FAQ](https://google.github.io/googletest/gmock_faq.html) 35 | * [gMock Cookbook](https://google.github.io/googletest/gmock_cook_book.html) 36 | * [gMock Cheat Sheet](https://google.github.io/googletest/gmock_cheat_sheet.html) 37 | 38 | GoogleMock is a part of 39 | [GoogleTest C++ testing framework](http://github.com/google/googletest/) and a 40 | subject to the same requirements. 41 | -------------------------------------------------------------------------------- /thirdparty/googletest/googlemock/cmake/gmock.pc.in: -------------------------------------------------------------------------------- 1 | libdir=@CMAKE_INSTALL_FULL_LIBDIR@ 2 | includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ 3 | 4 | Name: gmock 5 | Description: GoogleMock (without main() function) 6 | Version: @PROJECT_VERSION@ 7 | URL: https://github.com/google/googletest 8 | Requires: gtest = @PROJECT_VERSION@ 9 | Libs: -L${libdir} -lgmock @CMAKE_THREAD_LIBS_INIT@ 10 | Cflags: -I${includedir} @GTEST_HAS_PTHREAD_MACRO@ 11 | -------------------------------------------------------------------------------- /thirdparty/googletest/googlemock/cmake/gmock_main.pc.in: -------------------------------------------------------------------------------- 1 | libdir=@CMAKE_INSTALL_FULL_LIBDIR@ 2 | includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ 3 | 4 | Name: gmock_main 5 | Description: GoogleMock (with main() function) 6 | Version: @PROJECT_VERSION@ 7 | URL: https://github.com/google/googletest 8 | Requires: gmock = @PROJECT_VERSION@ 9 | Libs: -L${libdir} -lgmock_main @CMAKE_THREAD_LIBS_INIT@ 10 | Cflags: -I${includedir} @GTEST_HAS_PTHREAD_MACRO@ 11 | -------------------------------------------------------------------------------- /thirdparty/googletest/googlemock/docs/README.md: -------------------------------------------------------------------------------- 1 | # Content Moved 2 | 3 | We are working on updates to the GoogleTest documentation, which has moved to 4 | the top-level [docs](../../docs) directory. 5 | -------------------------------------------------------------------------------- /thirdparty/googletest/googlemock/include/gmock/gmock-more-matchers.h: -------------------------------------------------------------------------------- 1 | // Copyright 2013, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | // Google Mock - a framework for writing C++ mock classes. 32 | // 33 | // This file implements some matchers that depend on gmock-matchers.h. 34 | // 35 | // Note that tests are implemented in gmock-matchers_test.cc rather than 36 | // gmock-more-matchers-test.cc. 37 | 38 | #ifndef GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_MORE_MATCHERS_H_ 39 | #define GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_MORE_MATCHERS_H_ 40 | 41 | #include "gmock/gmock-matchers.h" 42 | 43 | namespace testing { 44 | 45 | // Silence C4100 (unreferenced formal 46 | // parameter) for MSVC 47 | #ifdef _MSC_VER 48 | # pragma warning(push) 49 | # pragma warning(disable:4100) 50 | #if (_MSC_VER == 1900) 51 | // and silence C4800 (C4800: 'int *const ': forcing value 52 | // to bool 'true' or 'false') for MSVC 14 53 | # pragma warning(disable:4800) 54 | #endif 55 | #endif 56 | 57 | // Defines a matcher that matches an empty container. The container must 58 | // support both size() and empty(), which all STL-like containers provide. 59 | MATCHER(IsEmpty, negation ? "isn't empty" : "is empty") { 60 | if (arg.empty()) { 61 | return true; 62 | } 63 | *result_listener << "whose size is " << arg.size(); 64 | return false; 65 | } 66 | 67 | // Define a matcher that matches a value that evaluates in boolean 68 | // context to true. Useful for types that define "explicit operator 69 | // bool" operators and so can't be compared for equality with true 70 | // and false. 71 | MATCHER(IsTrue, negation ? "is false" : "is true") { 72 | return static_cast(arg); 73 | } 74 | 75 | // Define a matcher that matches a value that evaluates in boolean 76 | // context to false. Useful for types that define "explicit operator 77 | // bool" operators and so can't be compared for equality with true 78 | // and false. 79 | MATCHER(IsFalse, negation ? "is true" : "is false") { 80 | return !static_cast(arg); 81 | } 82 | 83 | #ifdef _MSC_VER 84 | # pragma warning(pop) 85 | #endif 86 | 87 | 88 | } // namespace testing 89 | 90 | #endif // GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_MORE_MATCHERS_H_ 91 | -------------------------------------------------------------------------------- /thirdparty/googletest/googlemock/include/gmock/internal/custom/README.md: -------------------------------------------------------------------------------- 1 | # Customization Points 2 | 3 | The custom directory is an injection point for custom user configurations. 4 | 5 | ## Header `gmock-port.h` 6 | 7 | The following macros can be defined: 8 | 9 | ### Flag related macros: 10 | 11 | * `GMOCK_DECLARE_bool_(name)` 12 | * `GMOCK_DECLARE_int32_(name)` 13 | * `GMOCK_DECLARE_string_(name)` 14 | * `GMOCK_DEFINE_bool_(name, default_val, doc)` 15 | * `GMOCK_DEFINE_int32_(name, default_val, doc)` 16 | * `GMOCK_DEFINE_string_(name, default_val, doc)` 17 | * `GMOCK_FLAG_GET(flag_name)` 18 | * `GMOCK_FLAG_SET(flag_name, value)` 19 | -------------------------------------------------------------------------------- /thirdparty/googletest/googlemock/include/gmock/internal/custom/gmock-generated-actions.h: -------------------------------------------------------------------------------- 1 | #ifndef GOOGLEMOCK_INCLUDE_GMOCK_INTERNAL_CUSTOM_GMOCK_GENERATED_ACTIONS_H_ 2 | #define GOOGLEMOCK_INCLUDE_GMOCK_INTERNAL_CUSTOM_GMOCK_GENERATED_ACTIONS_H_ 3 | 4 | #endif // GOOGLEMOCK_INCLUDE_GMOCK_INTERNAL_CUSTOM_GMOCK_GENERATED_ACTIONS_H_ 5 | -------------------------------------------------------------------------------- /thirdparty/googletest/googlemock/include/gmock/internal/custom/gmock-matchers.h: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | // 30 | // Injection point for custom user configurations. See README for details 31 | 32 | #ifndef GOOGLEMOCK_INCLUDE_GMOCK_INTERNAL_CUSTOM_GMOCK_MATCHERS_H_ 33 | #define GOOGLEMOCK_INCLUDE_GMOCK_INTERNAL_CUSTOM_GMOCK_MATCHERS_H_ 34 | #endif // GOOGLEMOCK_INCLUDE_GMOCK_INTERNAL_CUSTOM_GMOCK_MATCHERS_H_ 35 | -------------------------------------------------------------------------------- /thirdparty/googletest/googlemock/include/gmock/internal/custom/gmock-port.h: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | // 30 | // Injection point for custom user configurations. See README for details 31 | // 32 | // ** Custom implementation starts here ** 33 | 34 | #ifndef GOOGLEMOCK_INCLUDE_GMOCK_INTERNAL_CUSTOM_GMOCK_PORT_H_ 35 | #define GOOGLEMOCK_INCLUDE_GMOCK_INTERNAL_CUSTOM_GMOCK_PORT_H_ 36 | 37 | #endif // GOOGLEMOCK_INCLUDE_GMOCK_INTERNAL_CUSTOM_GMOCK_PORT_H_ 38 | -------------------------------------------------------------------------------- /thirdparty/googletest/googlemock/src/gmock-all.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2008, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | // 31 | // Google C++ Mocking Framework (Google Mock) 32 | // 33 | // This file #includes all Google Mock implementation .cc files. The 34 | // purpose is to allow a user to build Google Mock by compiling this 35 | // file alone. 36 | 37 | // This line ensures that gmock.h can be compiled on its own, even 38 | // when it's fused. 39 | #include "gmock/gmock.h" 40 | 41 | // The following lines pull in the real gmock *.cc files. 42 | #include "src/gmock-cardinalities.cc" 43 | #include "src/gmock-internal-utils.cc" 44 | #include "src/gmock-matchers.cc" 45 | #include "src/gmock-spec-builders.cc" 46 | #include "src/gmock.cc" 47 | -------------------------------------------------------------------------------- /thirdparty/googletest/googlemock/src/gmock_main.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2008, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | #include 32 | #include "gmock/gmock.h" 33 | #include "gtest/gtest.h" 34 | 35 | #if GTEST_OS_ESP8266 || GTEST_OS_ESP32 36 | #if GTEST_OS_ESP8266 37 | extern "C" { 38 | #endif 39 | void setup() { 40 | // Since Google Mock depends on Google Test, InitGoogleMock() is 41 | // also responsible for initializing Google Test. Therefore there's 42 | // no need for calling testing::InitGoogleTest() separately. 43 | testing::InitGoogleMock(); 44 | } 45 | void loop() { RUN_ALL_TESTS(); } 46 | #if GTEST_OS_ESP8266 47 | } 48 | #endif 49 | 50 | #else 51 | 52 | // MS C++ compiler/linker has a bug on Windows (not on Windows CE), which 53 | // causes a link error when _tmain is defined in a static library and UNICODE 54 | // is enabled. For this reason instead of _tmain, main function is used on 55 | // Windows. See the following link to track the current status of this bug: 56 | // https://web.archive.org/web/20170912203238/connect.microsoft.com/VisualStudio/feedback/details/394464/wmain-link-error-in-the-static-library 57 | // // NOLINT 58 | #if GTEST_OS_WINDOWS_MOBILE 59 | # include // NOLINT 60 | 61 | GTEST_API_ int _tmain(int argc, TCHAR** argv) { 62 | #else 63 | GTEST_API_ int main(int argc, char** argv) { 64 | #endif // GTEST_OS_WINDOWS_MOBILE 65 | std::cout << "Running main() from gmock_main.cc\n"; 66 | // Since Google Mock depends on Google Test, InitGoogleMock() is 67 | // also responsible for initializing Google Test. Therefore there's 68 | // no need for calling testing::InitGoogleTest() separately. 69 | testing::InitGoogleMock(&argc, argv); 70 | return RUN_ALL_TESTS(); 71 | } 72 | #endif 73 | -------------------------------------------------------------------------------- /thirdparty/googletest/googlemock/test/gmock-port_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2008, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | // Google Mock - a framework for writing C++ mock classes. 32 | // 33 | // This file tests the internal cross-platform support utilities. 34 | 35 | #include "gmock/internal/gmock-port.h" 36 | #include "gtest/gtest.h" 37 | 38 | // NOTE: if this file is left without tests for some reason, put a dummy 39 | // test here to make references to symbols in the gtest library and avoid 40 | // 'undefined symbol' linker errors in gmock_main: 41 | 42 | TEST(DummyTest, Dummy) {} 43 | -------------------------------------------------------------------------------- /thirdparty/googletest/googlemock/test/gmock_all_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2009, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | // 31 | // Tests for Google C++ Mocking Framework (Google Mock) 32 | // 33 | // Some users use a build system that Google Mock doesn't support directly, 34 | // yet they still want to build and run Google Mock's own tests. This file 35 | // includes most such tests, making it easier for these users to maintain 36 | // their build scripts (they just need to build this file, even though the 37 | // below list of actual *_test.cc files might change). 38 | #include "test/gmock-actions_test.cc" 39 | #include "test/gmock-cardinalities_test.cc" 40 | #include "test/gmock-internal-utils_test.cc" 41 | #include "test/gmock-matchers_test.cc" 42 | #include "test/gmock-more-actions_test.cc" 43 | #include "test/gmock-nice-strict_test.cc" 44 | #include "test/gmock-port_test.cc" 45 | #include "test/gmock-spec-builders_test.cc" 46 | #include "test/gmock_test.cc" 47 | -------------------------------------------------------------------------------- /thirdparty/googletest/googlemock/test/gmock_ex_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2013, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | // Tests Google Mock's functionality that depends on exceptions. 32 | 33 | #include "gmock/gmock.h" 34 | #include "gtest/gtest.h" 35 | 36 | #if GTEST_HAS_EXCEPTIONS 37 | namespace { 38 | 39 | using testing::HasSubstr; 40 | 41 | using testing::internal::GoogleTestFailureException; 42 | 43 | // A type that cannot be default constructed. 44 | class NonDefaultConstructible { 45 | public: 46 | explicit NonDefaultConstructible(int /* dummy */) {} 47 | }; 48 | 49 | class MockFoo { 50 | public: 51 | // A mock method that returns a user-defined type. Google Mock 52 | // doesn't know what the default value for this type is. 53 | MOCK_METHOD0(GetNonDefaultConstructible, NonDefaultConstructible()); 54 | }; 55 | 56 | TEST(DefaultValueTest, ThrowsRuntimeErrorWhenNoDefaultValue) { 57 | MockFoo mock; 58 | try { 59 | // No expectation is set on this method, so Google Mock must 60 | // return the default value. However, since Google Mock knows 61 | // nothing about the return type, it doesn't know what to return, 62 | // and has to throw (when exceptions are enabled) or abort 63 | // (otherwise). 64 | mock.GetNonDefaultConstructible(); 65 | FAIL() << "GetNonDefaultConstructible()'s return type has no default " 66 | << "value, so Google Mock should have thrown."; 67 | } catch (const GoogleTestFailureException& /* unused */) { 68 | FAIL() << "Google Test does not try to catch an exception of type " 69 | << "GoogleTestFailureException, which is used for reporting " 70 | << "a failure to other testing frameworks. Google Mock should " 71 | << "not throw a GoogleTestFailureException as it will kill the " 72 | << "entire test program instead of just the current TEST."; 73 | } catch (const std::exception& ex) { 74 | EXPECT_THAT(ex.what(), HasSubstr("has no default value")); 75 | } 76 | } 77 | 78 | 79 | } // unnamed namespace 80 | #endif 81 | -------------------------------------------------------------------------------- /thirdparty/googletest/googlemock/test/gmock_leak_test_.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2009, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | // Google Mock - a framework for writing C++ mock classes. 32 | // 33 | // This program is for verifying that a leaked mock object can be 34 | // caught by Google Mock's leak detector. 35 | 36 | #include "gmock/gmock.h" 37 | 38 | namespace { 39 | 40 | using ::testing::Return; 41 | 42 | class FooInterface { 43 | public: 44 | virtual ~FooInterface() {} 45 | virtual void DoThis() = 0; 46 | }; 47 | 48 | class MockFoo : public FooInterface { 49 | public: 50 | MockFoo() {} 51 | 52 | MOCK_METHOD0(DoThis, void()); 53 | 54 | private: 55 | GTEST_DISALLOW_COPY_AND_ASSIGN_(MockFoo); 56 | }; 57 | 58 | TEST(LeakTest, LeakedMockWithExpectCallCausesFailureWhenLeakCheckingIsEnabled) { 59 | MockFoo* foo = new MockFoo; 60 | 61 | EXPECT_CALL(*foo, DoThis()); 62 | foo->DoThis(); 63 | 64 | // In order to test the leak detector, we deliberately leak foo. 65 | 66 | // Makes sure Google Mock's leak detector can change the exit code 67 | // to 1 even when the code is already exiting with 0. 68 | exit(0); 69 | } 70 | 71 | TEST(LeakTest, LeakedMockWithOnCallCausesFailureWhenLeakCheckingIsEnabled) { 72 | MockFoo* foo = new MockFoo; 73 | 74 | ON_CALL(*foo, DoThis()).WillByDefault(Return()); 75 | 76 | // In order to test the leak detector, we deliberately leak foo. 77 | 78 | // Makes sure Google Mock's leak detector can change the exit code 79 | // to 1 even when the code is already exiting with 0. 80 | exit(0); 81 | } 82 | 83 | TEST(LeakTest, CatchesMultipleLeakedMockObjects) { 84 | MockFoo* foo1 = new MockFoo; 85 | MockFoo* foo2 = new MockFoo; 86 | 87 | ON_CALL(*foo1, DoThis()).WillByDefault(Return()); 88 | EXPECT_CALL(*foo2, DoThis()); 89 | foo2->DoThis(); 90 | 91 | // In order to test the leak detector, we deliberately leak foo1 and 92 | // foo2. 93 | 94 | // Makes sure Google Mock's leak detector can change the exit code 95 | // to 1 even when the code is already exiting with 0. 96 | exit(0); 97 | } 98 | 99 | } // namespace 100 | -------------------------------------------------------------------------------- /thirdparty/googletest/googlemock/test/gmock_link2_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2008, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | // Google Mock - a framework for writing C++ mock classes. 32 | // 33 | // This file is for verifying that various Google Mock constructs do not 34 | // produce linker errors when instantiated in different translation units. 35 | // Please see gmock_link_test.h for details. 36 | 37 | #define LinkTest LinkTest2 38 | 39 | #include "test/gmock_link_test.h" 40 | -------------------------------------------------------------------------------- /thirdparty/googletest/googlemock/test/gmock_link_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2008, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | // Google Mock - a framework for writing C++ mock classes. 32 | // 33 | // This file is for verifying that various Google Mock constructs do not 34 | // produce linker errors when instantiated in different translation units. 35 | // Please see gmock_link_test.h for details. 36 | 37 | #define LinkTest LinkTest1 38 | 39 | #include "test/gmock_link_test.h" 40 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/cmake/Config.cmake.in: -------------------------------------------------------------------------------- 1 | @PACKAGE_INIT@ 2 | include(CMakeFindDependencyMacro) 3 | if (@GTEST_HAS_PTHREAD@) 4 | set(THREADS_PREFER_PTHREAD_FLAG @THREADS_PREFER_PTHREAD_FLAG@) 5 | find_dependency(Threads) 6 | endif() 7 | 8 | include("${CMAKE_CURRENT_LIST_DIR}/@targets_export_name@.cmake") 9 | check_required_components("@project_name@") 10 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/cmake/gtest.pc.in: -------------------------------------------------------------------------------- 1 | libdir=@CMAKE_INSTALL_FULL_LIBDIR@ 2 | includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ 3 | 4 | Name: gtest 5 | Description: GoogleTest (without main() function) 6 | Version: @PROJECT_VERSION@ 7 | URL: https://github.com/google/googletest 8 | Libs: -L${libdir} -lgtest @CMAKE_THREAD_LIBS_INIT@ 9 | Cflags: -I${includedir} @GTEST_HAS_PTHREAD_MACRO@ 10 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/cmake/gtest_main.pc.in: -------------------------------------------------------------------------------- 1 | libdir=@CMAKE_INSTALL_FULL_LIBDIR@ 2 | includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ 3 | 4 | Name: gtest_main 5 | Description: GoogleTest (with main() function) 6 | Version: @PROJECT_VERSION@ 7 | URL: https://github.com/google/googletest 8 | Requires: gtest = @PROJECT_VERSION@ 9 | Libs: -L${libdir} -lgtest_main @CMAKE_THREAD_LIBS_INIT@ 10 | Cflags: -I${includedir} @GTEST_HAS_PTHREAD_MACRO@ 11 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/cmake/libgtest.la.in: -------------------------------------------------------------------------------- 1 | # libgtest.la - a libtool library file 2 | # Generated by libtool (GNU libtool) 2.4.6 3 | 4 | # Please DO NOT delete this file! 5 | # It is necessary for linking the library. 6 | 7 | # Names of this library. 8 | library_names='libgtest.so' 9 | 10 | # Is this an already installed library? 11 | installed=yes 12 | 13 | # Should we warn about portability when linking against -modules? 14 | shouldnotlink=no 15 | 16 | # Files to dlopen/dlpreopen 17 | dlopen='' 18 | dlpreopen='' 19 | 20 | # Directory that this library needs to be installed in: 21 | libdir='@CMAKE_INSTALL_FULL_LIBDIR@' 22 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/docs/README.md: -------------------------------------------------------------------------------- 1 | # Content Moved 2 | 3 | We are working on updates to the GoogleTest documentation, which has moved to 4 | the top-level [docs](../../docs) directory. 5 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/include/gtest/gtest_prod.h: -------------------------------------------------------------------------------- 1 | // Copyright 2006, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | // 31 | // Google C++ Testing and Mocking Framework definitions useful in production code. 32 | 33 | #ifndef GOOGLETEST_INCLUDE_GTEST_GTEST_PROD_H_ 34 | #define GOOGLETEST_INCLUDE_GTEST_GTEST_PROD_H_ 35 | 36 | // When you need to test the private or protected members of a class, 37 | // use the FRIEND_TEST macro to declare your tests as friends of the 38 | // class. For example: 39 | // 40 | // class MyClass { 41 | // private: 42 | // void PrivateMethod(); 43 | // FRIEND_TEST(MyClassTest, PrivateMethodWorks); 44 | // }; 45 | // 46 | // class MyClassTest : public testing::Test { 47 | // // ... 48 | // }; 49 | // 50 | // TEST_F(MyClassTest, PrivateMethodWorks) { 51 | // // Can call MyClass::PrivateMethod() here. 52 | // } 53 | // 54 | // Note: The test class must be in the same namespace as the class being tested. 55 | // For example, putting MyClassTest in an anonymous namespace will not work. 56 | 57 | #define FRIEND_TEST(test_case_name, test_name)\ 58 | friend class test_case_name##_##test_name##_Test 59 | 60 | #endif // GOOGLETEST_INCLUDE_GTEST_GTEST_PROD_H_ 61 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/include/gtest/internal/custom/README.md: -------------------------------------------------------------------------------- 1 | # Customization Points 2 | 3 | The custom directory is an injection point for custom user configurations. 4 | 5 | ## Header `gtest.h` 6 | 7 | ### The following macros can be defined: 8 | 9 | * `GTEST_OS_STACK_TRACE_GETTER_` - The name of an implementation of 10 | `OsStackTraceGetterInterface`. 11 | * `GTEST_CUSTOM_TEMPDIR_FUNCTION_` - An override for `testing::TempDir()`. See 12 | `testing::TempDir` for semantics and signature. 13 | 14 | ## Header `gtest-port.h` 15 | 16 | The following macros can be defined: 17 | 18 | ### Flag related macros: 19 | 20 | * `GTEST_FLAG(flag_name)` 21 | * `GTEST_USE_OWN_FLAGFILE_FLAG_` - Define to 0 when the system provides its 22 | own flagfile flag parsing. 23 | * `GTEST_DECLARE_bool_(name)` 24 | * `GTEST_DECLARE_int32_(name)` 25 | * `GTEST_DECLARE_string_(name)` 26 | * `GTEST_DEFINE_bool_(name, default_val, doc)` 27 | * `GTEST_DEFINE_int32_(name, default_val, doc)` 28 | * `GTEST_DEFINE_string_(name, default_val, doc)` 29 | * `GTEST_FLAG_GET(flag_name)` 30 | * `GTEST_FLAG_SET(flag_name, value)` 31 | 32 | ### Logging: 33 | 34 | * `GTEST_LOG_(severity)` 35 | * `GTEST_CHECK_(condition)` 36 | * Functions `LogToStderr()` and `FlushInfoLog()` have to be provided too. 37 | 38 | ### Threading: 39 | 40 | * `GTEST_HAS_NOTIFICATION_` - Enabled if Notification is already provided. 41 | * `GTEST_HAS_MUTEX_AND_THREAD_LOCAL_` - Enabled if `Mutex` and `ThreadLocal` 42 | are already provided. Must also provide `GTEST_DECLARE_STATIC_MUTEX_(mutex)` 43 | and `GTEST_DEFINE_STATIC_MUTEX_(mutex)` 44 | * `GTEST_EXCLUSIVE_LOCK_REQUIRED_(locks)` 45 | * `GTEST_LOCK_EXCLUDED_(locks)` 46 | 47 | ### Underlying library support features 48 | 49 | * `GTEST_HAS_CXXABI_H_` 50 | 51 | ### Exporting API symbols: 52 | 53 | * `GTEST_API_` - Specifier for exported symbols. 54 | 55 | ## Header `gtest-printers.h` 56 | 57 | * See documentation at `gtest/gtest-printers.h` for details on how to define a 58 | custom printer. 59 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/include/gtest/internal/custom/gtest-port.h: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | // 30 | // Injection point for custom user configurations. See README for details 31 | // 32 | // ** Custom implementation starts here ** 33 | 34 | #ifndef GOOGLETEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_PORT_H_ 35 | #define GOOGLETEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_PORT_H_ 36 | 37 | #endif // GOOGLETEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_PORT_H_ 38 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/include/gtest/internal/custom/gtest-printers.h: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | // 30 | // This file provides an injection point for custom printers in a local 31 | // installation of gTest. 32 | // It will be included from gtest-printers.h and the overrides in this file 33 | // will be visible to everyone. 34 | // 35 | // Injection point for custom user configurations. See README for details 36 | // 37 | // ** Custom implementation starts here ** 38 | 39 | #ifndef GOOGLETEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_PRINTERS_H_ 40 | #define GOOGLETEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_PRINTERS_H_ 41 | 42 | #endif // GOOGLETEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_PRINTERS_H_ 43 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/include/gtest/internal/custom/gtest.h: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | // 30 | // Injection point for custom user configurations. See README for details 31 | // 32 | // ** Custom implementation starts here ** 33 | 34 | #ifndef GOOGLETEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_H_ 35 | #define GOOGLETEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_H_ 36 | 37 | #endif // GOOGLETEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_H_ 38 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/samples/sample1.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2005, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | // A sample program demonstrating using Google C++ testing framework. 31 | 32 | #include "sample1.h" 33 | 34 | // Returns n! (the factorial of n). For negative n, n! is defined to be 1. 35 | int Factorial(int n) { 36 | int result = 1; 37 | for (int i = 1; i <= n; i++) { 38 | result *= i; 39 | } 40 | 41 | return result; 42 | } 43 | 44 | // Returns true if and only if n is a prime number. 45 | bool IsPrime(int n) { 46 | // Trivial case 1: small numbers 47 | if (n <= 1) return false; 48 | 49 | // Trivial case 2: even numbers 50 | if (n % 2 == 0) return n == 2; 51 | 52 | // Now, we have that n is odd and n >= 3. 53 | 54 | // Try to divide n by every odd number i, starting from 3 55 | for (int i = 3; ; i += 2) { 56 | // We only have to try i up to the square root of n 57 | if (i > n/i) break; 58 | 59 | // Now, we have i <= n/i < n. 60 | // If n is divisible by i, n is not prime. 61 | if (n % i == 0) return false; 62 | } 63 | 64 | // n has no integer factor in the range (1, n), and thus is prime. 65 | return true; 66 | } 67 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/samples/sample1.h: -------------------------------------------------------------------------------- 1 | // Copyright 2005, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | // A sample program demonstrating using Google C++ testing framework. 31 | 32 | #ifndef GOOGLETEST_SAMPLES_SAMPLE1_H_ 33 | #define GOOGLETEST_SAMPLES_SAMPLE1_H_ 34 | 35 | // Returns n! (the factorial of n). For negative n, n! is defined to be 1. 36 | int Factorial(int n); 37 | 38 | // Returns true if and only if n is a prime number. 39 | bool IsPrime(int n); 40 | 41 | #endif // GOOGLETEST_SAMPLES_SAMPLE1_H_ 42 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/samples/sample2.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2005, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | // A sample program demonstrating using Google C++ testing framework. 31 | 32 | #include "sample2.h" 33 | 34 | #include 35 | 36 | // Clones a 0-terminated C string, allocating memory using new. 37 | const char* MyString::CloneCString(const char* a_c_string) { 38 | if (a_c_string == nullptr) return nullptr; 39 | 40 | const size_t len = strlen(a_c_string); 41 | char* const clone = new char[ len + 1 ]; 42 | memcpy(clone, a_c_string, len + 1); 43 | 44 | return clone; 45 | } 46 | 47 | // Sets the 0-terminated C string this MyString object 48 | // represents. 49 | void MyString::Set(const char* a_c_string) { 50 | // Makes sure this works when c_string == c_string_ 51 | const char* const temp = MyString::CloneCString(a_c_string); 52 | delete[] c_string_; 53 | c_string_ = temp; 54 | } 55 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/samples/sample2.h: -------------------------------------------------------------------------------- 1 | // Copyright 2005, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | // A sample program demonstrating using Google C++ testing framework. 31 | 32 | #ifndef GOOGLETEST_SAMPLES_SAMPLE2_H_ 33 | #define GOOGLETEST_SAMPLES_SAMPLE2_H_ 34 | 35 | #include 36 | 37 | 38 | // A simple string class. 39 | class MyString { 40 | private: 41 | const char* c_string_; 42 | const MyString& operator=(const MyString& rhs); 43 | 44 | public: 45 | // Clones a 0-terminated C string, allocating memory using new. 46 | static const char* CloneCString(const char* a_c_string); 47 | 48 | //////////////////////////////////////////////////////////// 49 | // 50 | // C'tors 51 | 52 | // The default c'tor constructs a NULL string. 53 | MyString() : c_string_(nullptr) {} 54 | 55 | // Constructs a MyString by cloning a 0-terminated C string. 56 | explicit MyString(const char* a_c_string) : c_string_(nullptr) { 57 | Set(a_c_string); 58 | } 59 | 60 | // Copy c'tor 61 | MyString(const MyString& string) : c_string_(nullptr) { 62 | Set(string.c_string_); 63 | } 64 | 65 | //////////////////////////////////////////////////////////// 66 | // 67 | // D'tor. MyString is intended to be a final class, so the d'tor 68 | // doesn't need to be virtual. 69 | ~MyString() { delete[] c_string_; } 70 | 71 | // Gets the 0-terminated C string this MyString object represents. 72 | const char* c_string() const { return c_string_; } 73 | 74 | size_t Length() const { return c_string_ == nullptr ? 0 : strlen(c_string_); } 75 | 76 | // Sets the 0-terminated C string this MyString object represents. 77 | void Set(const char* c_string); 78 | }; 79 | 80 | #endif // GOOGLETEST_SAMPLES_SAMPLE2_H_ 81 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/samples/sample4.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2005, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | // A sample program demonstrating using Google C++ testing framework. 31 | 32 | #include 33 | 34 | #include "sample4.h" 35 | 36 | // Returns the current counter value, and increments it. 37 | int Counter::Increment() { 38 | return counter_++; 39 | } 40 | 41 | // Returns the current counter value, and decrements it. 42 | // counter can not be less than 0, return 0 in this case 43 | int Counter::Decrement() { 44 | if (counter_ == 0) { 45 | return counter_; 46 | } else { 47 | return counter_--; 48 | } 49 | } 50 | 51 | // Prints the current counter value to STDOUT. 52 | void Counter::Print() const { 53 | printf("%d", counter_); 54 | } 55 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/samples/sample4.h: -------------------------------------------------------------------------------- 1 | // Copyright 2005, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | // A sample program demonstrating using Google C++ testing framework. 31 | #ifndef GOOGLETEST_SAMPLES_SAMPLE4_H_ 32 | #define GOOGLETEST_SAMPLES_SAMPLE4_H_ 33 | 34 | // A simple monotonic counter. 35 | class Counter { 36 | private: 37 | int counter_; 38 | 39 | public: 40 | // Creates a counter that starts at 0. 41 | Counter() : counter_(0) {} 42 | 43 | // Returns the current counter value, and increments it. 44 | int Increment(); 45 | 46 | // Returns the current counter value, and decrements it. 47 | int Decrement(); 48 | 49 | // Prints the current counter value to STDOUT. 50 | void Print() const; 51 | }; 52 | 53 | #endif // GOOGLETEST_SAMPLES_SAMPLE4_H_ 54 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/samples/sample4_unittest.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2005, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | #include "sample4.h" 32 | #include "gtest/gtest.h" 33 | 34 | namespace { 35 | // Tests the Increment() method. 36 | 37 | TEST(Counter, Increment) { 38 | Counter c; 39 | 40 | // Test that counter 0 returns 0 41 | EXPECT_EQ(0, c.Decrement()); 42 | 43 | // EXPECT_EQ() evaluates its arguments exactly once, so they 44 | // can have side effects. 45 | 46 | EXPECT_EQ(0, c.Increment()); 47 | EXPECT_EQ(1, c.Increment()); 48 | EXPECT_EQ(2, c.Increment()); 49 | 50 | EXPECT_EQ(3, c.Decrement()); 51 | } 52 | 53 | } // namespace 54 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/src/gtest-all.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2008, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | // 31 | // Google C++ Testing and Mocking Framework (Google Test) 32 | // 33 | // Sometimes it's desirable to build Google Test by compiling a single file. 34 | // This file serves this purpose. 35 | 36 | // This line ensures that gtest.h can be compiled on its own, even 37 | // when it's fused. 38 | #include "gtest/gtest.h" 39 | 40 | // The following lines pull in the real gtest *.cc files. 41 | #include "src/gtest.cc" 42 | #include "src/gtest-death-test.cc" 43 | #include "src/gtest-filepath.cc" 44 | #include "src/gtest-matchers.cc" 45 | #include "src/gtest-port.cc" 46 | #include "src/gtest-printers.cc" 47 | #include "src/gtest-test-part.cc" 48 | #include "src/gtest-typed-test.cc" 49 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/src/gtest_main.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2006, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | #include 31 | #include "gtest/gtest.h" 32 | 33 | #if GTEST_OS_ESP8266 || GTEST_OS_ESP32 34 | #if GTEST_OS_ESP8266 35 | extern "C" { 36 | #endif 37 | void setup() { 38 | testing::InitGoogleTest(); 39 | } 40 | 41 | void loop() { RUN_ALL_TESTS(); } 42 | 43 | #if GTEST_OS_ESP8266 44 | } 45 | #endif 46 | 47 | #else 48 | 49 | GTEST_API_ int main(int argc, char **argv) { 50 | printf("Running main() from %s\n", __FILE__); 51 | testing::InitGoogleTest(&argc, argv); 52 | return RUN_ALL_TESTS(); 53 | } 54 | #endif 55 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/googletest-break-on-failure-unittest_.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2006, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | // Unit test for Google Test's break-on-failure mode. 32 | // 33 | // A user can ask Google Test to seg-fault when an assertion fails, using 34 | // either the GTEST_BREAK_ON_FAILURE environment variable or the 35 | // --gtest_break_on_failure flag. This file is used for testing such 36 | // functionality. 37 | // 38 | // This program will be invoked from a Python unit test. It is 39 | // expected to fail. Don't run it directly. 40 | 41 | #include "gtest/gtest.h" 42 | 43 | #if GTEST_OS_WINDOWS 44 | # include 45 | # include 46 | #endif 47 | 48 | namespace { 49 | 50 | // A test that's expected to fail. 51 | TEST(Foo, Bar) { 52 | EXPECT_EQ(2, 3); 53 | } 54 | 55 | #if GTEST_HAS_SEH && !GTEST_OS_WINDOWS_MOBILE 56 | // On Windows Mobile global exception handlers are not supported. 57 | LONG WINAPI ExitWithExceptionCode( 58 | struct _EXCEPTION_POINTERS* exception_pointers) { 59 | exit(exception_pointers->ExceptionRecord->ExceptionCode); 60 | } 61 | #endif 62 | 63 | } // namespace 64 | 65 | int main(int argc, char **argv) { 66 | #if GTEST_OS_WINDOWS 67 | // Suppresses display of the Windows error dialog upon encountering 68 | // a general protection fault (segment violation). 69 | SetErrorMode(SEM_NOGPFAULTERRORBOX | SEM_FAILCRITICALERRORS); 70 | 71 | # if GTEST_HAS_SEH && !GTEST_OS_WINDOWS_MOBILE 72 | 73 | // The default unhandled exception filter does not always exit 74 | // with the exception code as exit code - for example it exits with 75 | // 0 for EXCEPTION_ACCESS_VIOLATION and 1 for EXCEPTION_BREAKPOINT 76 | // if the application is compiled in debug mode. Thus we use our own 77 | // filter which always exits with the exception code for unhandled 78 | // exceptions. 79 | SetUnhandledExceptionFilter(ExitWithExceptionCode); 80 | 81 | # endif 82 | #endif // GTEST_OS_WINDOWS 83 | testing::InitGoogleTest(&argc, argv); 84 | 85 | return RUN_ALL_TESTS(); 86 | } 87 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/googletest-color-test_.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2008, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | // A helper program for testing how Google Test determines whether to use 32 | // colors in the output. It prints "YES" and returns 1 if Google Test 33 | // decides to use colors, and prints "NO" and returns 0 otherwise. 34 | 35 | #include 36 | 37 | #include "gtest/gtest.h" 38 | #include "src/gtest-internal-inl.h" 39 | 40 | using testing::internal::ShouldUseColor; 41 | 42 | // The purpose of this is to ensure that the UnitTest singleton is 43 | // created before main() is entered, and thus that ShouldUseColor() 44 | // works the same way as in a real Google-Test-based test. We don't actual 45 | // run the TEST itself. 46 | TEST(GTestColorTest, Dummy) { 47 | } 48 | 49 | int main(int argc, char** argv) { 50 | testing::InitGoogleTest(&argc, argv); 51 | 52 | if (ShouldUseColor(true)) { 53 | // Google Test decides to use colors in the output (assuming it 54 | // goes to a TTY). 55 | printf("YES\n"); 56 | return 1; 57 | } else { 58 | // Google Test decides not to use colors in the output. 59 | printf("NO\n"); 60 | return 0; 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/googletest-global-environment-unittest_.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2005, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | // Unit test for Google Test global test environments. 31 | // 32 | // The program will be invoked from a Python unit test. Don't run it 33 | // directly. 34 | 35 | #include "gtest/gtest.h" 36 | 37 | namespace { 38 | 39 | // An environment that always fails in its SetUp method. 40 | class FailingEnvironment final : public ::testing::Environment { 41 | public: 42 | void SetUp() override { FAIL() << "Canned environment setup error"; } 43 | }; 44 | 45 | // Register the environment. 46 | auto* const g_environment_ = 47 | ::testing::AddGlobalTestEnvironment(new FailingEnvironment); 48 | 49 | // A test that doesn't actually run. 50 | TEST(SomeTest, DoesFoo) { FAIL() << "Unexpected call"; } 51 | 52 | } // namespace 53 | 54 | int main(int argc, char** argv) { 55 | ::testing::InitGoogleTest(&argc, argv); 56 | 57 | return RUN_ALL_TESTS(); 58 | } 59 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/googletest-param-test-invalid-name1-test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright 2015 Google Inc. All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are 7 | # met: 8 | # 9 | # * Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # * Redistributions in binary form must reproduce the above 12 | # copyright notice, this list of conditions and the following disclaimer 13 | # in the documentation and/or other materials provided with the 14 | # distribution. 15 | # * Neither the name of Google Inc. nor the names of its 16 | # contributors may be used to endorse or promote products derived from 17 | # this software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | 31 | """Verifies that Google Test warns the user when not initialized properly.""" 32 | 33 | import gtest_test_utils 34 | 35 | binary_name = 'googletest-param-test-invalid-name1-test_' 36 | COMMAND = gtest_test_utils.GetTestExecutablePath(binary_name) 37 | 38 | 39 | def Assert(condition): 40 | if not condition: 41 | raise AssertionError 42 | 43 | 44 | def TestExitCodeAndOutput(command): 45 | """Runs the given command and verifies its exit code and output.""" 46 | 47 | err = ('Parameterized test name \'"InvalidWithQuotes"\' is invalid') 48 | 49 | p = gtest_test_utils.Subprocess(command) 50 | Assert(p.terminated_by_signal) 51 | 52 | # Verify the output message contains appropriate output 53 | Assert(err in p.output) 54 | 55 | 56 | class GTestParamTestInvalidName1Test(gtest_test_utils.TestCase): 57 | 58 | def testExitCodeAndOutput(self): 59 | TestExitCodeAndOutput(COMMAND) 60 | 61 | 62 | if __name__ == '__main__': 63 | gtest_test_utils.Main() 64 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/googletest-param-test-invalid-name1-test_.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | #include "gtest/gtest.h" 32 | 33 | namespace { 34 | class DummyTest : public ::testing::TestWithParam {}; 35 | 36 | TEST_P(DummyTest, Dummy) { 37 | } 38 | 39 | INSTANTIATE_TEST_SUITE_P(InvalidTestName, 40 | DummyTest, 41 | ::testing::Values("InvalidWithQuotes"), 42 | ::testing::PrintToStringParamName()); 43 | 44 | } // namespace 45 | 46 | int main(int argc, char *argv[]) { 47 | testing::InitGoogleTest(&argc, argv); 48 | return RUN_ALL_TESTS(); 49 | } 50 | 51 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/googletest-param-test-invalid-name2-test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright 2015 Google Inc. All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are 7 | # met: 8 | # 9 | # * Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # * Redistributions in binary form must reproduce the above 12 | # copyright notice, this list of conditions and the following disclaimer 13 | # in the documentation and/or other materials provided with the 14 | # distribution. 15 | # * Neither the name of Google Inc. nor the names of its 16 | # contributors may be used to endorse or promote products derived from 17 | # this software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | 31 | """Verifies that Google Test warns the user when not initialized properly.""" 32 | 33 | import gtest_test_utils 34 | 35 | binary_name = 'googletest-param-test-invalid-name2-test_' 36 | COMMAND = gtest_test_utils.GetTestExecutablePath(binary_name) 37 | 38 | 39 | def Assert(condition): 40 | if not condition: 41 | raise AssertionError 42 | 43 | 44 | def TestExitCodeAndOutput(command): 45 | """Runs the given command and verifies its exit code and output.""" 46 | 47 | err = ('Duplicate parameterized test name \'a\'') 48 | 49 | p = gtest_test_utils.Subprocess(command) 50 | Assert(p.terminated_by_signal) 51 | 52 | # Check for appropriate output 53 | Assert(err in p.output) 54 | 55 | 56 | class GTestParamTestInvalidName2Test(gtest_test_utils.TestCase): 57 | 58 | def testExitCodeAndOutput(self): 59 | TestExitCodeAndOutput(COMMAND) 60 | 61 | if __name__ == '__main__': 62 | gtest_test_utils.Main() 63 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/googletest-param-test-invalid-name2-test_.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2015, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | #include "gtest/gtest.h" 32 | 33 | namespace { 34 | class DummyTest : public ::testing::TestWithParam {}; 35 | 36 | std::string StringParamTestSuffix( 37 | const testing::TestParamInfo& info) { 38 | return std::string(info.param); 39 | } 40 | 41 | TEST_P(DummyTest, Dummy) { 42 | } 43 | 44 | INSTANTIATE_TEST_SUITE_P(DuplicateTestNames, 45 | DummyTest, 46 | ::testing::Values("a", "b", "a", "c"), 47 | StringParamTestSuffix); 48 | } // namespace 49 | 50 | int main(int argc, char *argv[]) { 51 | testing::InitGoogleTest(&argc, argv); 52 | return RUN_ALL_TESTS(); 53 | } 54 | 55 | 56 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/googletest-param-test-test.h: -------------------------------------------------------------------------------- 1 | // Copyright 2008, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | // 30 | // The Google C++ Testing and Mocking Framework (Google Test) 31 | // 32 | // This header file provides classes and functions used internally 33 | // for testing Google Test itself. 34 | 35 | #ifndef GOOGLETEST_TEST_GOOGLETEST_PARAM_TEST_TEST_H_ 36 | #define GOOGLETEST_TEST_GOOGLETEST_PARAM_TEST_TEST_H_ 37 | 38 | #include "gtest/gtest.h" 39 | 40 | // Test fixture for testing definition and instantiation of a test 41 | // in separate translation units. 42 | class ExternalInstantiationTest : public ::testing::TestWithParam { 43 | }; 44 | 45 | // Test fixture for testing instantiation of a test in multiple 46 | // translation units. 47 | class InstantiationInMultipleTranslationUnitsTest 48 | : public ::testing::TestWithParam { 49 | }; 50 | 51 | #endif // GOOGLETEST_TEST_GOOGLETEST_PARAM_TEST_TEST_H_ 52 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/googletest-param-test2-test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2008, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | // 31 | // Tests for Google Test itself. This verifies that the basic constructs of 32 | // Google Test work. 33 | 34 | #include "gtest/gtest.h" 35 | #include "test/googletest-param-test-test.h" 36 | 37 | using ::testing::Values; 38 | using ::testing::internal::ParamGenerator; 39 | 40 | // Tests that generators defined in a different translation unit 41 | // are functional. The test using extern_gen is defined 42 | // in googletest-param-test-test.cc. 43 | ParamGenerator extern_gen = Values(33); 44 | 45 | // Tests that a parameterized test case can be defined in one translation unit 46 | // and instantiated in another. The test is defined in 47 | // googletest-param-test-test.cc and ExternalInstantiationTest fixture class is 48 | // defined in gtest-param-test_test.h. 49 | INSTANTIATE_TEST_SUITE_P(MultiplesOf33, 50 | ExternalInstantiationTest, 51 | Values(33, 66)); 52 | 53 | // Tests that a parameterized test case can be instantiated 54 | // in multiple translation units. Another instantiation is defined 55 | // in googletest-param-test-test.cc and 56 | // InstantiationInMultipleTranslationUnitsTest fixture is defined in 57 | // gtest-param-test_test.h 58 | INSTANTIATE_TEST_SUITE_P(Sequence2, 59 | InstantiationInMultipleTranslationUnitsTest, 60 | Values(42*3, 42*4, 42*5)); 61 | 62 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/googletest-setuptestsuite-test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright 2019, Google Inc. 4 | # All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are 8 | # met: 9 | # 10 | # * Redistributions of source code must retain the above copyright 11 | # notice, this list of conditions and the following disclaimer. 12 | # * Redistributions in binary form must reproduce the above 13 | # copyright notice, this list of conditions and the following disclaimer 14 | # in the documentation and/or other materials provided with the 15 | # distribution. 16 | # * Neither the name of Google Inc. nor the names of its 17 | # contributors may be used to endorse or promote products derived from 18 | # this software without specific prior written permission. 19 | # 20 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | """Verifies that SetUpTestSuite and TearDownTestSuite errors are noticed.""" 33 | 34 | import gtest_test_utils 35 | 36 | COMMAND = gtest_test_utils.GetTestExecutablePath( 37 | 'googletest-setuptestsuite-test_') 38 | 39 | 40 | class GTestSetUpTestSuiteTest(gtest_test_utils.TestCase): 41 | 42 | def testSetupErrorAndTearDownError(self): 43 | p = gtest_test_utils.Subprocess(COMMAND) 44 | self.assertNotEqual(p.exit_code, 0, msg=p.output) 45 | 46 | self.assertIn( 47 | '[ FAILED ] SetupFailTest: SetUpTestSuite or TearDownTestSuite\n' 48 | '[ FAILED ] TearDownFailTest: SetUpTestSuite or TearDownTestSuite\n' 49 | '\n' 50 | ' 2 FAILED TEST SUITES\n', 51 | p.output) 52 | 53 | if __name__ == '__main__': 54 | gtest_test_utils.Main() 55 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/googletest-setuptestsuite-test_.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2008, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | #include "gtest/gtest.h" 32 | 33 | class SetupFailTest : public ::testing::Test { 34 | protected: 35 | static void SetUpTestSuite() { 36 | ASSERT_EQ("", "SET_UP_FAIL"); 37 | } 38 | }; 39 | 40 | TEST_F(SetupFailTest, NoopPassingTest) {} 41 | 42 | class TearDownFailTest : public ::testing::Test { 43 | protected: 44 | static void TearDownTestSuite() { 45 | ASSERT_EQ("", "TEAR_DOWN_FAIL"); 46 | } 47 | }; 48 | 49 | TEST_F(TearDownFailTest, NoopPassingTest) {} 50 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/googletest-shuffle-test_.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2009, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | // Verifies that test shuffling works. 32 | 33 | #include "gtest/gtest.h" 34 | 35 | namespace { 36 | 37 | using ::testing::EmptyTestEventListener; 38 | using ::testing::InitGoogleTest; 39 | using ::testing::Message; 40 | using ::testing::Test; 41 | using ::testing::TestEventListeners; 42 | using ::testing::TestInfo; 43 | using ::testing::UnitTest; 44 | 45 | // The test methods are empty, as the sole purpose of this program is 46 | // to print the test names before/after shuffling. 47 | 48 | class A : public Test {}; 49 | TEST_F(A, A) {} 50 | TEST_F(A, B) {} 51 | 52 | TEST(ADeathTest, A) {} 53 | TEST(ADeathTest, B) {} 54 | TEST(ADeathTest, C) {} 55 | 56 | TEST(B, A) {} 57 | TEST(B, B) {} 58 | TEST(B, C) {} 59 | TEST(B, DISABLED_D) {} 60 | TEST(B, DISABLED_E) {} 61 | 62 | TEST(BDeathTest, A) {} 63 | TEST(BDeathTest, B) {} 64 | 65 | TEST(C, A) {} 66 | TEST(C, B) {} 67 | TEST(C, C) {} 68 | TEST(C, DISABLED_D) {} 69 | 70 | TEST(CDeathTest, A) {} 71 | 72 | TEST(DISABLED_D, A) {} 73 | TEST(DISABLED_D, DISABLED_B) {} 74 | 75 | // This printer prints the full test names only, starting each test 76 | // iteration with a "----" marker. 77 | class TestNamePrinter : public EmptyTestEventListener { 78 | public: 79 | void OnTestIterationStart(const UnitTest& /* unit_test */, 80 | int /* iteration */) override { 81 | printf("----\n"); 82 | } 83 | 84 | void OnTestStart(const TestInfo& test_info) override { 85 | printf("%s.%s\n", test_info.test_suite_name(), test_info.name()); 86 | } 87 | }; 88 | 89 | } // namespace 90 | 91 | int main(int argc, char **argv) { 92 | InitGoogleTest(&argc, argv); 93 | 94 | // Replaces the default printer with TestNamePrinter, which prints 95 | // the test name only. 96 | TestEventListeners& listeners = UnitTest::GetInstance()->listeners(); 97 | delete listeners.Release(listeners.default_result_printer()); 98 | listeners.Append(new TestNamePrinter); 99 | 100 | return RUN_ALL_TESTS(); 101 | } 102 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/googletest-throw-on-failure-test_.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2009, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | // Tests Google Test's throw-on-failure mode with exceptions disabled. 32 | // 33 | // This program must be compiled with exceptions disabled. It will be 34 | // invoked by googletest-throw-on-failure-test.py, and is expected to exit 35 | // with non-zero in the throw-on-failure mode or 0 otherwise. 36 | 37 | #include "gtest/gtest.h" 38 | 39 | #include // for fflush, fprintf, NULL, etc. 40 | #include // for exit 41 | #include // for set_terminate 42 | 43 | // This terminate handler aborts the program using exit() rather than abort(). 44 | // This avoids showing pop-ups on Windows systems and core dumps on Unix-like 45 | // ones. 46 | void TerminateHandler() { 47 | fprintf(stderr, "%s\n", "Unhandled C++ exception terminating the program."); 48 | fflush(nullptr); 49 | exit(1); 50 | } 51 | 52 | int main(int argc, char** argv) { 53 | #if GTEST_HAS_EXCEPTIONS 54 | std::set_terminate(&TerminateHandler); 55 | #endif 56 | testing::InitGoogleTest(&argc, argv); 57 | 58 | // We want to ensure that people can use Google Test assertions in 59 | // other testing frameworks, as long as they initialize Google Test 60 | // properly and set the throw-on-failure mode. Therefore, we don't 61 | // use Google Test's constructs for defining and running tests 62 | // (e.g. TEST and RUN_ALL_TESTS) here. 63 | 64 | // In the throw-on-failure mode with exceptions disabled, this 65 | // assertion will cause the program to exit with a non-zero code. 66 | EXPECT_EQ(2, 3); 67 | 68 | // When not in the throw-on-failure mode, the control will reach 69 | // here. 70 | return 0; 71 | } 72 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/googletest-uninitialized-test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright 2008, Google Inc. 4 | # All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are 8 | # met: 9 | # 10 | # * Redistributions of source code must retain the above copyright 11 | # notice, this list of conditions and the following disclaimer. 12 | # * Redistributions in binary form must reproduce the above 13 | # copyright notice, this list of conditions and the following disclaimer 14 | # in the documentation and/or other materials provided with the 15 | # distribution. 16 | # * Neither the name of Google Inc. nor the names of its 17 | # contributors may be used to endorse or promote products derived from 18 | # this software without specific prior written permission. 19 | # 20 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | """Verifies that Google Test warns the user when not initialized properly.""" 33 | 34 | import gtest_test_utils 35 | 36 | COMMAND = gtest_test_utils.GetTestExecutablePath('googletest-uninitialized-test_') 37 | 38 | 39 | def Assert(condition): 40 | if not condition: 41 | raise AssertionError 42 | 43 | 44 | def AssertEq(expected, actual): 45 | if expected != actual: 46 | print('Expected: %s' % (expected,)) 47 | print(' Actual: %s' % (actual,)) 48 | raise AssertionError 49 | 50 | 51 | def TestExitCodeAndOutput(command): 52 | """Runs the given command and verifies its exit code and output.""" 53 | 54 | # Verifies that 'command' exits with code 1. 55 | p = gtest_test_utils.Subprocess(command) 56 | if p.exited and p.exit_code == 0: 57 | Assert('IMPORTANT NOTICE' in p.output); 58 | Assert('InitGoogleTest' in p.output) 59 | 60 | 61 | class GTestUninitializedTest(gtest_test_utils.TestCase): 62 | def testExitCodeAndOutput(self): 63 | TestExitCodeAndOutput(COMMAND) 64 | 65 | 66 | if __name__ == '__main__': 67 | gtest_test_utils.Main() 68 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/googletest-uninitialized-test_.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2008, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | #include "gtest/gtest.h" 32 | 33 | TEST(DummyTest, Dummy) { 34 | // This test doesn't verify anything. We just need it to create a 35 | // realistic stage for testing the behavior of Google Test when 36 | // RUN_ALL_TESTS() is called without 37 | // testing::InitGoogleTest() being called first. 38 | } 39 | 40 | int main() { 41 | return RUN_ALL_TESTS(); 42 | } 43 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest-typed-test2_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2008 Google Inc. 2 | // All Rights Reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | #include 32 | 33 | #include "test/gtest-typed-test_test.h" 34 | #include "gtest/gtest.h" 35 | 36 | // Tests that the same type-parameterized test case can be 37 | // instantiated in different translation units linked together. 38 | // (ContainerTest is also instantiated in gtest-typed-test_test.cc.) 39 | INSTANTIATE_TYPED_TEST_SUITE_P(Vector, ContainerTest, 40 | testing::Types >); 41 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest-typed-test_test.h: -------------------------------------------------------------------------------- 1 | // Copyright 2008 Google Inc. 2 | // All Rights Reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | #ifndef GOOGLETEST_TEST_GTEST_TYPED_TEST_TEST_H_ 31 | #define GOOGLETEST_TEST_GTEST_TYPED_TEST_TEST_H_ 32 | 33 | #include "gtest/gtest.h" 34 | 35 | using testing::Test; 36 | 37 | // For testing that the same type-parameterized test case can be 38 | // instantiated in different translation units linked together. 39 | // ContainerTest will be instantiated in both gtest-typed-test_test.cc 40 | // and gtest-typed-test2_test.cc. 41 | 42 | template 43 | class ContainerTest : public Test { 44 | }; 45 | 46 | TYPED_TEST_SUITE_P(ContainerTest); 47 | 48 | TYPED_TEST_P(ContainerTest, CanBeDefaultConstructed) { 49 | TypeParam container; 50 | } 51 | 52 | TYPED_TEST_P(ContainerTest, InitialSizeIsZero) { 53 | TypeParam container; 54 | EXPECT_EQ(0U, container.size()); 55 | } 56 | 57 | REGISTER_TYPED_TEST_SUITE_P(ContainerTest, 58 | CanBeDefaultConstructed, InitialSizeIsZero); 59 | 60 | #endif // GOOGLETEST_TEST_GTEST_TYPED_TEST_TEST_H_ 61 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest_all_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2009, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | // 31 | // Tests for Google C++ Testing and Mocking Framework (Google Test) 32 | // 33 | // Sometimes it's desirable to build most of Google Test's own tests 34 | // by compiling a single file. This file serves this purpose. 35 | #include "test/googletest-filepath-test.cc" 36 | #include "test/googletest-message-test.cc" 37 | #include "test/googletest-options-test.cc" 38 | #include "test/googletest-port-test.cc" 39 | #include "test/googletest-test-part-test.cc" 40 | #include "test/gtest-typed-test2_test.cc" 41 | #include "test/gtest-typed-test_test.cc" 42 | #include "test/gtest_pred_impl_unittest.cc" 43 | #include "test/gtest_prod_test.cc" 44 | #include "test/gtest_skip_test.cc" 45 | #include "test/gtest_unittest.cc" 46 | #include "test/production.cc" 47 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest_help_test_.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2009, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | // This program is meant to be run by gtest_help_test.py. Do not run 32 | // it directly. 33 | 34 | #include "gtest/gtest.h" 35 | 36 | // When a help flag is specified, this program should skip the tests 37 | // and exit with 0; otherwise the following test will be executed, 38 | // causing this program to exit with a non-zero code. 39 | TEST(HelpFlagTest, ShouldNotBeRun) { 40 | ASSERT_TRUE(false) << "Tests shouldn't be run when --help is specified."; 41 | } 42 | 43 | #if GTEST_HAS_DEATH_TEST 44 | TEST(DeathTest, UsedByPythonScriptToDetectSupportForDeathTestsInThisBinary) {} 45 | #endif 46 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest_json_test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018, Google Inc. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are 6 | # met: 7 | # 8 | # * Redistributions of source code must retain the above copyright 9 | # notice, this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above 11 | # copyright notice, this list of conditions and the following disclaimer 12 | # in the documentation and/or other materials provided with the 13 | # distribution. 14 | # * Neither the name of Google Inc. nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | """Unit test utilities for gtest_json_output.""" 31 | 32 | import re 33 | 34 | 35 | def normalize(obj): 36 | """Normalize output object. 37 | 38 | Args: 39 | obj: Google Test's JSON output object to normalize. 40 | 41 | Returns: 42 | Normalized output without any references to transient information that may 43 | change from run to run. 44 | """ 45 | def _normalize(key, value): 46 | if key == 'time': 47 | return re.sub(r'^\d+(\.\d+)?s$', '*', value) 48 | elif key == 'timestamp': 49 | return re.sub(r'^\d{4}-\d\d-\d\dT\d\d:\d\d:\d\dZ$', '*', value) 50 | elif key == 'failure': 51 | value = re.sub(r'^.*[/\\](.*:)\d+\n', '\\1*\n', value) 52 | return re.sub(r'Stack trace:\n(.|\n)*', 'Stack trace:\n*', value) 53 | else: 54 | return normalize(value) 55 | if isinstance(obj, dict): 56 | return {k: _normalize(k, v) for k, v in obj.items()} 57 | if isinstance(obj, list): 58 | return [normalize(x) for x in obj] 59 | else: 60 | return obj 61 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest_list_output_unittest_.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2018, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | // 30 | // Author: david.schuldenfrei@gmail.com (David Schuldenfrei) 31 | 32 | // Unit test for Google Test's --gtest_list_tests and --gtest_output flag. 33 | // 34 | // A user can ask Google Test to list all tests that will run, 35 | // and have the output saved in a Json/Xml file. 36 | // The tests will not be run after listing. 37 | // 38 | // This program will be invoked from a Python unit test. 39 | // Don't run it directly. 40 | 41 | #include "gtest/gtest.h" 42 | 43 | TEST(FooTest, Test1) {} 44 | 45 | TEST(FooTest, Test2) {} 46 | 47 | class FooTestFixture : public ::testing::Test {}; 48 | TEST_F(FooTestFixture, Test3) {} 49 | TEST_F(FooTestFixture, Test4) {} 50 | 51 | class ValueParamTest : public ::testing::TestWithParam {}; 52 | TEST_P(ValueParamTest, Test5) {} 53 | TEST_P(ValueParamTest, Test6) {} 54 | INSTANTIATE_TEST_SUITE_P(ValueParam, ValueParamTest, ::testing::Values(33, 42)); 55 | 56 | template 57 | class TypedTest : public ::testing::Test {}; 58 | typedef testing::Types TypedTestTypes; 59 | TYPED_TEST_SUITE(TypedTest, TypedTestTypes); 60 | TYPED_TEST(TypedTest, Test7) {} 61 | TYPED_TEST(TypedTest, Test8) {} 62 | 63 | template 64 | class TypeParameterizedTestSuite : public ::testing::Test {}; 65 | TYPED_TEST_SUITE_P(TypeParameterizedTestSuite); 66 | TYPED_TEST_P(TypeParameterizedTestSuite, Test9) {} 67 | TYPED_TEST_P(TypeParameterizedTestSuite, Test10) {} 68 | REGISTER_TYPED_TEST_SUITE_P(TypeParameterizedTestSuite, Test9, Test10); 69 | typedef testing::Types TypeParameterizedTestSuiteTypes; // NOLINT 70 | INSTANTIATE_TYPED_TEST_SUITE_P(Single, TypeParameterizedTestSuite, 71 | TypeParameterizedTestSuiteTypes); 72 | 73 | int main(int argc, char **argv) { 74 | ::testing::InitGoogleTest(&argc, argv); 75 | 76 | return RUN_ALL_TESTS(); 77 | } 78 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest_main_unittest.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2006, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | #include "gtest/gtest.h" 32 | 33 | // Tests that we don't have to define main() when we link to 34 | // gtest_main instead of gtest. 35 | 36 | namespace { 37 | 38 | TEST(GTestMainTest, ShouldSucceed) { 39 | } 40 | 41 | } // namespace 42 | 43 | // We are using the main() function defined in gtest_main.cc, so we 44 | // don't define it here. 45 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest_no_test_unittest.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2006, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | // Tests that a Google Test program that has no test defined can run 31 | // successfully. 32 | 33 | #include "gtest/gtest.h" 34 | 35 | int main(int argc, char **argv) { 36 | testing::InitGoogleTest(&argc, argv); 37 | 38 | // An ad-hoc assertion outside of all tests. 39 | // 40 | // This serves three purposes: 41 | // 42 | // 1. It verifies that an ad-hoc assertion can be executed even if 43 | // no test is defined. 44 | // 2. It verifies that a failed ad-hoc assertion causes the test 45 | // program to fail. 46 | // 3. We had a bug where the XML output won't be generated if an 47 | // assertion is executed before RUN_ALL_TESTS() is called, even 48 | // though --gtest_output=xml is specified. This makes sure the 49 | // bug is fixed and doesn't regress. 50 | EXPECT_EQ(1, 2); 51 | 52 | // The above EXPECT_EQ() should cause RUN_ALL_TESTS() to return non-zero. 53 | return RUN_ALL_TESTS() ? 0 : 1; 54 | } 55 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest_prod_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2006, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | // 31 | // Unit test for gtest_prod.h. 32 | 33 | #include "production.h" 34 | #include "gtest/gtest.h" 35 | 36 | // Tests that private members can be accessed from a TEST declared as 37 | // a friend of the class. 38 | TEST(PrivateCodeTest, CanAccessPrivateMembers) { 39 | PrivateCode a; 40 | EXPECT_EQ(0, a.x_); 41 | 42 | a.set_x(1); 43 | EXPECT_EQ(1, a.x_); 44 | } 45 | 46 | typedef testing::Test PrivateCodeFixtureTest; 47 | 48 | // Tests that private members can be accessed from a TEST_F declared 49 | // as a friend of the class. 50 | TEST_F(PrivateCodeFixtureTest, CanAccessPrivateMembers) { 51 | PrivateCode a; 52 | EXPECT_EQ(0, a.x_); 53 | 54 | a.set_x(2); 55 | EXPECT_EQ(2, a.x_); 56 | } 57 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest_skip_check_output_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright 2019 Google LLC. All Rights Reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are 7 | # met: 8 | # 9 | # * Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # * Redistributions in binary form must reproduce the above 12 | # copyright notice, this list of conditions and the following disclaimer 13 | # in the documentation and/or other materials provided with the 14 | # distribution. 15 | # * Neither the name of Google Inc. nor the names of its 16 | # contributors may be used to endorse or promote products derived from 17 | # this software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | """Tests Google Test's gtest skip in environment setup behavior. 31 | 32 | This script invokes gtest_skip_in_environment_setup_test_ and verifies its 33 | output. 34 | """ 35 | 36 | import re 37 | 38 | import gtest_test_utils 39 | 40 | # Path to the gtest_skip_in_environment_setup_test binary 41 | EXE_PATH = gtest_test_utils.GetTestExecutablePath('gtest_skip_test') 42 | 43 | OUTPUT = gtest_test_utils.Subprocess([EXE_PATH]).output 44 | 45 | 46 | # Test. 47 | class SkipEntireEnvironmentTest(gtest_test_utils.TestCase): 48 | 49 | def testSkipEntireEnvironmentTest(self): 50 | self.assertIn('Skipped\nskipping single test\n', OUTPUT) 51 | skip_fixture = 'Skipped\nskipping all tests for this fixture\n' 52 | self.assertIsNotNone( 53 | re.search(skip_fixture + '.*' + skip_fixture, OUTPUT, flags=re.DOTALL), 54 | repr(OUTPUT)) 55 | self.assertNotIn('FAILED', OUTPUT) 56 | 57 | 58 | if __name__ == '__main__': 59 | gtest_test_utils.Main() 60 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest_skip_environment_check_output_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright 2019 Google LLC. All Rights Reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are 7 | # met: 8 | # 9 | # * Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # * Redistributions in binary form must reproduce the above 12 | # copyright notice, this list of conditions and the following disclaimer 13 | # in the documentation and/or other materials provided with the 14 | # distribution. 15 | # * Neither the name of Google Inc. nor the names of its 16 | # contributors may be used to endorse or promote products derived from 17 | # this software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | """Tests Google Test's gtest skip in environment setup behavior. 31 | 32 | This script invokes gtest_skip_in_environment_setup_test_ and verifies its 33 | output. 34 | """ 35 | 36 | import gtest_test_utils 37 | 38 | # Path to the gtest_skip_in_environment_setup_test binary 39 | EXE_PATH = gtest_test_utils.GetTestExecutablePath( 40 | 'gtest_skip_in_environment_setup_test') 41 | 42 | OUTPUT = gtest_test_utils.Subprocess([EXE_PATH]).output 43 | 44 | 45 | # Test. 46 | class SkipEntireEnvironmentTest(gtest_test_utils.TestCase): 47 | 48 | def testSkipEntireEnvironmentTest(self): 49 | self.assertIn('Skipping the entire environment', OUTPUT) 50 | self.assertNotIn('FAILED', OUTPUT) 51 | 52 | 53 | if __name__ == '__main__': 54 | gtest_test_utils.Main() 55 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest_skip_in_environment_setup_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2019, Google LLC. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google LLC. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | // 30 | // This test verifies that skipping in the environment results in the 31 | // testcases being skipped. 32 | 33 | #include 34 | #include "gtest/gtest.h" 35 | 36 | class SetupEnvironment : public testing::Environment { 37 | public: 38 | void SetUp() override { GTEST_SKIP() << "Skipping the entire environment"; } 39 | }; 40 | 41 | TEST(Test, AlwaysFails) { EXPECT_EQ(true, false); } 42 | 43 | int main(int argc, char **argv) { 44 | testing::InitGoogleTest(&argc, argv); 45 | 46 | testing::AddGlobalTestEnvironment(new SetupEnvironment()); 47 | 48 | return RUN_ALL_TESTS(); 49 | } 50 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest_skip_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2008 Google Inc. 2 | // All Rights Reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | // 30 | // Author: arseny.aprelev@gmail.com (Arseny Aprelev) 31 | // 32 | 33 | #include "gtest/gtest.h" 34 | 35 | using ::testing::Test; 36 | 37 | TEST(SkipTest, DoesSkip) { 38 | GTEST_SKIP() << "skipping single test"; 39 | EXPECT_EQ(0, 1); 40 | } 41 | 42 | class Fixture : public Test { 43 | protected: 44 | void SetUp() override { 45 | GTEST_SKIP() << "skipping all tests for this fixture"; 46 | } 47 | }; 48 | 49 | TEST_F(Fixture, SkipsOneTest) { 50 | EXPECT_EQ(5, 7); 51 | } 52 | 53 | TEST_F(Fixture, SkipsAnotherTest) { 54 | EXPECT_EQ(99, 100); 55 | } 56 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest_sole_header_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2008, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | // 31 | // This test verifies that it's possible to use Google Test by including 32 | // the gtest.h header file alone. 33 | 34 | #include "gtest/gtest.h" 35 | 36 | namespace { 37 | 38 | void Subroutine() { 39 | EXPECT_EQ(42, 42); 40 | } 41 | 42 | TEST(NoFatalFailureTest, ExpectNoFatalFailure) { 43 | EXPECT_NO_FATAL_FAILURE(;); 44 | EXPECT_NO_FATAL_FAILURE(SUCCEED()); 45 | EXPECT_NO_FATAL_FAILURE(Subroutine()); 46 | EXPECT_NO_FATAL_FAILURE({ SUCCEED(); }); 47 | } 48 | 49 | TEST(NoFatalFailureTest, AssertNoFatalFailure) { 50 | ASSERT_NO_FATAL_FAILURE(;); 51 | ASSERT_NO_FATAL_FAILURE(SUCCEED()); 52 | ASSERT_NO_FATAL_FAILURE(Subroutine()); 53 | ASSERT_NO_FATAL_FAILURE({ SUCCEED(); }); 54 | } 55 | 56 | } // namespace 57 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest_testbridge_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright 2018 Google LLC. All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are 7 | # met: 8 | # 9 | # * Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # * Redistributions in binary form must reproduce the above 12 | # copyright notice, this list of conditions and the following disclaimer 13 | # in the documentation and/or other materials provided with the 14 | # distribution. 15 | # * Neither the name of Google Inc. nor the names of its 16 | # contributors may be used to endorse or promote products derived from 17 | # this software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | """Verifies that Google Test uses filter provided via testbridge.""" 31 | 32 | import os 33 | 34 | import gtest_test_utils 35 | 36 | binary_name = 'gtest_testbridge_test_' 37 | COMMAND = gtest_test_utils.GetTestExecutablePath(binary_name) 38 | TESTBRIDGE_NAME = 'TESTBRIDGE_TEST_ONLY' 39 | 40 | 41 | def Assert(condition): 42 | if not condition: 43 | raise AssertionError 44 | 45 | 46 | class GTestTestFilterTest(gtest_test_utils.TestCase): 47 | 48 | def testTestExecutionIsFiltered(self): 49 | """Tests that the test filter is picked up from the testbridge env var.""" 50 | subprocess_env = os.environ.copy() 51 | 52 | subprocess_env[TESTBRIDGE_NAME] = '*.TestThatSucceeds' 53 | p = gtest_test_utils.Subprocess(COMMAND, env=subprocess_env) 54 | 55 | self.assertEquals(0, p.exit_code) 56 | 57 | Assert('filter = *.TestThatSucceeds' in p.output) 58 | Assert('[ OK ] TestFilterTest.TestThatSucceeds' in p.output) 59 | Assert('[ PASSED ] 1 test.' in p.output) 60 | 61 | 62 | if __name__ == '__main__': 63 | gtest_test_utils.Main() 64 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest_testbridge_test_.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2018, Google LLC. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | // This program is meant to be run by gtest_test_filter_test.py. Do not run 32 | // it directly. 33 | 34 | #include "gtest/gtest.h" 35 | 36 | // These tests are used to detect if filtering is working. Only 37 | // 'TestThatSucceeds' should ever run. 38 | 39 | TEST(TestFilterTest, TestThatSucceeds) {} 40 | 41 | TEST(TestFilterTest, TestThatFails) { 42 | ASSERT_TRUE(false) << "This test should never be run."; 43 | } 44 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest_throw_on_failure_ex_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2009, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | // Tests Google Test's throw-on-failure mode with exceptions enabled. 32 | 33 | #include "gtest/gtest.h" 34 | 35 | #include 36 | #include 37 | #include 38 | #include 39 | 40 | // Prints the given failure message and exits the program with 41 | // non-zero. We use this instead of a Google Test assertion to 42 | // indicate a failure, as the latter is been tested and cannot be 43 | // relied on. 44 | void Fail(const char* msg) { 45 | printf("FAILURE: %s\n", msg); 46 | fflush(stdout); 47 | exit(1); 48 | } 49 | 50 | // Tests that an assertion failure throws a subclass of 51 | // std::runtime_error. 52 | void TestFailureThrowsRuntimeError() { 53 | GTEST_FLAG_SET(throw_on_failure, true); 54 | 55 | // A successful assertion shouldn't throw. 56 | try { 57 | EXPECT_EQ(3, 3); 58 | } catch(...) { 59 | Fail("A successful assertion wrongfully threw."); 60 | } 61 | 62 | // A failed assertion should throw a subclass of std::runtime_error. 63 | try { 64 | EXPECT_EQ(2, 3) << "Expected failure"; 65 | } catch(const std::runtime_error& e) { 66 | if (strstr(e.what(), "Expected failure") != nullptr) return; 67 | 68 | printf("%s", 69 | "A failed assertion did throw an exception of the right type, " 70 | "but the message is incorrect. Instead of containing \"Expected " 71 | "failure\", it is:\n"); 72 | Fail(e.what()); 73 | } catch(...) { 74 | Fail("A failed assertion threw the wrong type of exception."); 75 | } 76 | Fail("A failed assertion should've thrown but didn't."); 77 | } 78 | 79 | int main(int argc, char** argv) { 80 | testing::InitGoogleTest(&argc, argv); 81 | 82 | // We want to ensure that people can use Google Test assertions in 83 | // other testing frameworks, as long as they initialize Google Test 84 | // properly and set the thrown-on-failure mode. Therefore, we don't 85 | // use Google Test's constructs for defining and running tests 86 | // (e.g. TEST and RUN_ALL_TESTS) here. 87 | 88 | TestFailureThrowsRuntimeError(); 89 | return 0; 90 | } 91 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest_xml_outfile1_test_.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2008, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | // 30 | // gtest_xml_outfile1_test_ writes some xml via TestProperty used by 31 | // gtest_xml_outfiles_test.py 32 | 33 | #include "gtest/gtest.h" 34 | 35 | class PropertyOne : public testing::Test { 36 | protected: 37 | void SetUp() override { RecordProperty("SetUpProp", 1); } 38 | void TearDown() override { RecordProperty("TearDownProp", 1); } 39 | }; 40 | 41 | TEST_F(PropertyOne, TestSomeProperties) { 42 | RecordProperty("TestSomeProperty", 1); 43 | } 44 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/gtest_xml_outfile2_test_.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2008, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | // 30 | // gtest_xml_outfile2_test_ writes some xml via TestProperty used by 31 | // gtest_xml_outfiles_test.py 32 | 33 | #include "gtest/gtest.h" 34 | 35 | class PropertyTwo : public testing::Test { 36 | protected: 37 | void SetUp() override { RecordProperty("SetUpProp", 2); } 38 | void TearDown() override { RecordProperty("TearDownProp", 2); } 39 | }; 40 | 41 | TEST_F(PropertyTwo, TestSomeProperties) { 42 | RecordProperty("TestSomeProperty", 2); 43 | } 44 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/production.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2006, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | // 31 | // This is part of the unit test for gtest_prod.h. 32 | 33 | #include "production.h" 34 | 35 | PrivateCode::PrivateCode() : x_(0) {} 36 | -------------------------------------------------------------------------------- /thirdparty/googletest/googletest/test/production.h: -------------------------------------------------------------------------------- 1 | // Copyright 2006, Google Inc. 2 | // All rights reserved. 3 | // 4 | // Redistribution and use in source and binary forms, with or without 5 | // modification, are permitted provided that the following conditions are 6 | // met: 7 | // 8 | // * Redistributions of source code must retain the above copyright 9 | // notice, this list of conditions and the following disclaimer. 10 | // * Redistributions in binary form must reproduce the above 11 | // copyright notice, this list of conditions and the following disclaimer 12 | // in the documentation and/or other materials provided with the 13 | // distribution. 14 | // * Neither the name of Google Inc. nor the names of its 15 | // contributors may be used to endorse or promote products derived from 16 | // this software without specific prior written permission. 17 | // 18 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | // 31 | // This is part of the unit test for gtest_prod.h. 32 | 33 | #ifndef GOOGLETEST_TEST_PRODUCTION_H_ 34 | #define GOOGLETEST_TEST_PRODUCTION_H_ 35 | 36 | #include "gtest/gtest_prod.h" 37 | 38 | class PrivateCode { 39 | public: 40 | // Declares a friend test that does not use a fixture. 41 | FRIEND_TEST(PrivateCodeTest, CanAccessPrivateMembers); 42 | 43 | // Declares a friend test that uses a fixture. 44 | FRIEND_TEST(PrivateCodeFixtureTest, CanAccessPrivateMembers); 45 | 46 | PrivateCode(); 47 | 48 | int x() const { return x_; } 49 | private: 50 | void set_x(int an_x) { x_ = an_x; } 51 | int x_; 52 | }; 53 | 54 | #endif // GOOGLETEST_TEST_PRODUCTION_H_ 55 | --------------------------------------------------------------------------------