├── .travis.yml ├── LICENSE ├── Makefile ├── README.md ├── examples ├── .gitignore ├── README.md ├── build.sh ├── example_blink_led │ ├── README.md │ ├── src │ │ ├── led.cpp │ │ ├── led.h │ │ ├── main.cpp │ │ ├── stimer.cpp │ │ └── stimer.h │ └── tests │ │ ├── mutest_led.cpp │ │ ├── test_led.cpp │ │ ├── test_main.cpp │ │ └── test_stimer.cpp ├── example_serial_comm │ ├── src │ │ ├── serial.cpp │ │ └── serial.h │ └── tests │ │ └── test_serial.cpp ├── test_fail_pinwrite │ ├── src │ │ └── main.cpp │ └── tests │ │ └── test_main.cpp ├── test_ok_extra_fake │ ├── src │ │ ├── main.cpp │ │ └── moduleT.h │ └── tests │ │ ├── common_globals.cpp │ │ └── test_main.cpp └── test_ok_mock_manual │ ├── src │ ├── drive.cpp │ ├── moduleT.h │ └── moduleX.h │ └── tests │ ├── mocks_man │ ├── mock_moduleT.cpp │ └── mock_moduleT.h │ └── test_drive.cpp ├── run_tests.sh ├── src ├── .gitignore ├── SConscript ├── build │ └── .gitkeep ├── fakes │ ├── Arduino.cpp │ ├── Arduino.h │ ├── Serial.cpp │ ├── Servo.cpp │ ├── Servo.h │ ├── avr │ │ └── eeprom.h │ ├── crc16.cpp │ ├── eeprom.cpp │ └── util │ │ └── crc16.h ├── mocks_gen │ └── .gitkeep ├── mocks_man │ └── .gitkeep ├── tmain.cpp └── tools │ ├── Makefile │ ├── SConstruct │ ├── __init__.py │ ├── catch.hpp │ ├── cpp │ ├── __init__.py │ ├── ast.py │ ├── find_warnings.py │ ├── headers.py │ ├── keywords.py │ ├── metrics.py │ ├── nonvirtual_dtors.py │ ├── static_data.py │ ├── symbols.py │ ├── tokenize.py │ └── utils.py │ ├── fff.h │ ├── fffmock.py │ ├── run_all.sh │ └── run_coverage.sh └── website ├── screen_coverage.png └── screen_debug.png /.travis.yml: -------------------------------------------------------------------------------- 1 | 2 | language: cpp 3 | sudo: required 4 | dist: trusty 5 | 6 | before_install: 7 | - sudo apt-get install scons lcov 8 | 9 | compiler: 10 | - gcc 11 | 12 | os: 13 | - linux 14 | 15 | script: ./run_tests.sh 16 | 17 | notifications: 18 | email: false 19 | 20 | env: 21 | - LANG="en_US.UTF-8" 22 | 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Pauli Salmenrinne 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | run_test: 3 | ./run_tests.sh 4 | 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | What it is 3 | ==================== 4 | Unittesting 'framework' for arduino projects, currently works with platform.io projects. It generates host-pc binaries that contains tests, that can be runned and debuged, no flashing required, super fast. I call it framework, since it provides building, mock-generation and testing in single package, and requires you to copy folder to your project folder. 5 | 6 | 7 | Why use it instead of other testing frameworks 8 | ==================== 9 | 10 | * Automatic generation of mock/stub functions from source-files headers and automatic compiling of test binaries with specified units as real implementantions. 11 | * Basic Arduino functions fakes (so that compiling works, this is not an emulator) 12 | * Compiles to HOST-PC runnable binary, that you can easily debug with your preferred IDE or plain gdb. 13 | * Batteries included: test-framework & mocking picked from two lightweight alternatives 14 | 15 | Examples why to do unittesting on your arduino code 16 | ==================== 17 | * Timer class that has to deal with overflow. Manually testing is imposible, and testing on real hw is doable, but you have to inject somehow millis() to return proper test values. 18 | * Create separate My_Timer.h module, that is unit under tests, and other modules, (include millis() function) are mocked, and user can define the return codes as one pleases. 19 | * Serial communication with another device, and you want to make sure the output looks proper 20 | * Create separate My_Device.h module, that is unit under test, make the Serial.read() return the input you want and check that Serial.write output is proper. 21 | * Stepper motor driving with specified ramp - assume you are controlling a stepper motor, that speeds up with smooth ramp and travels certain number of steps. You want to make sure that the ramp has no 'clitches' and it looks proper. 22 | * Create My_Driver.h module that contanis the driving functions and My_Stepper.h that contains the actual stepper controlling functions. Now when testing module My_Driver you can set the MyStepper.set_steps() to output a file, and you can analyze the file with say python to find possible spikes. 23 | 24 | Usage 25 | ==================== 26 | 27 | First you might want to see the examples folder - the [blink led](examples/example_blink_led) is currently the best to get quick overview what this is about. 28 | 29 | 30 | ### Set up the framework for your project: 31 | 32 | Currently you need to copy this projects 'src' folder to your project. This is because Scons is not very happy with out-of-tree builds. Sorry. 33 | 34 | 1. Copy the 'src' folder as 'tests' folder into your project. 35 | 2. Add building of the tests to your build manager - it works by executing ```scons -Y tests/tools/``` on base path 36 | * There is also example Makefile on the tests/tools/ directory 37 | 38 | ### Start writing unit tests 39 | 40 | These tests will be linked so that they will have only the unit under test real source file and other modules are mocked. See below for more info. 41 | 42 | 1. Create a test file called ```tests/test_.cpp``` 43 | 44 | That is if you have ```src/foo.cpp``` and ```src/bar.cpp``` the test file for foo should 45 | be called ```tests/tests_foo.cpp``` and there will be binary called ```tests/build/bin/test_foo``` created after compiling is done. 46 | 47 | The test binary will have real implementations from ```src/foo.cpp``` but module Bar will be implemented in mock - from ```tests/mocks_gen/mock_bar.cpp```. 48 | The arduino specific stuff is mostly mocked away, but some extra checking is availble. See [src/Arduino.h](src/Arduino.h) and 'Fakes provided' later on this document. 49 | 50 | 51 | ### Do some multi-unit tests 52 | 53 | These binaries will be called ```tests/build/bin/mutest_xxx``` and they will contain real implementation from the files listed. Other modules will be mocked away. 54 | 55 | 1. Create a test file called ```tests/mutest_.cpp``` 56 | 2. In that file create line containing '__UNITTEST__SOURCES_ = '. For example if you have ```src/foo.cpp'' and ```src/bar.cpp'', this line should be ```__UNITTEST__SOURCES_ = foo.cpp, bar.cpp```. 57 | 58 | The minimal test file (and this is enough! It compiles to test binary!) would be something like: 59 | ``` 60 | #include "catch.hpp" 61 | #include "Arduino.h" 62 | 63 | #include "led.h" 64 | 65 | 66 | TEST_CASE( "Led blinking works", "[led]" ) 67 | { 68 | Led led; 69 | led.setup(1); 70 | REQUIRE( digitalWrite_fake.call_count == 1); 71 | led.loop(); 72 | REQUIRE( digitalWrite_fake.call_count == 2); 73 | } 74 | ``` 75 | 76 | ### Extra: Define common test modules 77 | 78 | These common files will be included in all tests. They can be used to contain global variables (```external Foo XXX``` stuff) or just common test code (```initialize_output_csv()```). 79 | 1. Create files ```tests/common_.cpp``` as needed. 80 | 81 | ### Extra: Define manual mocks 82 | 83 | For manually mocked modules we do not generate 'automatic' mocks. To define a module as manually mocked do: 84 | 1. Create ```tests/mocks_man/mock_.cpp``` - that is if you have ```src/foo.cpp``` create file ```tests/mocks_man/mock_foo.cpp``` that contains the mocked functions 85 | 86 | 87 | Fakes provided 88 | ==================== 89 | 90 | * Arduino pin functionality. Its defined in ```src/fakes/Arduino.h``` and implemented in corresponding .cpp. It provides simple checks (can be disabled) that no un-initialized pin gets read or writes. 91 | * Serial functionality that stores and received lines from buffer. 92 | 93 | 94 | Examples 95 | ==================== 96 | * See [examples](examples) directory that also works as tests for this test framework. 97 | * See my [aquarium feeder](https://github.com/susundberg/arduino-aquarium-feeder) project for full platformio example 98 | 99 | #### Screenshot: Debugger on arduino code 100 | ![screenshot debugger](https://rawgit.com/susundberg/arduino-simple-unittest/master/website/screen_debug.png) 101 | 102 | #### Screenshot: Coverage of the tests on firefox (generates html) 103 | ![screenshot coverage](https://rawgit.com/susundberg/arduino-simple-unittest/master/website/screen_coverage.png) 104 | 105 | 106 | Testing of the test framework 107 | ==================== 108 | * Tested with travis on ubuntu trusty ![build status](https://travis-ci.org/susundberg/arduino-simple-unittest.svg?branch=master) 109 | 110 | What makes the heavy lifting 111 | ==================== 112 | * [FFF](https://github.com/meekrosoft/fff) to make mock functions 113 | * [CPPclean](https://github.com/myint/cppclean/) to parse the sources to find what kind of functions to mock 114 | * [Catch](https://github.com/philsquared/Catch) to run tests 115 | * [Scons](http://scons.org/) to run building scripts 116 | * [Lcov](http://ltp.sourceforge.net/coverage/lcov.php) to generate coverage report. 117 | 118 | 119 | TODO 120 | ==================== 121 | 122 | ### Structure: build tree 123 | I would like to have the tests/ folder out from the source tree, but currently it would require some hacking for Scons to support out of tree build or change of build system. 124 | 125 | ### Structure: separate platforms 126 | I am working on ESP8266 support, and it would be neat to have it somewhat separated from the basic Arduino tree, and one would configure the platform used in tests/test_config.ini or similar. 127 | 128 | ### Mock generation: C++ 129 | * C++: Support for static member functions 130 | * C++: Support for overloaded functions 131 | * C++: Support for references in parameters 132 | * C++: Support for types that are defined inside classes/structures 133 | 134 | ### Arduino IDE 135 | Like said before, currently this works only with platformio projects. So its not working with Arduino IDE code, but i welcome pull requests to make it happen. With Arduino IDE one still needs the source files to be splitted to multiple files (for mock generatation). 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | 3 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | Examples 2 | ================== 3 | 4 | 5 | ### example_blink_led 6 | 7 | Features: 8 | * Multi unit test example - see mutest_led.cpp file. It has two modules as real modules and others mocked. See the test_led.cpp for the same test with all modules mocked 9 | * Arduino_test hookup, providing check that we do not write on input or un-initialized pins. 10 | * Class mock generation: mocks are generated for both classes. The Led class mocks is used in test_main.cpp 11 | 12 | 13 | ### example_serial_comm 14 | Features: 15 | * Use the Serial-fake object to test the serial communication 16 | * Uses variable args 17 | 18 | ### test_fail_pinwrite 19 | * Tests and demonstrate how we get fail for pinwrite on uninitialized port ( been there done that - debugging hardware for software bug). 20 | 21 | ### test_ok_extra_fake 22 | * Demonstrate how to provide common sources for all tests. 23 | 24 | ### test_ok_mock_manual 25 | * Demonstrate how to manually provide mocked methods, for example if you want to do some extra simulation or so. 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /examples/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | set -e 5 | 6 | usage_exit() 7 | { 8 | echo "$0 " 9 | exit 1 10 | } 11 | 12 | 13 | if [ -z $1 ]; then 14 | usage_exit; 15 | fi 16 | 17 | 18 | 19 | SOURCE_DIR=$1 20 | TARGET_DIR="build/$1" 21 | BASE_PATH=$(pwd) 22 | 23 | 24 | if [[ $2 != *"keep"* ]]; then 25 | echo "Build clean" 26 | rm -rf $TARGET_DIR 27 | fi 28 | 29 | 30 | mkdir -p $TARGET_DIR 31 | cp -a "../src/" "$TARGET_DIR/tests/" 32 | cp -a "$SOURCE_DIR/src" $TARGET_DIR 33 | cp -a "$SOURCE_DIR/tests" $TARGET_DIR 34 | cd $TARGET_DIR 35 | PATCH_FILE="../../$SOURCE_DIR.patch" 36 | if [ -e $PATCH_FILE ]; then 37 | patch -p0 < $PATCH_FILE 38 | fi 39 | 40 | if [[ $2 == *"assume_fail"* ]]; then 41 | echo "Fail run, assume build ok, running fails" 42 | make -f tests/tools/Makefile test_build 43 | for tb in tests/build/bin/*; do 44 | set +e 45 | ./$tb > /dev/null # Hide the output as its hidious since it fails 46 | ret=$? 47 | set -e 48 | if [[ $ret == 0 ]];then 49 | echo "The test '$tb' did not fail, though it should!" 50 | return 1 51 | fi; 52 | done; 53 | else 54 | echo "Normal run - assume all ok" 55 | make -f tests/tools/Makefile test_build test_run test_coverage 56 | fi 57 | 58 | echo "ALL DONE, BYE BYE" 59 | -------------------------------------------------------------------------------- /examples/example_blink_led/README.md: -------------------------------------------------------------------------------- 1 | This is simple example for mocking usage. 2 | 3 | There are some sources on the 'src' folder -- these would be the real implementations that provide the final image going to your microcontroller. 4 | 5 | On the tests folder on the other hand are the unittest sources -- these are the tests that you wrote to verify that the modules work as you wanted. And that they work together as you wanted. 6 | 7 | Unfortunately the example is quite small, and you dont really see the true gain of unittesting, you need bigger and more complex project for that, and those projects are then not good examples of the test framework. 8 | 9 | -------------------------------------------------------------------------------- /examples/example_blink_led/src/led.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "led.h" 3 | 4 | 5 | 6 | 7 | void StatusToLed::setup(int pin) 8 | { 9 | this->blink_loop = 0; 10 | this->blinks = 0; 11 | this->pin = pin; 12 | pinMode( pin, OUTPUT ); 13 | digitalWrite( pin, 1 ); 14 | this->timer.reset(); 15 | } 16 | 17 | void StatusToLed::loop() 18 | { 19 | int interval ; 20 | if ( this->blink_loop == 0 ) 21 | { 22 | interval = INTERVAL_LONG; 23 | } 24 | else 25 | { 26 | interval = INTERVAL_SHORT; 27 | } 28 | 29 | if ( this->timer.check( interval ) == false ) 30 | return; 31 | 32 | 33 | this->timer.reset(); 34 | this->blink_loop += 1; 35 | 36 | if ( this->blink_loop >= 2*blinks ) // accept equal, since with one blink we have long-short, long-short 37 | this->blink_loop = 0; 38 | 39 | digitalWrite( pin, (this->blink_loop&0x01) ); 40 | 41 | } 42 | 43 | void StatusToLed::set_status(int blinks) 44 | { 45 | if ( this->blinks == blinks ) 46 | return; 47 | 48 | this->blink_loop = 0; 49 | this->blinks = blinks; 50 | } 51 | -------------------------------------------------------------------------------- /examples/example_blink_led/src/led.h: -------------------------------------------------------------------------------- 1 | #ifndef SUPA_LED_H 2 | #define SUPA_LED_H 3 | 4 | #include "stimer.h" 5 | 6 | /** Simple class that will blink the led in cycle of long-short*N, where N is the status given */ 7 | class StatusToLed 8 | { 9 | constexpr static const int INTERVAL_SHORT = 200; // ms 10 | constexpr static const int INTERVAL_LONG = 700; 11 | 12 | public: 13 | void setup( int pin ); /// What @param pin to use for output 14 | void loop(); 15 | void set_status( int blinks ); 16 | 17 | protected: 18 | STimer timer; 19 | int pin; 20 | int blinks; 21 | int blink_loop; 22 | }; 23 | 24 | #endif 25 | -------------------------------------------------------------------------------- /examples/example_blink_led/src/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "led.h" 4 | 5 | StatusToLed LED; 6 | 7 | void setup() 8 | { 9 | LED.setup(13); 10 | LED.set_status( 2 ); 11 | } 12 | 13 | void loop() 14 | { 15 | LED.loop(); 16 | } -------------------------------------------------------------------------------- /examples/example_blink_led/src/stimer.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "stimer.h" 4 | 5 | STimer::STimer() 6 | { 7 | this->reset_ms = 0xFFFFFFFF; 8 | } 9 | 10 | void STimer::reset() 11 | { 12 | this->reset_ms = millis(); 13 | } 14 | 15 | void STimer::reset_with_carry( unsigned long timeout_ms ) 16 | { 17 | // this may overflow, but that is ok! 18 | this->reset_ms = reset_ms + timeout_ms ; 19 | } 20 | 21 | bool STimer::check(unsigned long timeout) 22 | { 23 | unsigned long target_time = this->reset_ms + timeout; 24 | unsigned long current_time = millis(); 25 | 26 | // has the current time overflown: 27 | if ( current_time < this->reset_ms ) 28 | { 29 | // did the target time overflowed 30 | if ( this->reset_ms < target_time ) 31 | { // no, it did not -> we are way over. 32 | return true; 33 | } 34 | else 35 | { // yes its overflown as well, normal functionality. 36 | return ( current_time >= target_time ); 37 | } 38 | } 39 | else 40 | { // timer has not overflown, how about the target? 41 | if ( this->reset_ms < target_time ) 42 | { // no overflow, here either. Normal business 43 | return ( current_time >= target_time ); 44 | } 45 | else 46 | { // the target is overflown, so must we. 47 | return false; 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /examples/example_blink_led/src/stimer.h: -------------------------------------------------------------------------------- 1 | #ifndef SUPA_STIMER_H 2 | #define SUPA_STIMER_H 3 | 4 | /** Simple timer class that provides overflow checks properly. All units are milliseconds.*/ 5 | class STimer 6 | { 7 | public: 8 | STimer(); 9 | void reset(); // Use all time 10 | void reset_with_carry( unsigned long timeout_ms ); // Use given amount of time 11 | bool check( unsigned long timeout_ms ); /// check if @param timeout_ms milliseconds has passed since last reset. 12 | 13 | protected: 14 | unsigned long reset_ms; 15 | }; 16 | 17 | 18 | 19 | #endif -------------------------------------------------------------------------------- /examples/example_blink_led/tests/mutest_led.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Unittesting file for led.h - since the functionality requires stuff from timer, lets not fake it but rather 3 | * use it as real unit 4 | * 5 | * __UNITTEST__SOURCES_ = led.cpp, stimer.cpp 6 | * 7 | */ 8 | 9 | #include "catch.hpp" 10 | #include "Arduino.h" 11 | 12 | #include "led.h" // include the unit under test 13 | #include "mock_stimer.h" // include the faked module (so we can set the return values) 14 | 15 | 16 | void run_loop( StatusToLed* led ) 17 | { 18 | for ( int loop = 0; loop < 100; loop ++ ) 19 | { 20 | led->loop(); 21 | } 22 | } 23 | 24 | 25 | TEST_CASE( "Led blinking works", "[led]" ) 26 | { 27 | StatusToLed led; 28 | 29 | ARDUINO_TEST.hookup(); 30 | led.setup( 10 ); 31 | led.set_status( 2 ); 32 | 33 | SECTION("It runs") 34 | { 35 | run_loop(&led); 36 | REQUIRE( digitalWrite_fake.call_count == 1); 37 | 38 | for ( int loop = 0; loop < 10; loop ++ ) 39 | { 40 | digitalWrite_fake.call_count = 0; 41 | millis_fake.return_val += 1000; // lets fake this timer so that the STimer class will work as it would if 1sec has passed since last call. 42 | run_loop(&led); 43 | REQUIRE( digitalWrite_fake.call_count == 1); 44 | REQUIRE( digitalWrite_fake.arg1_val == ((loop+1)%2) ); 45 | } 46 | } 47 | } 48 | 49 | -------------------------------------------------------------------------------- /examples/example_blink_led/tests/test_led.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Unittest file for led. In this example the timer module is faked also 3 | * 4 | */ 5 | 6 | 7 | #include "catch.hpp" 8 | #include "Arduino.h" 9 | 10 | #include "led.h" // include the unit under test 11 | #include "mock_stimer.h" // include the faked module (so we can set the return values) 12 | 13 | 14 | 15 | void run_loop( StatusToLed* led ) 16 | { 17 | for ( int loop = 0; loop < 100; loop ++ ) 18 | { 19 | led->loop(); 20 | } 21 | } 22 | 23 | 24 | TEST_CASE( "Led blinking works", "[led]" ) 25 | { 26 | StatusToLed led; 27 | 28 | ARDUINO_TEST.hookup(); 29 | led.setup( 10 ); 30 | STimer__check_fake.return_val = false; 31 | 32 | SECTION("It runs") 33 | { 34 | run_loop(&led); 35 | REQUIRE( digitalWrite_fake.call_count == 1); 36 | 37 | digitalWrite_fake.call_count = 0; 38 | STimer__check_fake.return_val = true; // after this it will appear to module as the time would be changing always 39 | run_loop(&led); 40 | REQUIRE( digitalWrite_fake.call_count == 100); 41 | } 42 | } -------------------------------------------------------------------------------- /examples/example_blink_led/tests/test_main.cpp: -------------------------------------------------------------------------------- 1 | #include "catch.hpp" 2 | 3 | 4 | #include "mock_led.h" // get the led _fake objects 5 | 6 | 7 | extern void loop(); 8 | extern void setup(); 9 | 10 | TEST_CASE( "Main functionality", "[main]" ) 11 | { 12 | 13 | RESET_FAKE( StatusToLed__setup ); 14 | RESET_FAKE( StatusToLed__loop ); 15 | RESET_FAKE( StatusToLed__set_status ); 16 | setup(); 17 | loop(); 18 | loop(); 19 | loop(); 20 | REQUIRE( StatusToLed__loop_fake.call_count == 3); 21 | 22 | } -------------------------------------------------------------------------------- /examples/example_blink_led/tests/test_stimer.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Example testcase for timer module, checks that overflow functionality is what it expected to be 3 | * 4 | * All other modules are mocked. 5 | * 6 | */ 7 | 8 | #include 9 | #include "catch.hpp" 10 | #include "Arduino.h" 11 | 12 | 13 | #include "stimer.h" 14 | 15 | TEST_CASE( "Basic timer", "[timer]" ) 16 | { 17 | STimer timer; 18 | 19 | SECTION("Normal functionality") 20 | { 21 | REQUIRE( timer.check( 10 ) == true ); 22 | timer.reset(); 23 | REQUIRE( timer.check( 10 ) == false ); 24 | millis_fake.return_val = 1 << 10; 25 | REQUIRE( timer.check( 10 ) == true ); 26 | } 27 | SECTION("Overflow") 28 | { 29 | 30 | STimer timer; 31 | 32 | millis_fake.return_val = ULLONG_MAX - 1; 33 | timer.reset(); 34 | millis_fake.return_val = ULLONG_MAX ; 35 | REQUIRE( timer.check( 10 ) == false ); 36 | 37 | millis_fake.return_val = 10 ; 38 | REQUIRE( timer.check( 10 ) == true ); 39 | 40 | // Then check the case where we wait for whole timer to go all around over. 41 | timer.reset(); 42 | REQUIRE( timer.check( 10 ) == false ); 43 | millis_fake.return_val = 0 ; // and the timer went overflow without any check in between 44 | REQUIRE( timer.check( 10 ) == true ); 45 | } 46 | 47 | SECTION("Reset with carry normal and overflow") 48 | { 49 | const unsigned long start_value[] = { 0, ULLONG_MAX - 500 }; 50 | for ( int start_loop = 0; start_loop < 2; start_loop ++ ) 51 | { 52 | millis_fake.return_val = start_value[start_loop]; 53 | timer.reset(); 54 | REQUIRE( timer.check( 500 ) == false); 55 | millis_fake.return_val += 1000; 56 | REQUIRE( timer.check( 500 ) == true ); 57 | timer.reset_with_carry(450); 58 | REQUIRE( timer.check( 500 ) == true ); 59 | } 60 | } 61 | 62 | }; 63 | 64 | -------------------------------------------------------------------------------- /examples/example_serial_comm/src/serial.cpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | #include 5 | #include "serial.h" 6 | 7 | void serial_setup( const char* name) 8 | { 9 | Serial.begin(9600); 10 | Serial.write("***********************************\n"); 11 | serial_print("Welcome to %s!\n", name ); 12 | Serial.write("***********************************\n"); 13 | } 14 | 15 | 16 | bool serial_process_number( const char* buffer, int buffer_n, int* convert ) 17 | { 18 | char* end_ptr; 19 | int value = strtol( (const char*)buffer, &end_ptr, 10 ); 20 | if ( end_ptr != buffer + buffer_n ) 21 | { 22 | Serial.write(" Invalid: '" ); 23 | Serial.write( (const char*)buffer); 24 | Serial.write("'\n"); 25 | return false; 26 | } 27 | 28 | *convert = value; 29 | return true; 30 | } 31 | 32 | 33 | 34 | char* serial_receive( int* buffer_len) 35 | { 36 | static char buffer[64]; 37 | static int loop = 0; 38 | 39 | // if there's any serial available, read it: 40 | while (Serial.available() > 0) 41 | { 42 | // do it again: 43 | buffer[ loop ] = Serial.read(); 44 | 45 | if (buffer[loop] == '\n') 46 | { 47 | buffer[loop] = 0x00; 48 | *buffer_len = loop; 49 | loop = 0; 50 | return buffer; 51 | } 52 | 53 | loop += 1; 54 | if ( loop >= 64 ) 55 | { 56 | Serial.write("E: Too long line\n"); 57 | loop = 0; 58 | } 59 | 60 | } 61 | return NULL; 62 | } 63 | 64 | void serial_print( const char* format, ... ) 65 | { 66 | va_list arg_list; 67 | va_start(arg_list, format); 68 | 69 | int full_len = strlen( format ); 70 | int loop_offset = 0; 71 | int loop; 72 | 73 | for ( loop = 0; loop < full_len; loop ++ ) 74 | { 75 | if (format[loop] != '%') 76 | continue; 77 | 78 | Serial.write( format + loop_offset, loop - loop_offset ); 79 | 80 | if ( full_len == loop -1 ) 81 | { 82 | Serial.write("\ERROR_INVALID_FORMAT\n"); 83 | return; 84 | } 85 | 86 | char output_format = format[loop + 1]; 87 | if (output_format == 'd' ) 88 | { 89 | int value = va_arg( arg_list, int ); 90 | Serial.print( value ); 91 | } 92 | else if (output_format == 'f' ) 93 | { 94 | double value = va_arg( arg_list, double ); 95 | Serial.print( value ); 96 | } 97 | else if (output_format == 's' ) 98 | { 99 | const char* value = (char*)va_arg( arg_list, void* ); 100 | Serial.write( value ); 101 | } 102 | else 103 | { 104 | Serial.write("\ERROR_INVALID_FORMAT\n"); 105 | return; 106 | } 107 | // now loop is '%d' 108 | loop_offset = loop + 2; 109 | loop = loop + 1; 110 | } 111 | va_end(arg_list); 112 | if ( loop_offset < full_len ) 113 | { 114 | Serial.write( format + loop_offset, full_len - loop_offset ); 115 | } 116 | 117 | } 118 | 119 | int serial_receive_number(int min_value, int max_value) 120 | { 121 | int buffer_n = 0; 122 | int number; 123 | bool print_prompt = true; 124 | 125 | while ( true ) 126 | { 127 | if ( print_prompt ) 128 | { 129 | Serial.write(">"); 130 | print_prompt = false; 131 | } 132 | char* buffer = serial_receive( &buffer_n ); 133 | if ( buffer == NULL ) 134 | continue; 135 | Serial.write(buffer); 136 | Serial.write("\n"); 137 | 138 | print_prompt = true; 139 | 140 | if ( serial_process_number( buffer, buffer_n , &number ) == false ) 141 | continue; 142 | 143 | if ( number < min_value ) 144 | { 145 | serial_print("Too small number. Min %d\n", min_value ); 146 | } 147 | else if ( number > max_value) 148 | { 149 | serial_print("Too large number. Max %d\n", max_value ); 150 | } 151 | else 152 | { 153 | return number; 154 | } 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /examples/example_serial_comm/src/serial.h: -------------------------------------------------------------------------------- 1 | #ifndef SUPA_SERIAL_H 2 | #define SUPA_SERIAL_H 3 | 4 | #include 5 | 6 | void serial_setup(const char* name); 7 | 8 | char* serial_receive( int* buffer_len ); 9 | int serial_receive_number( int min_value, int max_value ); 10 | void serial_print( const char* format, ... ); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /examples/example_serial_comm/tests/test_serial.cpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include "catch.hpp" 4 | #include "serial.h" 5 | #include 6 | #include 7 | 8 | #include "Arduino.h" 9 | 10 | TEST_CASE( "input works", "[serial]" ) 11 | { 12 | Serial._test_clear(); 13 | 14 | SECTION( "plain string" ) 15 | { 16 | int len = 0; 17 | char* buffer; 18 | buffer = serial_receive( &len ); 19 | REQUIRE( buffer == NULL ); 20 | Serial._test_set_input("setup"); 21 | buffer = serial_receive( &len ); 22 | REQUIRE( buffer == NULL ); 23 | Serial._test_set_input("\n"); 24 | buffer = serial_receive( &len ); 25 | REQUIRE( buffer != NULL ); 26 | REQUIRE( len == 5 ); 27 | REQUIRE( std::string(buffer) == "setup" ); 28 | } 29 | 30 | SECTION( "input number" ) 31 | { 32 | Serial._test_set_input("Invalid\n-3\n10\n0\n"); 33 | int ret = serial_receive_number( -1, 1); 34 | REQUIRE( ret == 0 ); 35 | REQUIRE( Serial._test_input_buffer.size() == 0 ); 36 | } 37 | } 38 | 39 | TEST_CASE( "prints some format", "[serial]" ) 40 | { 41 | Serial._test_clear(); 42 | REQUIRE( Serial._test_output_buffer.size() == 0); 43 | 44 | 45 | SECTION( "strings" ) 46 | { 47 | serial_print("HELLO %s\n", "WORLD"); 48 | REQUIRE( Serial._test_output_buffer.size() == 1); 49 | REQUIRE( Serial._test_output_buffer.front() == "HELLO WORLD" ); 50 | } 51 | 52 | SECTION("setup") 53 | { 54 | serial_setup( "TESTFOO" ); 55 | } 56 | 57 | SECTION( "ints and floats" ) 58 | { 59 | 60 | serial_print("FOO %d %d\nMORE", 10, 20 ); 61 | REQUIRE( Serial._test_output_buffer.size() == 1); 62 | REQUIRE( Serial._test_output_buffer.front() == "FOO 10 20" ); 63 | Serial._test_output_buffer.pop(); 64 | serial_print(" %f TOCOME\n", 0.15 ); 65 | REQUIRE( Serial._test_output_buffer.size() == 1); 66 | REQUIRE( Serial._test_output_buffer.front() == "MORE 0.15 TOCOME" ); 67 | } 68 | 69 | } 70 | -------------------------------------------------------------------------------- /examples/test_fail_pinwrite/src/main.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Simple test to check that invalid pinwrite does fail 3 | */ 4 | 5 | #include 6 | 7 | 8 | void setup() 9 | { 10 | pinMode( 12, OUTPUT ); 11 | } 12 | 13 | void loop() 14 | { 15 | digitalWrite( 10, 1 ); 16 | delay(1000); 17 | digitalWrite( 10, 0 ); 18 | } 19 | 20 | -------------------------------------------------------------------------------- /examples/test_fail_pinwrite/tests/test_main.cpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | #include "catch.hpp" 5 | 6 | #include "Arduino.h" 7 | 8 | extern void loop(); 9 | extern void setup(); 10 | 11 | TEST_CASE( "__ASSUME__FAIL__", "[main]" ) 12 | { 13 | ARDUINO_TEST.hookup(); 14 | setup(); 15 | loop(); 16 | 17 | } 18 | 19 | -------------------------------------------------------------------------------- /examples/test_ok_extra_fake/src/main.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "moduleT.h" 3 | 4 | void loop() 5 | { 6 | MODULE_T.hello_world(); 7 | } 8 | -------------------------------------------------------------------------------- /examples/test_ok_extra_fake/src/moduleT.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | class ModuleT 5 | { 6 | public: 7 | void hello_world(); 8 | }; 9 | 10 | extern ModuleT MODULE_T; -------------------------------------------------------------------------------- /examples/test_ok_extra_fake/tests/common_globals.cpp: -------------------------------------------------------------------------------- 1 | /** Here we declare global variables that would be otherwise be left 2 | * as linker errors, as generated fakes do not contain these global variables. 3 | * 4 | */ 5 | 6 | #include "moduleT.h" 7 | 8 | 9 | ModuleT MODULE_T; -------------------------------------------------------------------------------- /examples/test_ok_extra_fake/tests/test_main.cpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | #include "catch.hpp" 5 | 6 | extern void loop(); 7 | 8 | TEST_CASE( "__ASSUME__OK__", "[main]" ) 9 | { 10 | loop(); 11 | } 12 | 13 | -------------------------------------------------------------------------------- /examples/test_ok_mock_manual/src/drive.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "moduleT.h" 3 | #include "moduleX.h" 4 | 5 | void loop() 6 | { 7 | ModuleT modt; 8 | ModuleX modx; 9 | modx.hello_world(); 10 | 11 | for ( int loop = 0; loop < 1000; loop ++ ) 12 | { 13 | modt.drive_step_motor( loop ); 14 | } 15 | 16 | } 17 | -------------------------------------------------------------------------------- /examples/test_ok_mock_manual/src/moduleT.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | class ModuleT 5 | { 6 | public: 7 | void drive_step_motor( int steps ); 8 | }; 9 | -------------------------------------------------------------------------------- /examples/test_ok_mock_manual/src/moduleX.h: -------------------------------------------------------------------------------- 1 | /** 2 | * We assume this file to mocked in normal way. 3 | * 4 | */ 5 | class ModuleX 6 | { 7 | public: 8 | void hello_world(); 9 | }; 10 | -------------------------------------------------------------------------------- /examples/test_ok_mock_manual/tests/mocks_man/mock_moduleT.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include "moduleT.h" 4 | 5 | 6 | static FILE* OUTPUT_FID = NULL; 7 | 8 | int _test_moduleT_init( const char* filename ) 9 | { 10 | OUTPUT_FID = fopen( filename, "wb" ); 11 | if ( OUTPUT_FID == NULL ) 12 | return -1; 13 | return 100; 14 | } 15 | 16 | void _test_moduleT_close( ) 17 | { 18 | fclose( OUTPUT_FID ); 19 | OUTPUT_FID = NULL; 20 | } 21 | 22 | 23 | void ModuleT::drive_step_motor( int steps ) 24 | { 25 | fprintf( OUTPUT_FID, "%d %d\n", (int)millis_fake.return_val, steps ); 26 | } -------------------------------------------------------------------------------- /examples/test_ok_mock_manual/tests/mocks_man/mock_moduleT.h: -------------------------------------------------------------------------------- 1 | #include "moduleT.h" 2 | 3 | int _test_moduleT_init( const char* outfile ); 4 | void _test_moduleT_close(); 5 | 6 | 7 | -------------------------------------------------------------------------------- /examples/test_ok_mock_manual/tests/test_drive.cpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | #include "catch.hpp" 5 | 6 | #include "mock_moduleT.h" 7 | 8 | extern void loop(); 9 | 10 | TEST_CASE( "Manual mocks functions are usable", "[main]" ) 11 | { 12 | REQUIRE( _test_moduleT_init("./output_steps.txt") == 100 ); 13 | loop(); 14 | _test_moduleT_close(); 15 | } 16 | 17 | -------------------------------------------------------------------------------- /run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd examples 3 | rm -rf build 4 | 5 | set -e 6 | 7 | for fn in example_*; do 8 | if [ -d $fn ]; then 9 | ./build.sh $fn normal 10 | fi 11 | done 12 | 13 | for fn in test_fail_*; do 14 | if [ -d $fn ]; then 15 | ./build.sh $fn assume_fail 16 | fi 17 | done 18 | 19 | for fn in test_ok_*; do 20 | if [ -d $fn ]; then 21 | ./build.sh $fn normal 22 | fi 23 | done 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | **/*.pyc 3 | .scons* 4 | -------------------------------------------------------------------------------- /src/SConscript: -------------------------------------------------------------------------------- 1 | # 2 | # This is building script for the unittests. 3 | # 4 | # 5 | 6 | # What libraries and flags you want 7 | libraries = ['gcov'] 8 | cppFlags = ['-Wall', "-O0", "-g", "-fprofile-arcs","-ftest-coverage", "-std=c++11" ]#, '-Werror'] 9 | 10 | 11 | ENV = Environment() 12 | ENV.Append(LIBS = libraries) 13 | ENV.Append(CPPFLAGS = cppFlags) 14 | ENV.Append(CPPPATH = ["tools/", "../src/","fakes"] ) 15 | ENV.Append(LINKFLAGS="--coverage") 16 | 17 | ENV.VariantDir('build/src', '../src', duplicate=0) 18 | ENV.VariantDir('build/mocks_gen', 'mocks_gen', duplicate=0) 19 | ENV.VariantDir('build/mocks_man', 'mocks_man', duplicate=0) 20 | ENV.VariantDir('build/fakes', 'fakes', duplicate=0) 21 | ENV.VariantDir('build/root', './', duplicate=0) 22 | 23 | ENV_TESTBIN = ENV.Clone() 24 | ENV_TESTBIN.Append( CPPPATH = ("mocks_gen", "mocks_man", "fakes", ), ) 25 | 26 | 27 | # Scan for files 28 | import os 29 | import re 30 | 31 | 32 | 33 | 34 | def fn_replace( what, with_what, fns ): 35 | ret = [] 36 | for fn in sorted(fns): 37 | fn = fn.replace(what, with_what) 38 | fn = fn.split(".")[0] + ".cpp" 39 | ret.append(fn) 40 | return ret 41 | 42 | 43 | def build_mock_lib( env ): 44 | 45 | # These are manually mocked functions, we do not mock those 46 | mocks_manual = Glob( "./mocks_man/mock_*.cpp" ) 47 | # These are common functions, we always include those 48 | mocks_extra = Glob( "./common_*.cpp" ) 49 | # These are all headers that could have mocks, we reduce manual mocks from them 50 | all_headers = Glob( "../src/*.h" ) 51 | 52 | def convert_filename( fn, old_prefix, new_prefix, old_postfix, new_postfix ): 53 | fn = os.path.basename( fn ) 54 | if not fn.startswith( old_prefix ): 55 | raise Exception("Trying to convert file '%s' that does not start with '%s' as assumed." % (fn, old_prefix ) ) 56 | 57 | if not fn.endswith( old_postfix): 58 | raise Exception("Trying to convert file '%s' that does not endwith '%s' as assumed." % (fn, old_postfix ) ) 59 | 60 | def rreplace(s, old, new, occurrence): 61 | li = s.rsplit(old, occurrence) 62 | return new.join(li) 63 | 64 | fn = fn.replace( old_prefix, new_prefix, 1) 65 | fn = rreplace( fn, old_postfix, new_postfix, 1) 66 | return fn 67 | 68 | 69 | mocks_manual_sources = { convert_filename( str(x), "mock_", "", ".cpp", "" ) for x in mocks_manual } 70 | mocks_generated_sources = { convert_filename( str(x), "", "", ".h", "" ) for x in all_headers } 71 | # And then do not generate mocks for manually mocked files 72 | mocks_generated_sources -= mocks_manual_sources 73 | 74 | from tools.fffmock import generate_mocks 75 | 76 | def generate_mock( source, target, env ): 77 | assert( len(source) == 1) 78 | assert( len(target) == 2) 79 | source_fn = source[0].get_abspath() 80 | target_c_fn = target[0].get_abspath() 81 | target_h_fn = target[1].get_abspath() 82 | generate_mocks( source_fn, target_h_fn, target_c_fn ) 83 | 84 | 85 | for source in mocks_generated_sources: 86 | mock_name_c = "mocks_gen/mock_%s.cpp" % source 87 | mock_name_h = "mocks_gen/mock_%s.h" % source 88 | source_header = "../src/%s.h" % source 89 | t = env.Command(target=(mock_name_c, mock_name_h), source=source_header, action=generate_mock) 90 | env.Depends( t , "tools/fffmock.py" ) 91 | env.Depends( "build/" + mock_name_c , t) #Sconscrit does not understand at time of writing this dependency -> mock does not get regenerated if this is not here. 92 | env.Depends( "build/" + mock_name_h , t) 93 | 94 | 95 | mocklib_sources = [ "build/mocks_gen/mock_%s.cpp" % x for x in mocks_generated_sources ] 96 | mocklib_sources += [ "build/mocks_man/mock_%s.cpp" % x for x in mocks_manual_sources ] 97 | mocklib_sources += [ "build/root/%s" % x for x in mocks_extra ] 98 | 99 | mock_lib = env.Library( "build/build/mocks", mocklib_sources, ) 100 | return mock_lib 101 | 102 | def build_fake_lib( env ): 103 | source_files = Glob('fakes/*.c') + Glob('fakes/*.cpp') 104 | source_files = [ str(x) for x in source_files ] 105 | source_files = fn_replace("fakes/", "build/fakes/", source_files ) 106 | return env.Library( "build/build/fakes", source_files ) 107 | 108 | 109 | def build_test_lib( env ): 110 | return env.Library( "build/build/tests", "build/root/tmain.cpp" ) 111 | 112 | 113 | def build_test_binaries_single( source ): 114 | real_file = "../src/" + source.replace("test_","") 115 | sources = ["build/root/" + source, "build/src/" + real_file ] 116 | return sources 117 | 118 | def build_test_binaries_multi( source ): 119 | sources = ["build/root/" + source ] 120 | with open( source ) as fid: 121 | content = fid.read() 122 | for line in content.split( '\n' ): 123 | if "__UNITTEST__SOURCES_" in line: 124 | m = re.match( r".*__UNITTEST__SOURCES_\s*=\s*(.*)", line ) 125 | if m == None: 126 | raise Exception("File '%s' contains magic keyword __UNITTEST__SOURCES_, but the line is not matching our regexp." % source ) 127 | for item in m.group(1).split(","): 128 | item = item.strip() 129 | sources.append( "build/src/" + item ) 130 | return sources 131 | raise Exception("File '%s' is named as multi-unit test but magic keyword '__UNITTEST__SOURCES_' is missing from the file content!" % source ) 132 | 133 | def build_test_binaries( env, libsources, prefix, process_unit ): 134 | progs = [] 135 | for source in [str(x) for x in Glob("%s_*.cpp" % prefix ) ]: 136 | sources = process_unit(source) 137 | basename = source.split(".")[0] 138 | unitname = basename.split("_",1)[1] 139 | bin_name = "build/bin/%s_%s" % (prefix,unitname) 140 | sources.extend( libsources ) 141 | prog = env.Program( bin_name, sources ) 142 | progs.append(prog) 143 | return progs 144 | 145 | 146 | 147 | MOCK_LIB = build_mock_lib(ENV) 148 | FAKE_LIB = build_fake_lib(ENV) 149 | TEST_LIB = build_test_lib(ENV) 150 | 151 | 152 | build_test_binaries( ENV_TESTBIN, ( TEST_LIB, MOCK_LIB, FAKE_LIB, ), "test", build_test_binaries_single ) 153 | build_test_binaries( ENV_TESTBIN, ( TEST_LIB, MOCK_LIB, FAKE_LIB, ), "mutest", build_test_binaries_multi ) 154 | 155 | 156 | -------------------------------------------------------------------------------- /src/build/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/susundberg/arduino-simple-unittest/ed72c17cd8eea76fb51ed22ceb8a5b7a29daa4a1/src/build/.gitkeep -------------------------------------------------------------------------------- /src/fakes/Arduino.cpp: -------------------------------------------------------------------------------- 1 | #include "Arduino.h" 2 | 3 | #include "catch.hpp" 4 | 5 | 6 | DEFINE_FAKE_VOID_FUNC( pinMode, uint8_t, uint8_t ); 7 | DEFINE_FAKE_VOID_FUNC( digitalWrite, uint8_t, uint8_t ); 8 | DEFINE_FAKE_VALUE_FUNC( int, digitalRead, uint8_t ); 9 | DEFINE_FAKE_VALUE_FUNC( int, analogRead, uint8_t ); 10 | DEFINE_FAKE_VOID_FUNC( analogWrite, uint8_t, int ); 11 | DEFINE_FAKE_VOID_FUNC( delay, unsigned long ); 12 | DEFINE_FAKE_VALUE_FUNC( unsigned long, millis ); 13 | 14 | 15 | 16 | Arduino_TEST ARDUINO_TEST; 17 | 18 | 19 | void _arduino_test_pinMode( uint8_t pin, uint8_t mode ) 20 | { 21 | ARDUINO_TEST.check_pin( pin ); 22 | ARDUINO_TEST.pin_mode[ pin ] = mode; 23 | } 24 | 25 | void _arduino_test_digitalWrite( uint8_t pin, uint8_t value) 26 | { 27 | ARDUINO_TEST.check_write( pin ); 28 | ARDUINO_TEST.pin_value[ pin ] = (value != 0); 29 | } 30 | 31 | int _arduino_test_digitalRead( uint8_t pin ) 32 | { 33 | ARDUINO_TEST.check_read( pin ); 34 | return (ARDUINO_TEST.pin_value[ pin ] != 0); 35 | } 36 | 37 | int _arduino_test_analogRead( uint8_t pin ) 38 | { 39 | ARDUINO_TEST.check_read( pin ); 40 | return (ARDUINO_TEST.pin_value[ pin ]); 41 | } 42 | 43 | void _arduino_test_analogWrite( uint8_t pin, int value) 44 | { 45 | ARDUINO_TEST.check_write( pin ); 46 | ARDUINO_TEST.pin_value[ pin ] = value; 47 | } 48 | 49 | void Arduino_TEST::set_mode( Arduino_TEST::Check_mode mode ) 50 | { 51 | this->check_mode = mode; 52 | } 53 | 54 | void Arduino_TEST::check_pin( uint8_t pin ) 55 | { 56 | int max_pins = Arduino_TEST::MAX_PINS; 57 | REQUIRE( pin < max_pins ); 58 | } 59 | 60 | void Arduino_TEST::check_read( uint8_t pin ) 61 | { 62 | check_pin( pin ); 63 | 64 | if ( this->check_mode != Arduino_TEST::Check_mode::Full ) 65 | return; 66 | 67 | bool valid_input = (this->pin_mode[ pin ] == INPUT) || (this->pin_mode[ pin ] == INPUT_PULLUP ); 68 | REQUIRE( valid_input == true ); 69 | 70 | } 71 | 72 | void Arduino_TEST::check_write( uint8_t pin ) 73 | { 74 | check_pin( pin ); 75 | 76 | if ( this->check_mode != Arduino_TEST::Check_mode::Full ) 77 | return; 78 | 79 | REQUIRE( this->pin_mode[ pin ] == OUTPUT ); 80 | 81 | } 82 | 83 | 84 | 85 | void Arduino_TEST::hookup() 86 | { 87 | memset( ARDUINO_TEST.pin_value, 0x00, sizeof(ARDUINO_TEST.pin_value)); 88 | memset( ARDUINO_TEST.pin_mode, 0xFF, sizeof(ARDUINO_TEST.pin_mode)); 89 | this->check_mode = Arduino_TEST::Check_mode::Full; 90 | 91 | RESET_FAKE( pinMode ); 92 | RESET_FAKE( digitalWrite ); 93 | RESET_FAKE( digitalRead ); 94 | RESET_FAKE( analogRead ); 95 | RESET_FAKE( analogWrite ); 96 | 97 | pinMode_fake.custom_fake = _arduino_test_pinMode; 98 | digitalWrite_fake.custom_fake = _arduino_test_digitalWrite; 99 | digitalRead_fake.custom_fake = _arduino_test_digitalRead; 100 | analogRead_fake.custom_fake = _arduino_test_analogRead; 101 | analogWrite_fake.custom_fake = _arduino_test_analogWrite; 102 | 103 | } 104 | void Arduino_TEST::hookdown() 105 | { 106 | pinMode_fake.custom_fake = NULL; 107 | digitalWrite_fake.custom_fake = NULL; 108 | digitalRead_fake.custom_fake = NULL; 109 | analogRead_fake.custom_fake = NULL; 110 | analogWrite_fake.custom_fake = NULL; 111 | } 112 | -------------------------------------------------------------------------------- /src/fakes/Arduino.h: -------------------------------------------------------------------------------- 1 | #ifndef FAKE_Arduino_h 2 | #define FAKE_Arduino_h 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "fff.h" 11 | 12 | #define HIGH 0x1 13 | #define LOW 0x0 14 | 15 | #define INPUT 0x0 16 | #define OUTPUT 0x1 17 | #define INPUT_PULLUP 0x2 18 | 19 | #define PI 3.1415926535897932384626433832795 20 | #define HALF_PI 1.5707963267948966192313216916398 21 | #define TWO_PI 6.283185307179586476925286766559 22 | #define DEG_TO_RAD 0.017453292519943295769236907684886 23 | #define RAD_TO_DEG 57.295779513082320876798154814105 24 | #define EULER 2.718281828459045235360287471352 25 | 26 | #define SERIAL 0x0 27 | #define DISPLAY 0x1 28 | 29 | #define LSBFIRST 0 30 | #define MSBFIRST 1 31 | 32 | #define CHANGE 1 33 | #define FALLING 2 34 | #define RISING 3 35 | 36 | #define INTERNAL 3 37 | #define DEFAULT 1 38 | #define EXTERNAL 0 39 | 40 | 41 | #define constrain(amt,low,high) ((amt)<(low)?(low):((amt)>(high)?(high):(amt))) 42 | #define round(x) ((x)>=0?(long)((x)+0.5):(long)((x)-0.5)) 43 | #define radians(deg) ((deg)*DEG_TO_RAD) 44 | #define degrees(rad) ((rad)*RAD_TO_DEG) 45 | #define sq(x) ((x)*(x)) 46 | 47 | #define lowByte(w) ((uint8_t) ((w) & 0xff)) 48 | #define highByte(w) ((uint8_t) ((w) >> 8)) 49 | 50 | #define bitRead(value, bit) (((value) >> (bit)) & 0x01) 51 | #define bitSet(value, bit) ((value) |= (1UL << (bit))) 52 | #define bitClear(value, bit) ((value) &= ~(1UL << (bit))) 53 | #define bitWrite(value, bit, bitvalue) (bitvalue ? bitSet(value, bit) : bitClear(value, bit)) 54 | #define bit(b) (1UL << (b)) 55 | 56 | typedef unsigned int word; 57 | typedef bool boolean; 58 | typedef uint8_t byte; 59 | 60 | 61 | void pinMode(uint8_t, uint8_t); 62 | void digitalWrite(uint8_t, uint8_t); 63 | int digitalRead(uint8_t); 64 | int analogRead(uint8_t); 65 | unsigned long millis(); 66 | void analogWrite(uint8_t, int); 67 | void delay(unsigned long); 68 | 69 | DECLARE_FAKE_VOID_FUNC( pinMode, uint8_t, uint8_t ); 70 | DECLARE_FAKE_VOID_FUNC( digitalWrite, uint8_t, uint8_t ); 71 | DECLARE_FAKE_VALUE_FUNC( int, digitalRead, uint8_t ); 72 | DECLARE_FAKE_VALUE_FUNC( unsigned long, millis ); 73 | DECLARE_FAKE_VALUE_FUNC( int, analogRead, uint8_t ); 74 | DECLARE_FAKE_VOID_FUNC( analogWrite, uint8_t, int ); 75 | DECLARE_FAKE_VOID_FUNC( delay, unsigned long ); 76 | 77 | 78 | 79 | #include 80 | #include 81 | #include 82 | 83 | 84 | class Serial_CLS 85 | { 86 | 87 | typedef std::queue Buffer; 88 | 89 | public: 90 | void write( const char* buffer, int buffer_n ); 91 | void write( const char* buffer ); 92 | void print( int value ); 93 | void print( double value ); 94 | void begin( int baudrate ); 95 | int available(); 96 | char read(); 97 | 98 | // any printing will be appended to this vector 99 | Buffer _test_output_buffer; 100 | std::string _test_output_current; // current output line 101 | std::queue _test_input_buffer; 102 | 103 | void _test_clear(); // remove and reset everything from the buffers 104 | void _test_set_input( const char* what ); 105 | protected: 106 | void _test_output_string( std::string what ); 107 | }; 108 | 109 | 110 | class Arduino_TEST 111 | { 112 | public: 113 | enum class Check_mode { Full, None }; // Mode FULL for all checks (check that digital write is output and digital read is input), Defaults to FULL 114 | constexpr static const int MAX_PINS = 128; 115 | 116 | void hookup(); // Reset and hook the arduino functions (digitalRead, digitalWrite, pinMode, analogRead, analogWrite) 117 | void hookdown(); // clear all values and custom hookups. 118 | void set_mode( Check_mode target ); 119 | void check_write(uint8_t pin); 120 | void check_read(uint8_t pin); 121 | void check_pin(uint8_t pin); 122 | int pin_value[ MAX_PINS ]; 123 | uint8_t pin_mode [ MAX_PINS ]; 124 | 125 | Check_mode check_mode; 126 | }; 127 | 128 | extern Arduino_TEST ARDUINO_TEST; 129 | 130 | extern Serial_CLS Serial; 131 | 132 | typedef std::string String ; 133 | 134 | static const int A0 = 100; 135 | 136 | #define LED_BUILTIN 13 137 | 138 | /** TODO 139 | // undefine stdlib's abs if encountered 140 | // #ifdef abs 141 | // #undef abs 142 | // #endif 143 | // 144 | // #ifdef max 145 | // #undef max 146 | // #endif 147 | // 148 | // #ifdef min 149 | // #undef min 150 | // #endif 151 | // 152 | // #define min(a,b) ((a)<(b)?(a):(b)) 153 | // #define max(a,b) ((a)>(b)?(a):(b)) 154 | // #define abs(x) ((x)>0?(x):-(x)) 155 | // These seems to cause troubles with queue or string include -- one should figure out why and do we need these. 156 | 157 | void analogReference(uint8_t mode); 158 | 159 | unsigned long millis(void); 160 | unsigned long micros(void); 161 | void delayMicroseconds(unsigned int us); 162 | unsigned long pulseIn(uint8_t pin, uint8_t state, unsigned long timeout); 163 | unsigned long pulseInLong(uint8_t pin, uint8_t state, unsigned long timeout); 164 | 165 | void shiftOut(uint8_t dataPin, uint8_t clockPin, uint8_t bitOrder, uint8_t val); 166 | uint8_t shiftIn(uint8_t dataPin, uint8_t clockPin, uint8_t bitOrder); 167 | 168 | void attachInterrupt(uint8_t, void (*)(void), int mode); 169 | void detachInterrupt(uint8_t); 170 | 171 | 172 | // WMath prototypes 173 | long random(long); 174 | long random(long, long); 175 | void randomSeed(unsigned long); 176 | long map(long, long, long, long, long); 177 | 178 | */ 179 | 180 | 181 | #endif 182 | -------------------------------------------------------------------------------- /src/fakes/Serial.cpp: -------------------------------------------------------------------------------- 1 | #include "Arduino.h" 2 | #include 3 | #include 4 | 5 | Serial_CLS Serial; 6 | 7 | void Serial_CLS::write( const char* buffer, int buffer_n ) 8 | { 9 | 10 | this->_test_output_string( std::string( buffer, buffer_n ) ); 11 | } 12 | 13 | void Serial_CLS::_test_set_input( const char* what ) 14 | { 15 | this->_test_input_buffer = std::queue(); 16 | for ( int loop = 0; loop < (int)strlen(what); loop ++ ) 17 | { 18 | this->_test_input_buffer.push( what[loop] ); 19 | } 20 | } 21 | 22 | void Serial_CLS::_test_clear() 23 | { 24 | this->_test_output_current = std::string(); 25 | this->_test_output_buffer = std::queue(); 26 | this->_test_input_buffer = std::queue(); 27 | } 28 | 29 | void Serial_CLS::_test_output_string( std::string what ) 30 | { 31 | 32 | this->_test_output_current += what; 33 | 34 | std::size_t index = this->_test_output_current.find( '\n' ); 35 | if ( index == std::string::npos ) 36 | return; 37 | 38 | std::string tobuf = this->_test_output_current.substr( 0, index ); 39 | if ( index + 1 >= this->_test_output_current.length() ) 40 | { 41 | this->_test_output_current = std::string(); 42 | } 43 | else 44 | { 45 | this->_test_output_current = this->_test_output_current.substr( index + 1, std::string::npos ); 46 | } 47 | this->_test_output_buffer.push( tobuf ); 48 | } 49 | 50 | 51 | void Serial_CLS::write( const char* buffer ) 52 | { 53 | this->write( buffer, strlen( buffer ) ); 54 | } 55 | 56 | template std::string tostr(const T& t) { 57 | std::ostringstream os; 58 | os<_test_output_string( tostr( value ) ); 65 | } 66 | 67 | void Serial_CLS::print( double value ) 68 | { 69 | this->_test_output_string( tostr( value ) ); 70 | } 71 | 72 | void Serial_CLS::begin( int baudrate ) 73 | { 74 | 75 | } 76 | 77 | int Serial_CLS::available() 78 | { 79 | return this->_test_input_buffer.empty() == false; 80 | } 81 | 82 | char Serial_CLS::read() 83 | { 84 | char ret = this->_test_input_buffer.front(); 85 | this->_test_input_buffer.pop(); 86 | return ret; 87 | } 88 | -------------------------------------------------------------------------------- /src/fakes/Servo.cpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include "Servo.h" 4 | 5 | DEFINE_FAKE_VOID_FUNC( Servo__write, Servo*, int ); 6 | DEFINE_FAKE_VOID_FUNC( Servo__attach, Servo*, int ); 7 | 8 | -------------------------------------------------------------------------------- /src/fakes/Servo.h: -------------------------------------------------------------------------------- 1 | #ifndef FAKE_SERVO_H 2 | #define FAKE_SERVO_H 3 | 4 | #include "fff.h" 5 | 6 | 7 | class Servo; 8 | 9 | void Servo__write( Servo*, int ); 10 | DECLARE_FAKE_VOID_FUNC( Servo__write, Servo*, int ); 11 | 12 | void Servo__attach( Servo*, int ); 13 | DECLARE_FAKE_VOID_FUNC( Servo__attach, Servo*, int ); 14 | 15 | 16 | class Servo 17 | { 18 | public: 19 | void write( int angle ) { Servo__write(this, angle); }; 20 | void attach( int pin ) { Servo__attach(this, pin ); } ; 21 | }; 22 | 23 | #endif 24 | -------------------------------------------------------------------------------- /src/fakes/avr/eeprom.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef FAKE_AVR_EEPROM_H 3 | #define FAKE_AVR_EEPROM_H 4 | 5 | #include 6 | #include 7 | 8 | void eeprom_update_block( const void * src, void * dst, size_t n ); 9 | void eeprom_read_block( void * dest, const void * source, size_t n ); 10 | 11 | void _test_eeprom_reset(); 12 | 13 | #endif -------------------------------------------------------------------------------- /src/fakes/crc16.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "util/crc16.h" 3 | 4 | // Function directly from : http://www.nongnu.org/avr-libc/user-manual/group__util__crc.html 5 | uint16_t _crc16_update(uint16_t crc, uint8_t a) 6 | { 7 | int i; 8 | crc ^= a; 9 | for (i = 0; i < 8; ++i) 10 | { 11 | if (crc & 1) 12 | crc = (crc >> 1) ^ 0xA001; 13 | else 14 | crc = (crc >> 1); 15 | } 16 | return crc; 17 | } -------------------------------------------------------------------------------- /src/fakes/eeprom.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "avr/eeprom.h" 5 | #define EEPROM_SIZE 1024 6 | static char EEPROM[EEPROM_SIZE]; 7 | 8 | 9 | void _test_eeprom_reset() 10 | { 11 | memset( EEPROM, 0xFE, sizeof(EEPROM )); 12 | } 13 | 14 | void eeprom_update_block( const void * src, void* dst, size_t n ) 15 | { 16 | char* target = (char*)EEPROM + (uintptr_t)(dst); 17 | assert( target + n < EEPROM + EEPROM_SIZE ); 18 | 19 | memcpy( target, src, n ); 20 | 21 | } 22 | 23 | void eeprom_read_block( void * dest, const void * source, size_t n ) 24 | { 25 | char* target = EEPROM + (uintptr_t)source; 26 | 27 | assert( target + n < EEPROM + EEPROM_SIZE ); 28 | memcpy( dest, target, n ); 29 | } -------------------------------------------------------------------------------- /src/fakes/util/crc16.h: -------------------------------------------------------------------------------- 1 | #ifndef _FAKE_CRC16_H 2 | #define _FAKE_CRC16_H 3 | 4 | #include 5 | 6 | // Function directly from : http://www.nongnu.org/avr-libc/user-manual/group__util__crc.html 7 | uint16_t _crc16_update(uint16_t crc, uint8_t a); 8 | 9 | 10 | #endif -------------------------------------------------------------------------------- /src/mocks_gen/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/susundberg/arduino-simple-unittest/ed72c17cd8eea76fb51ed22ceb8a5b7a29daa4a1/src/mocks_gen/.gitkeep -------------------------------------------------------------------------------- /src/mocks_man/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/susundberg/arduino-simple-unittest/ed72c17cd8eea76fb51ed22ceb8a5b7a29daa4a1/src/mocks_man/.gitkeep -------------------------------------------------------------------------------- /src/tmain.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * 3 | * This is test main program that is linked with all the tests. Do not add tests to this file, unless 4 | * you want every binary to run those. 5 | */ 6 | 7 | #define CATCH_CONFIG_MAIN 8 | #include "fff.h" 9 | 10 | DEFINE_FFF_GLOBALS; 11 | 12 | #include "catch.hpp" 13 | 14 | 15 | -------------------------------------------------------------------------------- /src/tools/Makefile: -------------------------------------------------------------------------------- 1 | 2 | test_build: 3 | scons -Y tests/tools/ 4 | 5 | test_run: 6 | ./tests/tools/run_all.sh 7 | 8 | test_coverage: 9 | ./tests/tools/run_coverage.sh 10 | 11 | 12 | -------------------------------------------------------------------------------- /src/tools/SConstruct: -------------------------------------------------------------------------------- 1 | SConscript('tests/SConscript') 2 | 3 | -------------------------------------------------------------------------------- /src/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/susundberg/arduino-simple-unittest/ed72c17cd8eea76fb51ed22ceb8a5b7a29daa4a1/src/tools/__init__.py -------------------------------------------------------------------------------- /src/tools/cpp/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.12' 2 | -------------------------------------------------------------------------------- /src/tools/cpp/ast.py: -------------------------------------------------------------------------------- 1 | # Copyright 2007 Neal Norwitz 2 | # Portions Copyright 2007 Google Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Generate an Abstract Syntax Tree (AST) for C++.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import print_function 20 | from __future__ import unicode_literals 21 | 22 | from . import keywords 23 | from . import tokenize 24 | 25 | 26 | try: 27 | unicode 28 | except NameError: 29 | unicode = str 30 | 31 | 32 | __author__ = 'nnorwitz@google.com (Neal Norwitz)' 33 | 34 | 35 | # TODO: 36 | # * Tokens should never be exported, need to convert to Nodes 37 | # (return types, parameters, etc.) 38 | # * Handle static class data for templatized classes 39 | # * Handle casts (both C++ and C-style) 40 | # * Handle conditions and loops (if/else, switch, for, while/do) 41 | # 42 | # TODO much, much later: 43 | # * Handle #define 44 | # * exceptions 45 | 46 | 47 | FUNCTION_NONE = 0x00 48 | FUNCTION_SPECIFIER = 0x01 49 | FUNCTION_VIRTUAL = 0x02 50 | FUNCTION_PURE_VIRTUAL = 0x04 51 | FUNCTION_CTOR = 0x08 52 | FUNCTION_DTOR = 0x10 53 | FUNCTION_ATTRIBUTE = 0x20 54 | FUNCTION_UNKNOWN_ANNOTATION = 0x40 55 | FUNCTION_THROW = 0x80 56 | 57 | 58 | class ParseError(Exception): 59 | 60 | """Raise exception on parsing problems.""" 61 | 62 | 63 | # TODO(nnorwitz): move AST nodes into a separate module. 64 | class Node(object): 65 | 66 | """Base AST node.""" 67 | 68 | def __init__(self, start, end): 69 | self.start = start 70 | self.end = end 71 | 72 | def is_declaration(self): 73 | """Returns bool if this node is a declaration.""" 74 | return False 75 | 76 | def is_definition(self): 77 | """Returns bool if this node is a definition.""" 78 | return False 79 | 80 | def is_exportable(self): 81 | """Returns bool if this node exportable from a header file.""" 82 | return False 83 | 84 | def _string_helper(self, name, suffix): 85 | return '%s(%d, %d, %s)' % (name, self.start, self.end, suffix) 86 | 87 | def __repr__(self): 88 | return unicode(self) 89 | 90 | 91 | class Define(Node): 92 | 93 | def __init__(self, start, end, name, definition): 94 | Node.__init__(self, start, end) 95 | self.name = name 96 | self.definition = definition 97 | 98 | def __str__(self): 99 | value = '%s %s' % (self.name, self.definition) 100 | return self._string_helper(self.__class__.__name__, value) 101 | 102 | 103 | class Include(Node): 104 | 105 | def __init__(self, start, end, filename, system): 106 | Node.__init__(self, start, end) 107 | self.filename = filename 108 | self.system = system 109 | 110 | def __str__(self): 111 | fmt = '"%s"' 112 | if self.system: 113 | fmt = '<%s>' 114 | return self._string_helper(self.__class__.__name__, 115 | fmt % self.filename) 116 | 117 | 118 | class Expr(Node): 119 | 120 | def __init__(self, start, end, expr): 121 | Node.__init__(self, start, end) 122 | self.expr = expr 123 | 124 | def __str__(self): 125 | return self._string_helper(self.__class__.__name__, unicode(self.expr)) 126 | 127 | 128 | class Friend(Expr): 129 | 130 | def __init__(self, start, end, expr, namespace): 131 | Expr.__init__(self, start, end, expr) 132 | self.namespace = namespace[:] 133 | 134 | 135 | class Using(Node): 136 | 137 | def __init__(self, start, end, names): 138 | Node.__init__(self, start, end) 139 | self.names = names 140 | 141 | def __str__(self): 142 | return self._string_helper(self.__class__.__name__, 143 | unicode(self.names)) 144 | 145 | 146 | class Parameter(Node): 147 | 148 | def __init__(self, start, end, name, parameter_type, default): 149 | Node.__init__(self, start, end) 150 | self.name = name 151 | self.type = parameter_type 152 | self.default = default 153 | 154 | def __str__(self): 155 | name = unicode(self.type) 156 | suffix = '%s %s' % (name, self.name) 157 | if self.default: 158 | suffix += ' = ' + ''.join([d.name for d in self.default]) 159 | return self._string_helper(self.__class__.__name__, suffix) 160 | 161 | 162 | class _GenericDeclaration(Node): 163 | 164 | def __init__(self, start, end, name, namespace): 165 | Node.__init__(self, start, end) 166 | self.name = name 167 | self.namespace = namespace[:] 168 | 169 | def full_name(self): 170 | prefix = '' 171 | names = [n for n in self.namespace if n is not None] 172 | if names: 173 | prefix = '::'.join(names) + '::' 174 | return prefix + self.name 175 | 176 | def _type_string_helper(self, suffix): 177 | if self.namespace: 178 | names = [n or '' for n in self.namespace] 179 | suffix += ' in ' + '::'.join(names) 180 | return self._string_helper(self.__class__.__name__, suffix) 181 | 182 | 183 | # TODO(nnorwitz): merge with Parameter in some way? 184 | class VariableDeclaration(_GenericDeclaration): 185 | 186 | def __init__(self, start, end, name, var_type, initial_value, namespace): 187 | _GenericDeclaration.__init__(self, start, end, name, namespace) 188 | self.type = var_type 189 | self.initial_value = initial_value 190 | 191 | def to_string(self): 192 | """Return a string that tries to reconstitute the variable decl.""" 193 | suffix = '%s %s' % (self.type, self.name) 194 | if self.initial_value: 195 | suffix += ' = ' + self.initial_value 196 | return suffix 197 | 198 | def __str__(self): 199 | return self._string_helper(self.__class__.__name__, self.to_string()) 200 | 201 | 202 | class Typedef(_GenericDeclaration): 203 | 204 | def __init__(self, start, end, name, alias, namespace): 205 | _GenericDeclaration.__init__(self, start, end, name, namespace) 206 | self.alias = alias 207 | 208 | def is_definition(self): 209 | return True 210 | 211 | def is_exportable(self): 212 | return True 213 | 214 | def __str__(self): 215 | suffix = '%s, %s' % (self.name, self.alias) 216 | return self._type_string_helper(suffix) 217 | 218 | 219 | class Enum(_GenericDeclaration): 220 | 221 | def __init__(self, start, end, name, fields, namespace): 222 | _GenericDeclaration.__init__(self, start, end, name, namespace) 223 | self.fields = fields 224 | 225 | def is_definition(self): 226 | return True 227 | 228 | def is_exportable(self): 229 | return True 230 | 231 | def __str__(self): 232 | suffix = '%s, {%s}' % (self.name, self.fields) 233 | return self._type_string_helper(suffix) 234 | 235 | 236 | class Class(_GenericDeclaration): 237 | 238 | def __init__(self, start, end, name, 239 | bases, templated_types, body, namespace): 240 | _GenericDeclaration.__init__(self, start, end, name, namespace) 241 | self.bases = bases 242 | self.body = body 243 | self.templated_types = templated_types 244 | 245 | def is_declaration(self): 246 | return self.bases is None and self.body is None 247 | 248 | def is_definition(self): 249 | return not self.is_declaration() 250 | 251 | def is_exportable(self): 252 | return not self.is_declaration() 253 | 254 | def __str__(self): 255 | name = self.name 256 | if self.templated_types: 257 | types = ','.join([t for t in self.templated_types]) 258 | name += '<%s>' % types 259 | suffix = '%s, %s, %s' % (name, self.bases, self.body) 260 | return self._type_string_helper(suffix) 261 | 262 | 263 | class Struct(Class): 264 | pass 265 | 266 | 267 | class Union(Class): 268 | pass 269 | 270 | 271 | class Function(_GenericDeclaration): 272 | 273 | def __init__(self, start, end, name, return_type, parameters, 274 | specializations, modifiers, templated_types, body, namespace): 275 | _GenericDeclaration.__init__(self, start, end, name, namespace) 276 | converter = TypeConverter(namespace) 277 | self.return_type = converter.create_return_type(return_type) 278 | self.parameters = converter.to_parameters(parameters) 279 | self.specializations = converter.to_type(specializations) 280 | self.modifiers = modifiers 281 | self.body = body 282 | self.templated_types = templated_types 283 | 284 | def is_declaration(self): 285 | return self.body is None 286 | 287 | def is_definition(self): 288 | return self.body is not None 289 | 290 | def is_exportable(self): 291 | if self.return_type: 292 | if ( 293 | 'static' in self.return_type.modifiers or 294 | 'constexpr' in self.return_type.modifiers 295 | ): 296 | return False 297 | return None not in self.namespace 298 | 299 | def __str__(self): 300 | # TODO(nnorwitz): add templated_types. 301 | suffix = ('%s %s(%s), 0x%02x, %s' % 302 | (self.return_type, self.name, self.parameters, 303 | self.modifiers, self.body)) 304 | return self._type_string_helper(suffix) 305 | 306 | 307 | class Method(Function): 308 | 309 | def __init__(self, start, end, name, in_class, return_type, parameters, 310 | specializations, modifiers, templated_types, body, namespace): 311 | Function.__init__(self, start, end, name, return_type, parameters, 312 | specializations, modifiers, templated_types, 313 | body, namespace) 314 | # TODO(nnorwitz): in_class could also be a namespace which can 315 | # mess up finding functions properly. 316 | self.in_class = in_class 317 | 318 | 319 | class Type(_GenericDeclaration): 320 | 321 | """Type used for any variable (eg class, primitive, struct, etc).""" 322 | 323 | def __init__(self, start, end, name, templated_types, modifiers, 324 | reference, pointer, array): 325 | """Args: 326 | 327 | name: str name of main type 328 | templated_types: [Class (Type?)] template type info between <> 329 | modifiers: [str] type modifiers (keywords) eg, const, mutable, etc. 330 | reference, pointer, array: bools 331 | 332 | """ 333 | _GenericDeclaration.__init__(self, start, end, name, []) 334 | self.templated_types = templated_types 335 | if not name and modifiers: 336 | self.name = modifiers.pop() 337 | self.modifiers = modifiers 338 | self.reference = reference 339 | self.pointer = pointer 340 | self.array = array 341 | 342 | def __str__(self): 343 | prefix = '' 344 | if self.modifiers: 345 | prefix = ' '.join(self.modifiers) + ' ' 346 | name = unicode(self.name) 347 | if self.templated_types: 348 | name += '<%s>' % self.templated_types 349 | suffix = prefix + name 350 | if self.reference: 351 | suffix += '&' 352 | if self.pointer: 353 | suffix += '*' 354 | if self.array: 355 | suffix += '[]' 356 | return self._type_string_helper(suffix) 357 | 358 | # By definition, Is* are always False. A Type can only exist in 359 | # some sort of variable declaration, parameter, or return value. 360 | def is_declaration(self): 361 | return False 362 | 363 | def is_definition(self): 364 | return False 365 | 366 | def is_exportable(self): 367 | return False 368 | 369 | 370 | class TypeConverter(object): 371 | 372 | def __init__(self, namespace_stack): 373 | self.namespace_stack = namespace_stack 374 | 375 | def _get_template_end(self, tokens, start): 376 | count = 1 377 | end = start 378 | while count and end < len(tokens): 379 | token = tokens[end] 380 | end += 1 381 | if token.name == '<': 382 | count += 1 383 | elif token.name == '>': 384 | count -= 1 385 | return tokens[start:end - 1], end 386 | 387 | def to_type(self, tokens): 388 | """Convert [Token,...] to [Class(...), ] useful for base classes. 389 | 390 | For example, code like class Foo : public Bar { ... }; 391 | the "Bar" portion gets converted to an AST. 392 | 393 | Returns: 394 | [Class(...), ...] 395 | 396 | """ 397 | result = [] 398 | name_tokens = [] 399 | reference = pointer = array = False 400 | inside_array = False 401 | empty_array = True 402 | templated_tokens = [] 403 | 404 | def add_type(): 405 | if not name_tokens: 406 | return 407 | 408 | # Partition tokens into name and modifier tokens. 409 | names = [] 410 | modifiers = [] 411 | for t in name_tokens: 412 | if keywords.is_keyword(t.name): 413 | modifiers.append(t.name) 414 | else: 415 | names.append(t.name) 416 | name = ''.join(names) 417 | 418 | templated_types = self.to_type(templated_tokens) 419 | result.append(Type(name_tokens[0].start, name_tokens[-1].end, 420 | name, templated_types, modifiers, 421 | reference, pointer, array)) 422 | del name_tokens[:] 423 | del templated_tokens[:] 424 | 425 | i = 0 426 | end = len(tokens) 427 | while i < end: 428 | token = tokens[i] 429 | if token.name == ']': 430 | inside_array = False 431 | if empty_array: 432 | pointer = True 433 | else: 434 | array = True 435 | elif inside_array: 436 | empty_array = False 437 | elif token.name == '<': 438 | templated_tokens, i = self._get_template_end(tokens, i + 1) 439 | continue 440 | elif token.name == ',' or token.name == '(': 441 | add_type() 442 | reference = pointer = array = False 443 | empty_array = True 444 | elif token.name == '*': 445 | pointer = True 446 | elif token.name == '&': 447 | reference = True 448 | elif token.name == '[': 449 | inside_array = True 450 | elif token.name != ')': 451 | name_tokens.append(token) 452 | i += 1 453 | 454 | add_type() 455 | return result 456 | 457 | def declaration_to_parts(self, parts, needs_name_removed): 458 | arrayBegin = 0 459 | arrayEnd = 0 460 | default = [] 461 | other_tokens = [] 462 | 463 | # Handle default (initial) values properly. 464 | for i, t in enumerate(parts): 465 | if t.name == '[' and arrayBegin == 0: 466 | arrayBegin = i 467 | other_tokens.append(t) 468 | elif t.name == ']': 469 | arrayEnd = i 470 | other_tokens.append(t) 471 | elif t.name == '=': 472 | default = parts[i + 1:] 473 | parts = parts[:i] 474 | break 475 | 476 | if arrayBegin < arrayEnd: 477 | parts = parts[:arrayBegin] + parts[arrayEnd + 1:] 478 | 479 | modifiers = [] 480 | type_name = [''] 481 | last_type = tokenize.SYNTAX 482 | templated_types = [] 483 | i = 0 484 | end = len(parts) 485 | while i < end: 486 | p = parts[i] 487 | if keywords.is_builtin_modifiers(p.name): 488 | modifiers.append(p.name) 489 | elif p.name == '<': 490 | templated_tokens, new_end = self._get_template_end( 491 | parts, i + 1) 492 | templated_types = self.to_type(templated_tokens) 493 | i = new_end - 1 494 | elif p.name not in ('*', '&'): 495 | if ( 496 | last_type == tokenize.NAME and 497 | p.token_type == tokenize.NAME 498 | ): 499 | type_name.append('') 500 | type_name[-1] += p.name 501 | last_type = p.token_type 502 | else: 503 | other_tokens.append(p) 504 | i += 1 505 | 506 | name = None 507 | if len(type_name) == 1 or keywords.is_builtin_type(type_name[-1]): 508 | needs_name_removed = False 509 | if needs_name_removed: 510 | name = type_name.pop() 511 | 512 | return (name, 513 | ' '.join([t for t in type_name]), 514 | templated_types, 515 | modifiers, 516 | default, 517 | other_tokens) 518 | 519 | def to_parameters(self, tokens): 520 | if not tokens: 521 | return [] 522 | 523 | result = [] 524 | type_modifiers = [] 525 | pointer = reference = False 526 | first_token = None 527 | default = [] 528 | 529 | def add_parameter(): 530 | if not type_modifiers: 531 | return 532 | if default: 533 | del default[0] # Remove flag. 534 | end = type_modifiers[-1].end 535 | 536 | (name, type_name, templated_types, modifiers, 537 | _, __) = self.declaration_to_parts(type_modifiers, 538 | True) 539 | 540 | if type_name: 541 | parameter_type = Type(first_token.start, first_token.end, 542 | type_name, templated_types, modifiers, 543 | reference, pointer, False) 544 | p = Parameter(first_token.start, end, name, 545 | parameter_type, default) 546 | result.append(p) 547 | 548 | template_count = 0 549 | for s in tokens: 550 | if not first_token: 551 | first_token = s 552 | if s.name == '<': 553 | template_count += 1 554 | elif s.name == '>': 555 | template_count -= 1 556 | if template_count > 0: 557 | if default: 558 | default.append(s) 559 | else: 560 | type_modifiers.append(s) 561 | continue 562 | 563 | if s.name == ',': 564 | add_parameter() 565 | type_modifiers = [] 566 | pointer = reference = False 567 | first_token = None 568 | default = [] 569 | elif default: 570 | default.append(s) 571 | elif s.name == '*': 572 | pointer = True 573 | elif s.name == '&': 574 | reference = True 575 | elif s.name == '[': 576 | pointer = True 577 | elif s.name == ']': 578 | pass # Just don't add to type_modifiers. 579 | elif s.name == '=': 580 | # Got a default value. Add any value (None) as a flag. 581 | default.append(None) 582 | else: 583 | type_modifiers.append(s) 584 | add_parameter() 585 | return result 586 | 587 | def create_return_type(self, return_type_seq): 588 | if not return_type_seq: 589 | return None 590 | start = return_type_seq[0].start 591 | end = return_type_seq[-1].end 592 | 593 | _, name, templated_types, modifiers, __, other_tokens = ( 594 | self.declaration_to_parts(return_type_seq, False)) 595 | 596 | names = [n.name for n in other_tokens] 597 | reference = '&' in names 598 | pointer = '*' in names 599 | array = '[' in names 600 | return Type(start, end, name, templated_types, modifiers, 601 | reference, pointer, array) 602 | 603 | def get_template_indices(self, names): 604 | # names is a list of strings. 605 | start = names.index('<') 606 | end = len(names) - 1 607 | while end > 0: 608 | if names[end] == '>': 609 | break 610 | end -= 1 611 | return start, end + 1 612 | 613 | 614 | class ASTBuilder(object): 615 | 616 | def __init__(self, token_stream, filename, in_class=None, 617 | namespace_stack=None, quiet=False): 618 | if namespace_stack is None: 619 | namespace_stack = [] 620 | 621 | self.tokens = token_stream 622 | self.filename = filename 623 | self.token_queue = [] 624 | self.namespace_stack = namespace_stack[:] 625 | self.namespaces = [] 626 | self.define = set() 627 | self.quiet = quiet 628 | self.in_class = in_class 629 | if in_class: 630 | self.namespaces.append(False) 631 | # Keep the state whether we are currently handling a typedef or not. 632 | self._handling_typedef = False 633 | self._handling_const = False 634 | self.converter = TypeConverter(self.namespace_stack) 635 | 636 | def generate(self): 637 | while True: 638 | try: 639 | token = self._get_next_token() 640 | except StopIteration: 641 | break 642 | 643 | if token.name == '{': 644 | self.namespaces.append(False) 645 | continue 646 | if token.name == '}': 647 | if self.namespaces.pop(): 648 | self.namespace_stack.pop() 649 | continue 650 | 651 | result = self._generate_one(token) 652 | if result: 653 | yield result 654 | 655 | def _create_variable(self, pos_token, name, type_name, type_modifiers, 656 | ref_pointer_name_seq, templated_types=None, value=''): 657 | if templated_types is None: 658 | templated_types = [] 659 | 660 | reference = '&' in ref_pointer_name_seq 661 | pointer = '*' in ref_pointer_name_seq 662 | array = '[' in ref_pointer_name_seq 663 | var_type = Type(pos_token.start, pos_token.end, type_name, 664 | templated_types, type_modifiers, 665 | reference, pointer, array) 666 | return VariableDeclaration(pos_token.start, pos_token.end, 667 | name, var_type, value, self.namespace_stack) 668 | 669 | def _generate_one(self, token): 670 | if token.token_type == tokenize.NAME: 671 | if (keywords.is_keyword(token.name) and 672 | not keywords.is_builtin_type(token.name)): 673 | method = getattr(self, 'handle_' + token.name, None) 674 | assert_parse(method, 'unexpected token: {}'.format(token)) 675 | return method() 676 | 677 | # Handle data or function declaration/definition. 678 | temp_tokens, last_token = \ 679 | self._get_var_tokens_up_to(True, '(', ';', '{') 680 | 681 | temp_tokens.insert(0, token) 682 | if last_token.name == '(' or last_token.name == '{': 683 | # Ignore static_assert 684 | if temp_tokens[-1].name == 'static_assert': 685 | self._ignore_up_to(';') 686 | return None 687 | 688 | # Ignore __declspec 689 | if temp_tokens[-1].name == '__declspec': 690 | list(self._get_parameters()) 691 | return None 692 | 693 | # Ignore __attribute__ 694 | if temp_tokens[-1].name == '__attribute__': 695 | list(self._get_parameters()) 696 | new_temp, last_token = \ 697 | self._get_var_tokens_up_to(True, '(', ';', '{') 698 | del temp_tokens[-1] 699 | temp_tokens.extend(new_temp) 700 | 701 | # If there is an assignment before the paren, 702 | # this is an expression, not a method. 703 | for i, elt in reversed(list(enumerate(temp_tokens))): 704 | if ( 705 | elt.name == '=' and 706 | temp_tokens[i - 1].name != 'operator' 707 | ): 708 | temp_tokens.append(last_token) 709 | new_temp, last_token = \ 710 | self._get_var_tokens_up_to(False, ';') 711 | temp_tokens.extend(new_temp) 712 | break 713 | 714 | if last_token.name == ';': 715 | return self._get_variable(temp_tokens) 716 | if last_token.name == '{': 717 | assert_parse(temp_tokens, 'not enough tokens') 718 | 719 | self._add_back_token(last_token) 720 | self._add_back_tokens(temp_tokens[1:]) 721 | method_name = temp_tokens[0].name 722 | method = getattr(self, 'handle_' + method_name, None) 723 | if not method: 724 | return None 725 | return method() 726 | return self._get_method(temp_tokens, 0, None, False) 727 | elif token.token_type == tokenize.SYNTAX: 728 | if token.name == '~' and self.in_class: 729 | # Must be a dtor (probably not in method body). 730 | token = self._get_next_token() 731 | return self._get_method([token], FUNCTION_DTOR, None, True) 732 | # TODO(nnorwitz): handle a lot more syntax. 733 | elif token.token_type == tokenize.PREPROCESSOR: 734 | # TODO(nnorwitz): handle more preprocessor directives. 735 | # token starts with a #, so remove it and strip whitespace. 736 | name = token.name[1:].lstrip() 737 | if name.startswith('include'): 738 | # Remove "include". 739 | name = name[7:].strip() 740 | assert name 741 | # Handle #include \ "header-on-second-line.h". 742 | if name.startswith('\\'): 743 | name = name[1:].strip() 744 | 745 | system = True 746 | filename = name 747 | if name[0] in '<"': 748 | assert_parse(name[-1] in '>"', token) 749 | 750 | system = name[0] == '<' 751 | filename = name[1:-1] 752 | return Include(token.start, token.end, filename, system) 753 | if name.startswith('define'): 754 | # Remove "define". 755 | name = name[6:].strip() 756 | assert name 757 | # Handle #define \ MACRO. 758 | if name.startswith('\\'): 759 | name = name[1:].strip() 760 | value = '' 761 | paren = 0 762 | 763 | for i, c in enumerate(name): 764 | if not paren and c.isspace(): 765 | value = name[i:].lstrip() 766 | name = name[:i] 767 | break 768 | if c == ')': 769 | value = name[i + 1:].lstrip() 770 | name = name[:paren] 771 | self.define.add(name) 772 | break 773 | if c == '(': 774 | paren = i 775 | if value.startswith('\\'): 776 | value = value[1:].strip() 777 | return Define(token.start, token.end, name, value) 778 | if name.startswith('undef'): 779 | # Remove "undef". 780 | name = name[5:].strip() 781 | assert name 782 | self.define.discard(name) 783 | return None 784 | 785 | def _get_tokens_up_to(self, expected_token): 786 | return self._get_var_tokens_up_to(False, 787 | expected_token)[0] 788 | 789 | def _get_var_tokens_up_to(self, skip_bracket_content, *expected_tokens): 790 | last_token = self._get_next_token() 791 | tokens = [] 792 | count1 = 0 793 | count2 = 0 794 | while (count1 != 0 or 795 | count2 != 0 or 796 | last_token.token_type != tokenize.SYNTAX or 797 | last_token.name not in expected_tokens): 798 | if last_token.name == '[': 799 | count1 += 1 800 | elif last_token.name == ']': 801 | count1 -= 1 802 | if skip_bracket_content and count1 == 0: 803 | if last_token.name == 'operator': 804 | skip_bracket_content = False 805 | elif last_token.name == '<': 806 | count2 += 1 807 | elif last_token.name == '>': 808 | count2 -= 1 809 | if last_token.token_type != tokenize.PREPROCESSOR: 810 | tokens.append(last_token) 811 | temp_token = self._get_next_token() 812 | if temp_token.name == '(' and last_token.name in self.define: 813 | # TODO: for now just ignore the tokens inside the parenthesis 814 | list(self._get_parameters()) 815 | temp_token = self._get_next_token() 816 | last_token = temp_token 817 | return tokens, last_token 818 | 819 | def _ignore_up_to(self, token): 820 | self._get_tokens_up_to(token) 821 | 822 | def _get_matching_char(self, open_paren, close_paren, get_next_token=None): 823 | if get_next_token is None: 824 | get_next_token = self._get_next_token 825 | # Assumes the current token is open_paren and we will consume 826 | # and return up to the close_paren. 827 | count = 1 828 | while count != 0: 829 | token = get_next_token() 830 | if token.token_type == tokenize.SYNTAX: 831 | if token.name == open_paren: 832 | count += 1 833 | elif token.name == close_paren: 834 | count -= 1 835 | yield token 836 | 837 | def _get_parameters(self): 838 | return self._get_matching_char('(', ')') 839 | 840 | def get_scope(self): 841 | return self._get_matching_char('{', '}') 842 | 843 | def _get_next_token(self): 844 | if self.token_queue: 845 | return self.token_queue.pop() 846 | return next(self.tokens) 847 | 848 | def _add_back_token(self, token): 849 | self.token_queue.append(token) 850 | 851 | def _add_back_tokens(self, tokens): 852 | if tokens: 853 | self.token_queue.extend(reversed(tokens)) 854 | 855 | def get_name(self, seq=None): 856 | """Returns ([tokens], next_token_info).""" 857 | if seq is not None: 858 | it = iter(seq) 859 | 860 | def get_next_token(): 861 | return next(it) 862 | else: 863 | get_next_token = self._get_next_token 864 | 865 | next_token = get_next_token() 866 | tokens = [] 867 | last_token_was_name = False 868 | while (next_token.token_type == tokenize.NAME or 869 | (next_token.token_type == tokenize.SYNTAX and 870 | next_token.name in ('::', '<'))): 871 | # Two NAMEs in a row means the identifier should terminate. 872 | # It's probably some sort of variable declaration. 873 | if last_token_was_name and next_token.token_type == tokenize.NAME: 874 | break 875 | last_token_was_name = next_token.token_type == tokenize.NAME 876 | tokens.append(next_token) 877 | # Handle templated names. 878 | if next_token.name == '<': 879 | tokens.extend(self._get_matching_char('<', '>', 880 | get_next_token)) 881 | last_token_was_name = True 882 | next_token = get_next_token() 883 | return tokens, next_token 884 | 885 | def get_method(self, modifiers, templated_types): 886 | return_type_and_name = self._get_tokens_up_to('(') 887 | assert len(return_type_and_name) >= 1 888 | return self._get_method( 889 | return_type_and_name, modifiers, templated_types, 890 | False) 891 | 892 | def _get_method(self, return_type_and_name, modifiers, templated_types, 893 | get_paren): 894 | specializations = [] 895 | if get_paren: 896 | token = self._get_next_token() 897 | assert_parse(token.token_type == tokenize.SYNTAX, token) 898 | if token.name == '<': 899 | # Handle templatized dtors. 900 | specializations = list(self._get_matching_char('<', '>')) 901 | del specializations[-1] 902 | token = self._get_next_token() 903 | assert_parse(token.token_type == tokenize.SYNTAX, token) 904 | assert_parse(token.name == '(', token) 905 | 906 | name = return_type_and_name.pop() 907 | if (len(return_type_and_name) > 2 and 908 | return_type_and_name[-1].name == '>' and 909 | return_type_and_name[-2].name == 'operator' and 910 | (name.name == '>=' or name.name == '>')): 911 | n = return_type_and_name.pop() 912 | name = tokenize.Token(tokenize.SYNTAX, 913 | n.name + name.name, 914 | n.start, name.end) 915 | 916 | if (len(return_type_and_name) > 1 and 917 | (return_type_and_name[-1].name == 'operator' or 918 | return_type_and_name[-1].name == '~')): 919 | op = return_type_and_name.pop() 920 | name = tokenize.Token(tokenize.NAME, op.name + name.name, 921 | op.start, name.end) 922 | # Handle templatized ctors. 923 | elif name.name == '>': 924 | count = 1 925 | index = len(return_type_and_name) 926 | while count and index > 0: 927 | index -= 1 928 | tok = return_type_and_name[index] 929 | if tok.name == '<': 930 | count -= 1 931 | elif tok.name == '>': 932 | count += 1 933 | specializations = return_type_and_name[index + 1:] 934 | del return_type_and_name[index:] 935 | name = return_type_and_name.pop() 936 | elif name.name == ']': 937 | name_seq = return_type_and_name[-2] 938 | del return_type_and_name[-2:] 939 | name = tokenize.Token(tokenize.NAME, name_seq.name + '[]', 940 | name_seq.start, name.end) 941 | 942 | return_type = return_type_and_name 943 | indices = name 944 | if return_type: 945 | indices = return_type[0] 946 | 947 | # Force ctor for templatized ctors. 948 | if name.name == self.in_class and not modifiers: 949 | modifiers |= FUNCTION_CTOR 950 | parameters = list(self._get_parameters()) 951 | assert_parse(parameters, 'missing closing parenthesis') 952 | last_token = parameters.pop() # Remove trailing ')'. 953 | 954 | # Handling operator() is especially weird. 955 | if name.name == 'operator' and not parameters: 956 | token = self._get_next_token() 957 | assert_parse(token.name == '(', token) 958 | name = tokenize.Token(tokenize.NAME, 'operator()', 959 | name.start, last_token.end) 960 | parameters = list(self._get_parameters()) 961 | del parameters[-1] # Remove trailing ')'. 962 | 963 | try: 964 | token = self._get_next_token() 965 | except StopIteration: 966 | token = tokenize.Token(tokenize.SYNTAX, ';', 0, 0) 967 | 968 | while ( 969 | token.token_type == tokenize.NAME or 970 | token.token_type == tokenize.PREPROCESSOR 971 | ): 972 | if ( 973 | token.name == 'const' or 974 | token.name == 'override' or 975 | token.name == 'final' 976 | ): 977 | modifiers |= FUNCTION_SPECIFIER 978 | token = self._get_next_token() 979 | elif token.name == 'noexcept': 980 | modifiers |= FUNCTION_SPECIFIER 981 | token = self._get_next_token() 982 | if token.name == '(': 983 | # Consume everything between the parens. 984 | list(self._get_matching_char('(', ')')) 985 | token = self._get_next_token() 986 | elif token.name == '__attribute__': 987 | # TODO(nnorwitz): handle more __attribute__ details. 988 | modifiers |= FUNCTION_ATTRIBUTE 989 | token = self._get_next_token() 990 | assert_parse(token.name == '(', token) 991 | # Consume everything between the parens. 992 | list(self._get_matching_char('(', ')')) 993 | token = self._get_next_token() 994 | elif token.name == 'throw': 995 | modifiers |= FUNCTION_THROW 996 | token = self._get_next_token() 997 | assert_parse(token.name == '(', token) 998 | # Consume everything between the parens. 999 | list(self._get_matching_char('(', ')')) 1000 | token = self._get_next_token() 1001 | elif token.name == token.name.upper(): 1002 | # Assume that all upper-case names are some macro. 1003 | modifiers |= FUNCTION_UNKNOWN_ANNOTATION 1004 | token = self._get_next_token() 1005 | if token.name == '(': 1006 | # Consume everything between the parens. 1007 | list(self._get_matching_char('(', ')')) 1008 | token = self._get_next_token() 1009 | elif token.token_type == tokenize.PREPROCESSOR: 1010 | token = self._get_next_token() 1011 | else: 1012 | self._add_back_token(token) 1013 | token = tokenize.Token(tokenize.SYNTAX, ';', 0, 0) 1014 | 1015 | # Handle ref-qualifiers. 1016 | if token.name == '&' or token.name == '&&': 1017 | token = self._get_next_token() 1018 | 1019 | # Handle trailing return types. 1020 | if token.name == '->': 1021 | return_type, token = self._get_var_tokens_up_to(False, '{', ';') 1022 | 1023 | if token.name == '}' or token.token_type == tokenize.PREPROCESSOR: 1024 | self._add_back_token(token) 1025 | token = tokenize.Token(tokenize.SYNTAX, ';', 0, 0) 1026 | 1027 | assert_parse(token.token_type == tokenize.SYNTAX, token) 1028 | 1029 | # Handle ctor initializers. 1030 | if token.name == ':': 1031 | while token.name != ';' and token.name != '{': 1032 | _, token = self.get_name() 1033 | if token.name == '(': 1034 | list(self._get_matching_char('(', ')')) 1035 | elif token.name == '{': 1036 | list(self._get_matching_char('{', '}')) 1037 | token = self._get_next_token() 1038 | 1039 | # Handle pointer to functions. 1040 | if token.name == '(': 1041 | # name contains the return type. 1042 | return_type.append(name) 1043 | while parameters[-1].name in '()': 1044 | parameters.pop() 1045 | name = parameters[-1] 1046 | # Already at the ( to open the parameter list. 1047 | parameters = list(self._get_matching_char('(', ')')) 1048 | del parameters[-1] # Remove trailing ')'. 1049 | # TODO(nnorwitz): store the function_parameters. 1050 | token = self._get_next_token() 1051 | 1052 | if token.name != '{': 1053 | default = [] 1054 | if token.name == '=': 1055 | default.extend(self._get_tokens_up_to(';')) 1056 | 1057 | return self._create_variable( 1058 | indices, 1059 | name.name, 1060 | indices.name, 1061 | [], 1062 | [t.name for t in return_type], 1063 | None, 1064 | ''.join([t.name for t in default])) 1065 | 1066 | if token.name == '{': 1067 | body = list(self.get_scope()) 1068 | del body[-1] # Remove trailing '}'. 1069 | else: 1070 | body = None 1071 | if token.name == '=': 1072 | token = self._get_next_token() 1073 | if token.name == '0': 1074 | modifiers |= FUNCTION_PURE_VIRTUAL 1075 | token = self._get_next_token() 1076 | 1077 | if token.name == '[': 1078 | # TODO(nnorwitz): store tokens and improve parsing. 1079 | # template char (&ASH(T (&seq)[N]))[N]; 1080 | list(self._get_matching_char('[', ']')) 1081 | token = self._get_next_token() 1082 | 1083 | if token.name in '*&': 1084 | tokens, last = self._get_var_tokens_up_to(False, '(', ';') 1085 | tokens.insert(0, token) 1086 | tokens = parameters + tokens 1087 | if last.name == '(': 1088 | return self._get_method(tokens, 0, None, False) 1089 | return self._get_variable(tokens) 1090 | 1091 | assert_parse(token.name == ';', 1092 | (token, return_type_and_name, parameters)) 1093 | 1094 | # Looks like we got a method, not a function. 1095 | if len(return_type) > 1 and return_type[-1].name == '::': 1096 | return_type, in_class = \ 1097 | self._get_return_type_and_class_name(return_type) 1098 | return Method(indices.start, indices.end, name.name, in_class, 1099 | return_type, parameters, specializations, modifiers, 1100 | templated_types, body, self.namespace_stack) 1101 | return Function(indices.start, indices.end, name.name, return_type, 1102 | parameters, specializations, modifiers, 1103 | templated_types, body, self.namespace_stack) 1104 | 1105 | def _get_variable(self, tokens): 1106 | name, type_name, templated_types, modifiers, default, _ = \ 1107 | self.converter.declaration_to_parts(tokens, True) 1108 | 1109 | assert_parse(tokens, 'not enough tokens') 1110 | 1111 | t0 = tokens[0] 1112 | names = [t.name for t in tokens] 1113 | if templated_types: 1114 | start, end = self.converter.get_template_indices(names) 1115 | names = names[:start] + names[end:] 1116 | default = ''.join([t.name for t in default]) 1117 | return self._create_variable(t0, name, type_name, modifiers, 1118 | names, templated_types, default) 1119 | 1120 | def _get_return_type_and_class_name(self, token_seq): 1121 | # Splitting the return type from the class name in a method 1122 | # can be tricky. For example, Return::Type::Is::Hard::To::Find(). 1123 | # Where is the return type and where is the class name? 1124 | # The heuristic used is to pull the last name as the class name. 1125 | # This includes all the templated type info. 1126 | # TODO(nnorwitz): if there is only One name like in the 1127 | # example above, punt and assume the last bit is the class name. 1128 | 1129 | i = 0 1130 | end = len(token_seq) - 1 1131 | 1132 | # Make a copy of the sequence so we can append a sentinel 1133 | # value. This is required for get_name will has to have some 1134 | # terminating condition beyond the last name. 1135 | seq_copy = token_seq[i:end] 1136 | seq_copy.append(tokenize.Token(tokenize.SYNTAX, '', 0, 0)) 1137 | names = [] 1138 | while i < end: 1139 | # Iterate through the sequence parsing out each name. 1140 | new_name, next_item = self.get_name(seq_copy[i:]) 1141 | # We got a pointer or ref. Add it to the name. 1142 | if next_item and next_item.token_type == tokenize.SYNTAX: 1143 | new_name.append(next_item) 1144 | names.append(new_name) 1145 | i += len(new_name) 1146 | 1147 | # Remove the sentinel value. 1148 | names[-1].pop() 1149 | # Flatten the token sequence for the return type. 1150 | return_type = [e for seq in names[:-1] for e in seq] 1151 | # The class name is the last name. 1152 | class_name = names[-1] 1153 | return return_type, class_name 1154 | 1155 | def _handle_class_and_struct(self, class_type): 1156 | if self._handling_typedef: 1157 | return self._get_class(class_type, None) 1158 | 1159 | name_tokens, var_token = self.get_name() 1160 | if var_token.token_type == tokenize.NAME or var_token.name in '*&': 1161 | tokens, last = self._get_var_tokens_up_to(False, '(', ';', '{') 1162 | tokens.insert(0, var_token) 1163 | tokens = name_tokens + tokens 1164 | if last.name == '{': 1165 | self._add_back_token(last) 1166 | self._add_back_tokens(tokens) 1167 | return self._get_class(class_type, None) 1168 | if last.name == '(': 1169 | return self._get_method(tokens, 0, None, False) 1170 | return self._get_variable(tokens) 1171 | 1172 | self._add_back_token(var_token) 1173 | self._add_back_tokens(name_tokens) 1174 | return self._get_class(class_type, None) 1175 | 1176 | def handle_class(self): 1177 | return self._handle_class_and_struct(Class) 1178 | 1179 | def handle_struct(self): 1180 | return self._handle_class_and_struct(Struct) 1181 | 1182 | def handle_union(self): 1183 | return self._handle_class_and_struct(Union) 1184 | 1185 | def handle_enum(self): 1186 | # Handle strongly typed enumerations. 1187 | token = self._get_next_token() 1188 | if token.name != 'class': 1189 | self._add_back_token(token) 1190 | 1191 | name = None 1192 | name_tokens, token = self.get_name() 1193 | if name_tokens: 1194 | name = ''.join([t.name for t in name_tokens]) 1195 | 1196 | if token.token_type == tokenize.NAME: 1197 | if self._handling_typedef: 1198 | self._add_back_token(token) 1199 | return Enum(token.start, token.end, name, None, 1200 | self.namespace_stack) 1201 | 1202 | next_token = self._get_next_token() 1203 | if next_token.name != '(': 1204 | self._add_back_token(next_token) 1205 | else: 1206 | name_tokens.append(token) 1207 | return self._get_method(name_tokens, 0, None, False) 1208 | 1209 | # Handle underlying type. 1210 | if token.token_type == tokenize.SYNTAX and token.name == ':': 1211 | _, token = self._get_var_tokens_up_to(False, '{', ';') 1212 | 1213 | # Handle forward declarations. 1214 | if token.token_type == tokenize.SYNTAX and token.name == ';': 1215 | return Enum(token.start, token.end, name, None, 1216 | self.namespace_stack) 1217 | 1218 | # Must be the type declaration. 1219 | if token.token_type == tokenize.SYNTAX and token.name == '{': 1220 | fields = list(self._get_matching_char('{', '}')) 1221 | del fields[-1] # Remove trailing '}'. 1222 | next_item = self._get_next_token() 1223 | new_type = Enum(token.start, token.end, name, fields, 1224 | self.namespace_stack) 1225 | # A name means this is an anonymous type and the name 1226 | # is the variable declaration. 1227 | if next_item.token_type != tokenize.NAME: 1228 | return new_type 1229 | name = new_type 1230 | token = next_item 1231 | 1232 | # Must be variable declaration using the type prefixed with keyword. 1233 | assert_parse(token.token_type == tokenize.NAME, token) 1234 | return self._create_variable(token, token.name, name, [], '') 1235 | 1236 | def handle_const(self): 1237 | self._handling_const = True 1238 | token = self._get_next_token() 1239 | result = self._generate_one(token) 1240 | self._handling_const = False 1241 | return result 1242 | 1243 | def handle_inline(self): 1244 | pass 1245 | 1246 | def handle_extern(self): 1247 | pass 1248 | 1249 | def handle_virtual(self): 1250 | # What follows must be a method. 1251 | token = self._get_next_token() 1252 | if token.name == 'inline': 1253 | token = self._get_next_token() 1254 | if token.token_type == tokenize.SYNTAX and token.name == '~': 1255 | return self.get_method(FUNCTION_VIRTUAL + FUNCTION_DTOR, None) 1256 | return_type_and_name = self._get_tokens_up_to('(') 1257 | return_type_and_name.insert(0, token) 1258 | return self._get_method(return_type_and_name, FUNCTION_VIRTUAL, 1259 | None, False) 1260 | 1261 | def handle_public(self): 1262 | assert_parse(self.in_class, 'expected to be in a class') 1263 | 1264 | def handle_protected(self): 1265 | assert_parse(self.in_class, 'expected to be in a class') 1266 | 1267 | def handle_private(self): 1268 | assert_parse(self.in_class, 'expected to be in a class') 1269 | 1270 | def handle_friend(self): 1271 | tokens, last = self._get_var_tokens_up_to(False, '(', ';') 1272 | if last.name == '(': 1273 | tokens.append(last) 1274 | self._add_back_tokens(tokens) 1275 | token = self._get_next_token() 1276 | while token.name in ('inline', 'typename', '::'): 1277 | token = self._get_next_token() 1278 | result = self._generate_one(token) 1279 | else: 1280 | if tokens[0].name == 'class': 1281 | tokens = tokens[1:] 1282 | result = self.converter.to_type(tokens)[0] 1283 | 1284 | assert result 1285 | return Friend(result.start, result.end, result, self.namespace_stack) 1286 | 1287 | def handle_typedef(self): 1288 | token = self._get_next_token() 1289 | if (token.token_type == tokenize.NAME and 1290 | keywords.is_builtin_other_modifiers(token.name)): 1291 | method = getattr(self, 'handle_' + token.name) 1292 | self._handling_typedef = True 1293 | tokens = [method()] 1294 | self._handling_typedef = False 1295 | else: 1296 | tokens = [token] 1297 | 1298 | # Get the remainder of the typedef up to the semi-colon. 1299 | tokens.extend(self._get_tokens_up_to(';')) 1300 | 1301 | name = tokens.pop() 1302 | if name.name == ')': 1303 | tokens.append(name) 1304 | end = len(tokens) - 2 1305 | count = 1 1306 | while count: 1307 | if tokens[end].name == '(': 1308 | count -= 1 1309 | elif tokens[end].name == ')': 1310 | count += 1 1311 | end -= 1 1312 | start = end 1313 | if tokens[start].name == ')': 1314 | name = tokens[start - 1] 1315 | while tokens[start].name != '(': 1316 | start -= 1 1317 | else: 1318 | name = tokens[start] 1319 | del tokens[start:end + 1] 1320 | elif name.name == ']' and len(tokens) >= 2: 1321 | tokens.append(name) 1322 | name = tokens[1] 1323 | del tokens[1] 1324 | new_type = tokens 1325 | if tokens and isinstance(tokens[0], tokenize.Token): 1326 | new_type = self.converter.to_type(tokens) 1327 | return Typedef(name.start, name.end, name.name, 1328 | new_type, self.namespace_stack) 1329 | 1330 | def handle_typename(self): 1331 | pass # Not needed yet. 1332 | 1333 | def _get_templated_types(self): 1334 | result = {} 1335 | tokens = list(self._get_matching_char('<', '>')) 1336 | len_tokens = len(tokens) - 1 # Ignore trailing '>'. 1337 | i = 0 1338 | while i < len_tokens: 1339 | key = tokens[i].name 1340 | i += 1 1341 | if keywords.is_keyword(key) or key == ',' or key == '.': 1342 | continue 1343 | type_name = default = None 1344 | if i < len_tokens: 1345 | i += 1 1346 | if tokens[i - 1].name == '=': 1347 | assert_parse(i < len_tokens, '%s %s' % (i, tokens)) 1348 | default, _ = self.get_name(tokens[i:]) 1349 | i += len(default) 1350 | elif tokens[i - 1].name != ',': 1351 | # We got something like: Type variable. 1352 | # Re-adjust the key (variable) and type_name (Type). 1353 | key = tokens[i - 1].name 1354 | type_name = tokens[i - 2] 1355 | 1356 | result[key] = (type_name, default) 1357 | return result 1358 | 1359 | def handle_template(self): 1360 | token = self._get_next_token() 1361 | 1362 | templated_types = None 1363 | if token.token_type == tokenize.SYNTAX and token.name == '<': 1364 | templated_types = self._get_templated_types() 1365 | token = self._get_next_token() 1366 | while token.token_type == tokenize.PREPROCESSOR: 1367 | token = self._get_next_token() 1368 | 1369 | if token.token_type == tokenize.NAME: 1370 | if token.name == 'class': 1371 | return self._get_class(Class, 1372 | templated_types) 1373 | elif token.name == 'struct': 1374 | return self._get_class(Struct, 1375 | templated_types) 1376 | elif token.name == 'union': 1377 | return self._get_class(Union, 1378 | templated_types) 1379 | elif token.name == 'friend': 1380 | return self.handle_friend() 1381 | elif token.name == 'template': 1382 | return self.handle_template() 1383 | self._add_back_token(token) 1384 | tokens, last = self._get_var_tokens_up_to(False, '(', ';') 1385 | tokens.append(last) 1386 | self._add_back_tokens(tokens) 1387 | if last.name == '(': 1388 | return self.get_method(FUNCTION_NONE, templated_types) 1389 | # Must be a variable definition. 1390 | return None 1391 | 1392 | def _get_bases(self): 1393 | # Get base classes. 1394 | bases = [] 1395 | specifier = ('public', 'protected', 'private', 'virtual') 1396 | while True: 1397 | token = self._get_next_token() 1398 | if ( 1399 | token.name in specifier or 1400 | token.token_type == tokenize.PREPROCESSOR 1401 | ): 1402 | continue 1403 | self._add_back_token(token) 1404 | 1405 | base, next_token = self.get_name() 1406 | if ( 1407 | len(base) > 2 and 1408 | base[-2].name == '::' and 1409 | next_token.token_type == tokenize.NAME and 1410 | next_token.name not in specifier 1411 | ): 1412 | self._add_back_token(next_token) 1413 | base2, next_token = self.get_name() 1414 | base.pop() 1415 | base.extend(base2) 1416 | bases_ast = self.converter.to_type(base) 1417 | if len(bases_ast) == 1: 1418 | bases.append(bases_ast[0]) 1419 | if next_token.name == ')': 1420 | next_token = self._get_next_token() 1421 | while next_token.token_type == tokenize.PREPROCESSOR: 1422 | next_token = self._get_next_token() 1423 | if next_token.name == '{': 1424 | token = next_token 1425 | break 1426 | return bases, token 1427 | 1428 | def _get_class(self, class_type, templated_types): 1429 | class_name = None 1430 | class_token = self._get_next_token() 1431 | if class_token.token_type != tokenize.NAME: 1432 | assert_parse(class_token.token_type == tokenize.SYNTAX, 1433 | class_token) 1434 | token = class_token 1435 | else: 1436 | self._add_back_token(class_token) 1437 | name_tokens, token = self.get_name() 1438 | 1439 | if self._handling_typedef: 1440 | # Handle typedef to pointer. 1441 | if token.name in '*&': 1442 | name_tokens.append(token) 1443 | token = self._get_next_token() 1444 | # Handle attribute. 1445 | elif token.token_type == tokenize.NAME: 1446 | self._add_back_token(token) 1447 | attribute, token = self.get_name() 1448 | if len(attribute) > 1 or attribute[0].name != 'final': 1449 | name_tokens = attribute 1450 | class_name = self.converter.to_type(name_tokens)[0].name 1451 | assert_parse(class_name, class_token) 1452 | 1453 | bases = None 1454 | if token.token_type == tokenize.PREPROCESSOR: 1455 | token = self._get_next_token() 1456 | if token.token_type == tokenize.SYNTAX: 1457 | if token.name == ';': 1458 | # Forward declaration. 1459 | return class_type(class_token.start, class_token.end, 1460 | class_name, None, templated_types, None, 1461 | self.namespace_stack) 1462 | if token.name in '*&': 1463 | # Inline forward declaration. Could be method or data. 1464 | name_token = self._get_next_token() 1465 | next_token = self._get_next_token() 1466 | if next_token.name == ';': 1467 | # Handle data 1468 | modifiers = ['class'] 1469 | return self._create_variable(class_token, name_token.name, 1470 | class_name, 1471 | modifiers, token.name) 1472 | else: 1473 | # Assume this is a method. 1474 | tokens = (class_token, token, name_token, next_token) 1475 | self._add_back_tokens(tokens) 1476 | return self.get_method(FUNCTION_NONE, None) 1477 | if token.name == ':': 1478 | bases, token = self._get_bases() 1479 | 1480 | body = None 1481 | if token.token_type == tokenize.SYNTAX and token.name == '{': 1482 | name = class_name or '__unamed__' 1483 | ast = ASTBuilder(self.get_scope(), self.filename, name, 1484 | self.namespace_stack, 1485 | quiet=self.quiet) 1486 | body = list(ast.generate()) 1487 | 1488 | if not self._handling_typedef: 1489 | token = self._get_next_token() 1490 | if token.token_type != tokenize.NAME: 1491 | assert_parse(token.token_type == tokenize.SYNTAX, token) 1492 | assert_parse(token.name == ';', token) 1493 | else: 1494 | if keywords.is_builtin_type(token.name): 1495 | token = self._get_next_token() 1496 | self._ignore_up_to(';') 1497 | new_class = class_type(class_token.start, class_token.end, 1498 | class_name, bases, None, 1499 | body, self.namespace_stack) 1500 | 1501 | modifiers = ['const'] if self._handling_const else [] 1502 | return self._create_variable(class_token, 1503 | token.name, new_class, 1504 | modifiers, token.name) 1505 | else: 1506 | if not self._handling_typedef: 1507 | name_tokens = [class_token] + name_tokens 1508 | return self._get_method(name_tokens, 0, None, False) 1509 | self._add_back_token(token) 1510 | 1511 | return class_type(class_token.start, class_token.end, class_name, 1512 | bases, templated_types, body, self.namespace_stack) 1513 | 1514 | def handle_namespace(self): 1515 | token = self._get_next_token() 1516 | # Support anonymous namespaces. 1517 | name = None 1518 | if token.token_type == tokenize.NAME: 1519 | name = token.name 1520 | token = self._get_next_token() 1521 | assert_parse(token.token_type == tokenize.SYNTAX, token) 1522 | 1523 | if token.name == '=': 1524 | # TODO(nnorwitz): handle aliasing namespaces. 1525 | name, next_token = self.get_name() 1526 | assert_parse(next_token.name == ';', next_token) 1527 | else: 1528 | assert_parse(token.name == '{', token) 1529 | self.namespace_stack.append(name) 1530 | self.namespaces.append(True) 1531 | return None 1532 | 1533 | def handle_using(self): 1534 | tokens = self._get_tokens_up_to(';') 1535 | assert tokens 1536 | return Using(tokens[0].start, tokens[0].end, tokens) 1537 | 1538 | def handle_explicit(self): 1539 | assert self.in_class 1540 | # Nothing much to do. 1541 | # TODO(nnorwitz): maybe verify the method name == class name. 1542 | # This must be a ctor. 1543 | return self.get_method(FUNCTION_CTOR, None) 1544 | 1545 | def handle_operator(self): 1546 | # Pull off the next token(s?) and make that part of the method name. 1547 | pass 1548 | 1549 | 1550 | def builder_from_source(source, filename, quiet=False): 1551 | """Utility method that returns an ASTBuilder from source code. 1552 | 1553 | Args: 1554 | source: 'C++ source code' 1555 | filename: 'file1' 1556 | 1557 | Returns: 1558 | ASTBuilder 1559 | 1560 | """ 1561 | return ASTBuilder(tokenize.get_tokens(source), 1562 | filename, 1563 | quiet=quiet) 1564 | 1565 | 1566 | def assert_parse(value, message): 1567 | """Raise ParseError on token if value is False.""" 1568 | if not value: 1569 | raise ParseError(message) 1570 | -------------------------------------------------------------------------------- /src/tools/cpp/find_warnings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2007 Neal Norwitz 2 | # Portions Copyright 2007 Google Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Find warnings for C++ code. 17 | 18 | TODO(nnorwitz): provide a mechanism to configure which warnings should 19 | be generated and which should be suppressed. Currently, all possible 20 | warnings will always be displayed. There is no way to suppress any. 21 | There also needs to be a way to use annotations in the source code to 22 | suppress warnings. 23 | 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import print_function 28 | from __future__ import unicode_literals 29 | 30 | import os 31 | import sys 32 | 33 | from . import ast 34 | from . import headers 35 | from . import keywords 36 | from . import metrics 37 | from . import symbols 38 | from . import tokenize 39 | from . import utils 40 | 41 | 42 | try: 43 | basestring 44 | except NameError: 45 | basestring = str 46 | 47 | 48 | __author__ = 'nnorwitz@google.com (Neal Norwitz)' 49 | 50 | 51 | HEADER_EXTENSIONS = frozenset(['.h', '.hh', '.hpp', '.h++', '.hxx', '.cuh']) 52 | CPP_EXTENSIONS = frozenset(['.cc', '.cpp', '.c++', '.cxx', '.cu']) 53 | 54 | # These enumerations are used to determine how a symbol/#include file is used. 55 | UNUSED = 0 56 | USES_REFERENCE = 1 57 | USES_DECLARATION = 2 58 | 59 | DECLARATION_TYPES = (ast.Class, ast.Struct, ast.Enum, ast.Union) 60 | 61 | 62 | class Module(object): 63 | 64 | """Data container representing a single source file.""" 65 | 66 | def __init__(self, filename, ast_list): 67 | self.filename = filename 68 | self.ast_list = ast_list 69 | self.public_symbols = self._get_exported_symbols() 70 | 71 | def _get_exported_symbols(self): 72 | if not self.ast_list: 73 | return {} 74 | return dict([(n.name, n) for n in self.ast_list if n.is_exportable()]) 75 | 76 | 77 | def is_header_file(filename): 78 | _, ext = os.path.splitext(filename) 79 | return ext.lower() in HEADER_EXTENSIONS 80 | 81 | 82 | def is_cpp_file(filename): 83 | _, ext = os.path.splitext(filename) 84 | return ext.lower() in CPP_EXTENSIONS 85 | 86 | 87 | class WarningHunter(object): 88 | 89 | # Cache filename: ast_list 90 | _module_cache = {} 91 | 92 | def __init__(self, filename, source, ast_list, include_paths, quiet=False): 93 | self.filename = filename 94 | self.source = source 95 | self.ast_list = ast_list 96 | self.include_paths = include_paths[:] 97 | self.quiet = quiet 98 | self.symbol_table = symbols.SymbolTable() 99 | 100 | self.metrics = metrics.Metrics(source) 101 | self.warnings = set() 102 | if filename not in self._module_cache: 103 | self._module_cache[filename] = Module(filename, ast_list) 104 | 105 | def _add_warning(self, msg, node, filename=None): 106 | if filename is not None: 107 | contents = utils.read_file(filename) 108 | src_metrics = metrics.Metrics(contents) 109 | else: 110 | filename = self.filename 111 | src_metrics = self.metrics 112 | line_number = get_line_number(src_metrics, node) 113 | self.warnings.add((filename, line_number, msg)) 114 | 115 | def show_warnings(self): 116 | for filename, line_num, msg in sorted(self.warnings): 117 | if line_num == 0: 118 | print('{}: {}'.format(filename, msg)) 119 | else: 120 | print('{}:{}: {}'.format(filename, line_num, msg)) 121 | 122 | def find_warnings(self): 123 | if is_header_file(self.filename): 124 | self._find_header_warnings() 125 | elif is_cpp_file(self.filename): 126 | self._find_source_warnings() 127 | 128 | def _update_symbol_table(self, module): 129 | for name, node in module.public_symbols.items(): 130 | self.symbol_table.add_symbol(name, node.namespace, node, module) 131 | 132 | def _get_module(self, node): 133 | include_paths = [os.path.dirname(self.filename)] + self.include_paths 134 | source, filename = headers.read_source(node.filename, include_paths) 135 | 136 | if source is None: 137 | module = Module(filename, None) 138 | msg = "unable to find '{}'".format(filename) 139 | self._add_warning(msg, node) 140 | elif filename in self._module_cache: 141 | # The cache survives across all instances, but the symbol table 142 | # is per instance, so we need to make sure the symbol table 143 | # is updated even if the module was in the cache. 144 | module = self._module_cache[filename] 145 | self._update_symbol_table(module) 146 | else: 147 | ast_list = None 148 | try: 149 | builder = ast.builder_from_source(source, filename, 150 | quiet=self.quiet) 151 | ast_list = [_f for _f in builder.generate() if _f] 152 | except tokenize.TokenError: 153 | pass 154 | except ast.ParseError as error: 155 | if not self.quiet: 156 | print( 157 | "Exception while processing '{}': {}".format( 158 | filename, 159 | error), 160 | file=sys.stderr) 161 | module = Module(filename, ast_list) 162 | self._module_cache[filename] = module 163 | self._update_symbol_table(module) 164 | return module 165 | 166 | def _read_and_parse_includes(self): 167 | # Map header-filename: (#include AST node, module). 168 | included_files = {} 169 | # Map declaration-name: AST node. 170 | forward_declarations = {} 171 | files_seen = {} 172 | for node in self.ast_list: 173 | if isinstance(node, ast.Include): 174 | if node.system: 175 | filename = node.filename 176 | else: 177 | module = self._get_module(node) 178 | filename = module.filename 179 | _, ext = os.path.splitext(filename) 180 | if ext.lower() != '.hxx': 181 | included_files[filename] = node, module 182 | if is_cpp_file(filename): 183 | self._add_warning( 184 | "should not #include C++ source file '{}'".format( 185 | node.filename), 186 | node) 187 | if filename == self.filename: 188 | self._add_warning( 189 | "'{}' #includes itself".format(node.filename), 190 | node) 191 | if filename in files_seen: 192 | include_node = files_seen[filename] 193 | line_num = get_line_number(self.metrics, include_node) 194 | self._add_warning( 195 | "'{}' already #included on line {}".format( 196 | node.filename, 197 | line_num), 198 | node) 199 | else: 200 | files_seen[filename] = node 201 | if isinstance(node, DECLARATION_TYPES) and node.is_declaration(): 202 | forward_declarations[node.full_name()] = node 203 | 204 | return included_files, forward_declarations 205 | 206 | def _verify_include_files_used(self, file_uses, included_files): 207 | """Find all #include files that are unnecessary.""" 208 | for include_file, use in file_uses.items(): 209 | if not use & USES_DECLARATION: 210 | node, module = included_files[include_file] 211 | if module.ast_list is not None: 212 | msg = "'{}' does not need to be #included".format( 213 | node.filename) 214 | if use & USES_REFERENCE: 215 | msg += '; use a forward declaration instead' 216 | self._add_warning(msg, node) 217 | 218 | def _verify_forward_declarations_used(self, forward_declarations, 219 | decl_uses, file_uses): 220 | """Find all the forward declarations that are not used.""" 221 | for cls in forward_declarations: 222 | if cls in file_uses: 223 | if not decl_uses[cls] & USES_DECLARATION: 224 | node = forward_declarations[cls] 225 | msg = ("'{}' forward declared, " 226 | 'but needs to be #included'.format(cls)) 227 | self._add_warning(msg, node) 228 | else: 229 | if decl_uses[cls] == UNUSED: 230 | node = forward_declarations[cls] 231 | msg = "'{}' not used".format(cls) 232 | self._add_warning(msg, node) 233 | 234 | def _determine_uses(self, included_files, forward_declarations): 235 | """Set up the use type of each symbol.""" 236 | file_uses = dict.fromkeys(included_files, UNUSED) 237 | decl_uses = dict.fromkeys(forward_declarations, UNUSED) 238 | symbol_table = self.symbol_table 239 | 240 | for name, node in forward_declarations.items(): 241 | try: 242 | symbol_table.lookup_symbol(node.name, node.namespace) 243 | decl_uses[name] |= USES_REFERENCE 244 | except symbols.Error: 245 | module = Module(name, None) 246 | self.symbol_table.add_symbol(node.name, node.namespace, node, 247 | module) 248 | 249 | def _add_declaration(name, namespace): 250 | if not name: 251 | # Ignore anonymous struct. It is not standard, but we might as 252 | # well avoid crashing if it is easy. 253 | return 254 | 255 | names = [n for n in namespace if n is not None] 256 | if names: 257 | name = '::'.join(names) + '::' + name 258 | if name in decl_uses: 259 | decl_uses[name] |= USES_DECLARATION 260 | 261 | def _add_reference(name, namespace): 262 | try: 263 | file_use_node = symbol_table.lookup_symbol(name, namespace) 264 | except symbols.Error: 265 | return 266 | 267 | name = file_use_node[1].filename 268 | if file_use_node[1].ast_list is None: 269 | decl_uses[name] |= USES_REFERENCE 270 | elif name in file_uses: 271 | # enum and typedef can't be forward declared 272 | if ( 273 | isinstance(file_use_node[0], ast.Enum) or 274 | isinstance(file_use_node[0], ast.Typedef) 275 | ): 276 | file_uses[name] |= USES_DECLARATION 277 | else: 278 | file_uses[name] |= USES_REFERENCE 279 | 280 | def _add_use(name, namespace): 281 | if isinstance(name, list): 282 | # name contains a list of tokens. 283 | name = '::'.join([n.name for n in name]) 284 | elif not isinstance(name, basestring): 285 | # Happens when variables are defined with inlined types, e.g.: 286 | # enum {...} variable; 287 | return 288 | try: 289 | file_use_node = symbol_table.lookup_symbol(name, namespace) 290 | except symbols.Error: 291 | return 292 | 293 | name = file_use_node[1].filename 294 | file_uses[name] = file_uses.get(name, 0) | USES_DECLARATION 295 | 296 | def _add_variable(node, namespace, reference=False): 297 | if node.reference or node.pointer or reference: 298 | _add_reference(node.name, namespace) 299 | else: 300 | _add_use(node.name, namespace) 301 | # This needs to recurse when the node is a templated type. 302 | _add_template_use(node.name, 303 | node.templated_types, 304 | namespace, 305 | reference) 306 | 307 | def _process_function(function, namespace): 308 | reference = function.body is None 309 | if function.return_type: 310 | return_type = function.return_type 311 | _add_variable(return_type, namespace, reference) 312 | 313 | for s in function.specializations: 314 | _add_variable(s, namespace, not function.body) 315 | 316 | templated_types = function.templated_types or () 317 | for p in function.parameters: 318 | node = p.type 319 | if node.name not in templated_types: 320 | if function.body and p.name: 321 | # Assume that if the function has a body and a name 322 | # the parameter type is really used. 323 | # NOTE(nnorwitz): this is over-aggressive. It would be 324 | # better to iterate through the body and determine 325 | # actual uses based on local vars and data members 326 | # used. 327 | _add_use(node.name, namespace) 328 | elif ( 329 | p.default and 330 | p.default[0].name != '0' and 331 | p.default[0].name != 'NULL' and 332 | p.default[0].name != 'nullptr' 333 | ): 334 | _add_use(node.name, namespace) 335 | elif node.reference or node.pointer or reference: 336 | _add_reference(node.name, namespace) 337 | else: 338 | _add_use(node.name, namespace) 339 | _add_template_use(node.name, 340 | node.templated_types, 341 | namespace, 342 | reference) 343 | 344 | def _process_function_body(function, namespace): 345 | previous = None 346 | save = namespace[:] 347 | for t in function.body: 348 | if t.token_type == tokenize.NAME: 349 | previous = t 350 | if not keywords.is_keyword(t.name): 351 | # TODO(nnorwitz): handle static function calls. 352 | # TODO(nnorwitz): handle using statements in file. 353 | # TODO(nnorwitz): handle using statements in function. 354 | # TODO(nnorwitz): handle namespace assignment in file. 355 | _add_use(t.name, namespace) 356 | elif t.name == '::' and previous is not None: 357 | namespace.append(previous.name) 358 | elif t.name in (':', ';'): 359 | namespace = save[:] 360 | 361 | def _add_template_use(name, types, namespace, reference=False): 362 | for cls in types or (): 363 | if cls.pointer or cls.reference or reference: 364 | _add_reference(cls.name, namespace) 365 | elif name.endswith('_ptr'): 366 | # Special case templated classes that end w/_ptr. 367 | # These are things like auto_ptr which do 368 | # not require the class definition, only decl. 369 | _add_reference(cls.name, namespace) 370 | else: 371 | _add_use(cls.name, namespace) 372 | _add_template_use(cls.name, cls.templated_types, 373 | namespace, reference) 374 | 375 | def _process_types(nodes, namespace): 376 | for node in nodes: 377 | if isinstance(node, ast.Type): 378 | _add_variable(node, namespace) 379 | 380 | # Iterate through the source AST/tokens, marking each symbols use. 381 | ast_seq = [self.ast_list] 382 | namespace_stack = [] 383 | while ast_seq: 384 | for node in ast_seq.pop(): 385 | if isinstance(node, ast.VariableDeclaration): 386 | namespace = namespace_stack + node.namespace 387 | _add_variable(node.type, namespace) 388 | elif isinstance(node, ast.Function): 389 | namespace = namespace_stack + node.namespace 390 | _process_function(node, namespace) 391 | if node.body: 392 | _process_function_body(node, namespace) 393 | elif isinstance(node, ast.Typedef): 394 | namespace = namespace_stack + node.namespace 395 | _process_types(node.alias, namespace) 396 | elif isinstance(node, ast.Friend): 397 | expr = node.expr 398 | namespace = namespace_stack + node.namespace 399 | if isinstance(expr, ast.Type): 400 | _add_reference(expr.name, namespace) 401 | elif isinstance(expr, ast.Function): 402 | _process_function(expr, namespace) 403 | elif isinstance(node, ast.Union) and node.body is not None: 404 | ast_seq.append(node.body) 405 | elif isinstance(node, ast.Class) and node.body is not None: 406 | _add_declaration(node.name, node.namespace) 407 | namespace = namespace_stack + node.namespace 408 | _add_template_use('', node.bases, namespace) 409 | ast_seq.append(node.body) 410 | elif isinstance(node, ast.Using): 411 | if node.names[0].name == 'namespace': 412 | namespace_stack.append(node.names[1].name) 413 | 414 | return file_uses, decl_uses 415 | 416 | def _find_unused_warnings(self, included_files, forward_declarations, 417 | primary_header=None): 418 | file_uses, decl_uses = self._determine_uses(included_files, 419 | forward_declarations) 420 | if primary_header and primary_header.filename in file_uses: 421 | file_uses[primary_header.filename] |= USES_DECLARATION 422 | self._verify_include_files_used(file_uses, included_files) 423 | self._verify_forward_declarations_used(forward_declarations, decl_uses, 424 | file_uses) 425 | for node in forward_declarations.values(): 426 | try: 427 | file_use_node = self.symbol_table.lookup_symbol(node.name, 428 | node.namespace) 429 | except symbols.Error: 430 | continue 431 | name = file_use_node[1].filename 432 | if ( 433 | file_use_node[1].ast_list is not None and 434 | name in file_uses and 435 | file_uses[name] & USES_DECLARATION 436 | ): 437 | msg = ("'{}' forward declared, " 438 | "but already #included in '{}'".format(node.name, name)) 439 | self._add_warning(msg, node) 440 | 441 | def _find_incorrect_case(self, included_files): 442 | for (filename, node_and_module) in included_files.items(): 443 | base_name = os.path.basename(filename) 444 | try: 445 | candidates = os.listdir(os.path.dirname(filename)) 446 | except OSError: 447 | continue 448 | 449 | correct_filename = get_correct_include_filename(base_name, 450 | candidates) 451 | if correct_filename: 452 | self._add_warning( 453 | "'{}' should be '{}'".format(base_name, correct_filename), 454 | node_and_module[0]) 455 | 456 | def _find_header_warnings(self): 457 | included_files, forward_declarations = self._read_and_parse_includes() 458 | self._find_unused_warnings(included_files, forward_declarations) 459 | self._find_incorrect_case(included_files) 460 | 461 | def _find_public_function_warnings(self, node, name, primary_header, 462 | all_headers): 463 | # Not found in the primary header, search all other headers. 464 | for _, header in all_headers.values(): 465 | if name in header.public_symbols: 466 | # If the primary.filename == header.filename, it probably 467 | # indicates an error elsewhere. It sucks to mask it, 468 | # but false positives are worse. 469 | if primary_header: 470 | msg = ("expected to find '{}' in '{}', " 471 | "but found in '{}'".format(name, 472 | primary_header.filename, 473 | header.filename)) 474 | self._add_warning(msg, node) 475 | break 476 | else: 477 | where = 'in any directly #included header' 478 | if primary_header: 479 | where = ( 480 | "in expected header '{}' " 481 | 'or any other directly #included header'.format( 482 | primary_header.filename)) 483 | 484 | if name != 'main' and name != name.upper(): 485 | self._add_warning("'{}' not found {}".format(name, where), 486 | node) 487 | 488 | def _check_public_functions(self, primary_header, all_headers): 489 | """Verify all the public functions are also declared in a header 490 | file.""" 491 | public_symbols = {} 492 | declared_only_symbols = {} 493 | if primary_header: 494 | for name, symbol in primary_header.public_symbols.items(): 495 | if isinstance(symbol, ast.Function): 496 | public_symbols[name] = symbol 497 | declared_only_symbols = dict.fromkeys(public_symbols, True) 498 | 499 | for node in self.ast_list: 500 | # Make sure we have a function that should be exported. 501 | if not isinstance(node, ast.Function): 502 | continue 503 | if isinstance(node, ast.Method): 504 | # Ensure that for Foo::Bar, Foo is *not* a namespace. 505 | # If Foo is a namespace, we have a function and not a method. 506 | names = [n.name for n in node.in_class] 507 | if names != self.symbol_table.get_namespace(names): 508 | continue 509 | if not (node.is_definition() and node.is_exportable()): 510 | continue 511 | 512 | # This function should be declared in a header file. 513 | name = node.name 514 | if name in public_symbols: 515 | declared_only_symbols[name] = False 516 | else: 517 | self._find_public_function_warnings(node, 518 | name, 519 | primary_header, 520 | all_headers) 521 | 522 | for name, declared_only in declared_only_symbols.items(): 523 | if declared_only: 524 | node = public_symbols[name] 525 | if node.templated_types is None: 526 | msg = "'{}' declared but not defined".format(name) 527 | self._add_warning(msg, node, primary_header.filename) 528 | 529 | def _get_primary_header(self, included_files): 530 | basename = os.path.basename(os.path.splitext(self.filename)[0]) 531 | include_paths = [os.path.dirname(self.filename)] + self.include_paths 532 | source, filename = headers.read_source(basename + '.h', include_paths) 533 | primary_header = included_files.get(filename) 534 | if primary_header: 535 | return primary_header[1] 536 | if source is not None: 537 | msg = "should #include header file '{}'".format(filename) 538 | self.warnings.add((self.filename, 0, msg)) 539 | return None 540 | 541 | def _find_source_warnings(self): 542 | included_files, forward_declarations = self._read_and_parse_includes() 543 | self._find_incorrect_case(included_files) 544 | 545 | for node in forward_declarations.values(): 546 | # TODO(nnorwitz): This really isn't a problem, but might 547 | # be something to warn against. I expect this will either 548 | # be configurable or removed in the future. But it's easy 549 | # to check for now. 550 | msg = ( 551 | "'{}' forward declaration not expected in source file".format( 552 | node.name)) 553 | self._add_warning(msg, node) 554 | 555 | # A primary header is optional. However, when looking up 556 | # defined methods in the source, always look in the 557 | # primary_header first. Expect that is the most likely location. 558 | # Use of primary_header is primarily an optimization. 559 | primary_header = self._get_primary_header(included_files) 560 | 561 | self._check_public_functions(primary_header, included_files) 562 | if primary_header and primary_header.ast_list is not None: 563 | includes = [ 564 | node.filename 565 | for node in primary_header.ast_list 566 | if isinstance(node, ast.Include) 567 | ] 568 | for (node, _) in included_files.values(): 569 | if node.filename in includes: 570 | msg = "'{}' already #included in '{}'".format( 571 | node.filename, primary_header.filename) 572 | self._add_warning(msg, node) 573 | 574 | # TODO(nnorwitz): other warnings to add: 575 | # * unused forward decls for variables (globals)/classes 576 | # * Functions that are too large/complex 577 | # * Variables declared far from first use 578 | # * primitive member variables not initialized in ctor 579 | 580 | 581 | def get_line_number(metrics_instance, node): 582 | return metrics_instance.get_line_number(node.start) 583 | 584 | 585 | def get_correct_include_filename(filename, candidate_filenames): 586 | if filename not in candidate_filenames: 587 | for candidate in candidate_filenames: 588 | if filename.lower() == candidate.lower(): 589 | return candidate 590 | return None 591 | 592 | 593 | def run(filename, source, entire_ast, include_paths, quiet): 594 | hunter = WarningHunter(filename, source, entire_ast, 595 | include_paths=include_paths, 596 | quiet=quiet) 597 | hunter.find_warnings() 598 | hunter.show_warnings() 599 | return len(hunter.warnings) 600 | -------------------------------------------------------------------------------- /src/tools/cpp/headers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2007 Neal Norwitz 2 | # Portions Copyright 2007 Google Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Find and print the headers #include'd in a source file.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import print_function 20 | from __future__ import unicode_literals 21 | 22 | import os 23 | 24 | from . import utils 25 | 26 | 27 | __author__ = 'nnorwitz@google.com (Neal Norwitz)' 28 | 29 | 30 | def read_source(filename, include_paths): 31 | for path in include_paths: 32 | actual_filename = os.path.join(path, filename) 33 | source = utils.read_file(actual_filename, False) 34 | if source is not None: 35 | return source, actual_filename 36 | return None, filename 37 | -------------------------------------------------------------------------------- /src/tools/cpp/keywords.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2007 Neal Norwitz 3 | # Portions Copyright 2007 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """C++ keywords and helper utilities for determining keywords.""" 18 | 19 | from __future__ import unicode_literals 20 | 21 | __author__ = 'nnorwitz@google.com (Neal Norwitz)' 22 | 23 | 24 | TYPES = frozenset(['bool', 'char', 'int', 'long', 'short', 'double', 'float', 25 | 'void', 'wchar_t', 'unsigned', 'signed', 'size_t', 'auto', 26 | 'asm']) 27 | 28 | TYPE_MODIFIERS = frozenset(['register', 'const', 'constexpr', 'extern', 29 | 'static', 'volatile', 'mutable']) 30 | 31 | OTHER_MODIFIERS = frozenset(['class', 'struct', 'union', 'enum']) 32 | 33 | ACCESS = frozenset(['public', 'protected', 'private', 'friend']) 34 | 35 | CASTS = frozenset(['static_cast', 'const_cast', 'dynamic_cast', 36 | 'reinterpret_cast']) 37 | 38 | OTHERS = frozenset(['true', 'false', 'asm', 'namespace', 'using', 39 | 'explicit', 'this', 'operator', 'sizeof', 40 | 'new', 'delete', 'typedef', 'typeid', 41 | 'typename', 'template', 'virtual', 'inline']) 42 | 43 | CONTROL = frozenset(['case', 'switch', 'default', 'if', 'else', 'return', 44 | 'goto']) 45 | 46 | EXCEPTION = frozenset(['try', 'catch', 'throw']) 47 | 48 | LOOP = frozenset(['while', 'do', 'for', 'break', 'continue']) 49 | 50 | ALL = (TYPES | TYPE_MODIFIERS | OTHER_MODIFIERS | ACCESS | CASTS | 51 | OTHERS | CONTROL | EXCEPTION | LOOP) 52 | 53 | 54 | def is_keyword(token): 55 | return token in ALL 56 | 57 | 58 | def is_builtin_type(token): 59 | return token in TYPES or token in TYPE_MODIFIERS 60 | 61 | 62 | def is_builtin_modifiers(token): 63 | return token in TYPE_MODIFIERS or token in OTHER_MODIFIERS 64 | 65 | 66 | def is_builtin_other_modifiers(token): 67 | return token in OTHER_MODIFIERS 68 | -------------------------------------------------------------------------------- /src/tools/cpp/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2007 Neal Norwitz 2 | # Portions Copyright 2007 Google Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Calculate metrics for C++ code.""" 17 | 18 | from __future__ import unicode_literals 19 | 20 | 21 | __author__ = 'nnorwitz@google.com (Neal Norwitz)' 22 | 23 | 24 | class Metrics(object): 25 | 26 | """Calculate various metrics on C++ source code.""" 27 | 28 | def __init__(self, source): 29 | self.source = source 30 | 31 | def get_line_number(self, index): 32 | """Return the line number in the source based on the index.""" 33 | return 1 + self.source.count('\n', 0, index) 34 | -------------------------------------------------------------------------------- /src/tools/cpp/nonvirtual_dtors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2008 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Print classes which have a virtual method and non-virtual destructor.""" 16 | 17 | from __future__ import print_function 18 | from __future__ import unicode_literals 19 | 20 | from . import ast 21 | from . import metrics 22 | 23 | 24 | __author__ = 'nnorwitz@google.com (Neal Norwitz)' 25 | 26 | 27 | def _find_warnings(filename, source, ast_list): 28 | count = 0 29 | for node in ast_list: 30 | if isinstance(node, ast.Class) and node.body: 31 | class_node = node 32 | has_virtuals = False 33 | for node in node.body: 34 | if isinstance(node, ast.Class) and node.body: 35 | _find_warnings(filename, source, [node]) 36 | elif (isinstance(node, ast.Function) and 37 | node.modifiers & ast.FUNCTION_VIRTUAL): 38 | has_virtuals = True 39 | if node.modifiers & ast.FUNCTION_DTOR: 40 | break 41 | else: 42 | if has_virtuals and not class_node.bases: 43 | lines = metrics.Metrics(source) 44 | print( 45 | '%s:%d' % ( 46 | filename, 47 | lines.get_line_number( 48 | class_node.start)), 49 | end=' ') 50 | print("'{}' has virtual methods without a virtual " 51 | 'dtor'.format(class_node.name)) 52 | count += 1 53 | 54 | return count 55 | 56 | 57 | def run(filename, source, entire_ast, include_paths, quiet): 58 | return _find_warnings(filename, source, entire_ast) 59 | -------------------------------------------------------------------------------- /src/tools/cpp/static_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2008 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Print classes, functions and modules which contain static data.""" 16 | 17 | from __future__ import print_function 18 | from __future__ import unicode_literals 19 | 20 | import collections 21 | 22 | from . import ast 23 | from . import metrics 24 | 25 | 26 | __author__ = 'nnorwitz@google.com (Neal Norwitz)' 27 | 28 | 29 | def _find_warnings(filename, lines, ast_list, static_is_optional): 30 | def print_warning(node): 31 | for name in node.name.split(','): 32 | print("{}:{}: static data '{}'".format( 33 | filename, 34 | lines.get_line_number(node.start), 35 | name)) 36 | 37 | def find_static(function_node): 38 | tokens = [] 39 | static_found = False 40 | for node in function_node.body: 41 | if node.name == 'static': 42 | static_found = True 43 | 44 | if static_found: 45 | tokens.append(node) 46 | if node.name == ';': 47 | body = list( 48 | ast.ASTBuilder(iter(tokens), filename).generate()) 49 | _find_warnings(filename, lines, body, False) 50 | tokens = [] 51 | static_found = False 52 | 53 | count = 0 54 | for node in ast_list: 55 | if isinstance(node, ast.VariableDeclaration): 56 | # Ignore 'static' at module scope so we can find globals too. 57 | is_static = 'static' in node.type.modifiers 58 | is_not_const = ( 59 | 'const' not in node.type.modifiers and 60 | 'constexpr' not in node.type.modifiers 61 | ) 62 | 63 | if is_not_const and (static_is_optional or is_static): 64 | print_warning(node) 65 | count += 1 66 | elif isinstance(node, ast.Function) and node.body: 67 | find_static(node) 68 | elif isinstance(node, ast.Class) and node.body: 69 | _find_warnings(filename, lines, node.body, False) 70 | 71 | return count 72 | 73 | 74 | def _get_static_declarations(ast_list): 75 | for node in ast_list: 76 | if (isinstance(node, ast.VariableDeclaration) and 77 | 'static' in node.type.modifiers): 78 | for name in node.name.split(','): 79 | yield (name, node) 80 | 81 | 82 | def _find_unused_static_warnings(filename, lines, ast_list): 83 | """Warn about unused static variables.""" 84 | static_declarations = dict(_get_static_declarations(ast_list)) 85 | 86 | def find_variables_use(body): 87 | for child in body: 88 | if child.name in static_declarations: 89 | static_use_counts[child.name] += 1 90 | 91 | static_use_counts = collections.Counter() 92 | for node in ast_list: 93 | if isinstance(node, ast.Function) and node.body: 94 | find_variables_use(node.body) 95 | elif isinstance(node, ast.Class) and node.body: 96 | for child in node.body: 97 | if isinstance(child, ast.Function) and child.body: 98 | find_variables_use(child.body) 99 | 100 | count = 0 101 | for (name, _) in sorted(static_declarations.items(), 102 | key=lambda x: x[1].start): 103 | if not static_use_counts[name]: 104 | print("{}:{}: unused variable '{}'".format( 105 | filename, 106 | lines.get_line_number(static_declarations[name].start), 107 | name)) 108 | count += 1 109 | 110 | return count 111 | 112 | 113 | def run(filename, source, entire_ast, include_paths, quiet): 114 | lines = metrics.Metrics(source) 115 | 116 | return ( 117 | _find_warnings(filename, lines, entire_ast, True) + 118 | _find_unused_static_warnings(filename, lines, entire_ast) 119 | ) 120 | -------------------------------------------------------------------------------- /src/tools/cpp/symbols.py: -------------------------------------------------------------------------------- 1 | # Copyright 2007 Neal Norwitz 2 | # Portions Copyright 2007 Google Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Symbol Table utility code.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import unicode_literals 20 | 21 | 22 | __author__ = 'nnorwitz@google.com (Neal Norwitz)' 23 | 24 | 25 | class Error(BaseException): 26 | 27 | """Exception raised when lookup fails.""" 28 | 29 | 30 | class Symbol(object): 31 | 32 | """Data container used internally.""" 33 | 34 | def __init__(self, name, parts, namespace_stack): 35 | self.name = name 36 | self.parts = parts 37 | self.namespace_stack = namespace_stack 38 | 39 | 40 | class SymbolTable(object): 41 | 42 | """Symbol table that can perform namespace operations.""" 43 | 44 | def __init__(self): 45 | # None is the global namespace. 46 | self.namespaces = {None: {}} 47 | 48 | def _lookup_namespace(self, symbol, namespace): 49 | """Helper for lookup_symbol that only looks up variables in a 50 | namespace. 51 | 52 | Args: 53 | symbol: Symbol 54 | namespace: pointer into self.namespaces 55 | 56 | """ 57 | for namespace_part in symbol.parts: 58 | namespace = namespace.get(namespace_part) 59 | if namespace is None: 60 | break 61 | if not isinstance(namespace, dict): 62 | return namespace 63 | raise Error('%s not found' % symbol.name) 64 | 65 | def _lookup_global(self, symbol): 66 | """Helper for lookup_symbol that only looks up global variables. 67 | 68 | Args: 69 | symbol: Symbol 70 | 71 | """ 72 | assert symbol.parts 73 | namespace = self.namespaces 74 | if len(symbol.parts) == 1: 75 | # If there is only one part, look in globals. 76 | namespace = self.namespaces[None] 77 | try: 78 | # Try to do a normal, global namespace lookup. 79 | return self._lookup_namespace(symbol, namespace) 80 | except Error as orig_exc: 81 | try: 82 | # The normal lookup can fail if all of the parts aren't 83 | # namespaces. This happens with OuterClass::Inner. 84 | namespace = self.namespaces[None] 85 | return self._lookup_namespace(symbol, namespace) 86 | except Error: 87 | raise orig_exc 88 | 89 | def _lookup_in_all_namespaces(self, symbol): 90 | """Helper for lookup_symbol that looks for symbols in all namespaces. 91 | 92 | Args: 93 | symbol: Symbol 94 | 95 | """ 96 | namespace = self.namespaces 97 | # Create a stack of namespaces. 98 | namespace_stack = [] 99 | for current in symbol.namespace_stack: 100 | namespace = namespace.get(current) 101 | if namespace is None or not isinstance(namespace, dict): 102 | break 103 | namespace_stack.append(namespace) 104 | 105 | # Iterate through the stack in reverse order. Need to go from 106 | # innermost namespace to outermost. 107 | for namespace in reversed(namespace_stack): 108 | try: 109 | return self._lookup_namespace(symbol, namespace) 110 | except Error: 111 | pass 112 | return None 113 | 114 | def lookup_symbol(self, name, namespace_stack): 115 | """Returns AST node and module for symbol if found. 116 | 117 | Args: 118 | name: 'name of the symbol to lookup' 119 | namespace_stack: None or ['namespaces', 'in', 'current', 'scope'] 120 | 121 | Returns: 122 | (ast.Node, module (ie, any object stored with symbol)) if found 123 | 124 | Raises: 125 | Error if the symbol cannot be found. 126 | 127 | """ 128 | # TODO(nnorwitz): a convenient API for this depends on the 129 | # representation of the name. e.g., does symbol_name contain 130 | # ::, is symbol_name a list of colon separated names, how are 131 | # names prefixed with :: handled. These have different lookup 132 | # semantics (if leading ::) or change the desirable API. 133 | 134 | # For now assume that the symbol_name contains :: and parse it. 135 | symbol = Symbol(name, name.split('::'), namespace_stack) 136 | assert symbol.parts 137 | if symbol.parts[0] == '': 138 | # Handle absolute (global) ::symbol_names. 139 | symbol.parts = symbol.parts[1:] 140 | elif namespace_stack is not None: 141 | result = self._lookup_in_all_namespaces(symbol) 142 | if result: 143 | return result 144 | 145 | return self._lookup_global(symbol) 146 | 147 | def _add(self, symbol_name, namespace, node, module): 148 | """Helper function for adding symbols. 149 | 150 | See add_symbol(). 151 | 152 | """ 153 | result = symbol_name in namespace 154 | namespace[symbol_name] = node, module 155 | return not result 156 | 157 | def add_symbol(self, symbol_name, namespace_stack, node, module): 158 | """Adds symbol_name defined in namespace_stack to the symbol table. 159 | 160 | Args: 161 | symbol_name: 'name of the symbol to lookup' 162 | namespace_stack: None or ['namespaces', 'symbol', 'defined', 'in'] 163 | node: ast.Node that defines this symbol 164 | module: module (any object) this symbol is defined in 165 | 166 | Returns: 167 | bool(if symbol was *not* already present) 168 | 169 | """ 170 | # TODO(nnorwitz): verify symbol_name doesn't contain :: ? 171 | if namespace_stack: 172 | # Handle non-global symbols (ie, in some namespace). 173 | last_namespace = self.namespaces 174 | for namespace in namespace_stack: 175 | last_namespace = last_namespace.setdefault(namespace, {}) 176 | else: 177 | last_namespace = self.namespaces[None] 178 | return self._add(symbol_name, last_namespace, node, module) 179 | 180 | def get_namespace(self, name_seq): 181 | """Returns the prefix of names from name_seq that are known namespaces. 182 | 183 | Args: 184 | name_seq: ['names', 'of', 'possible', 'namespace', 'to', 'find'] 185 | 186 | Returns: 187 | ['names', 'that', 'are', 'namespaces', 'possibly', 'empty', 'list'] 188 | 189 | """ 190 | namespaces = self.namespaces 191 | result = [] 192 | for name in name_seq: 193 | namespaces = namespaces.get(name) 194 | if not namespaces: 195 | break 196 | result.append(name) 197 | return result 198 | -------------------------------------------------------------------------------- /src/tools/cpp/tokenize.py: -------------------------------------------------------------------------------- 1 | # Copyright 2007 Neal Norwitz 2 | # Portions Copyright 2007 Google Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tokenize C++ source code.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import print_function 20 | from __future__ import unicode_literals 21 | 22 | 23 | __author__ = 'nnorwitz@google.com (Neal Norwitz)' 24 | 25 | 26 | # Add $ as a valid identifier char since so much code uses it. 27 | _letters = 'abcdefghijklmnopqrstuvwxyz' 28 | VALID_IDENTIFIER_CHARS = frozenset(_letters + 29 | _letters.upper() + 30 | '_0123456789$') 31 | HEX_DIGITS = frozenset('0123456789abcdefABCDEF') 32 | INT_OR_FLOAT_DIGITS = frozenset('01234567890eE-+') 33 | 34 | 35 | # C++0x string prefixes. 36 | _STR_PREFIXES = frozenset(('R', 'u8', 'u8R', 'u', 'uR', 'U', 'UR', 'L', 'LR')) 37 | 38 | 39 | # Token types. 40 | UNKNOWN = 'UNKNOWN' 41 | SYNTAX = 'SYNTAX' 42 | CONSTANT = 'CONSTANT' 43 | NAME = 'NAME' 44 | PREPROCESSOR = 'PREPROCESSOR' 45 | 46 | 47 | class TokenError(Exception): 48 | 49 | """Raised when tokenization fails.""" 50 | 51 | 52 | class Token(object): 53 | 54 | """Data container to represent a C++ token. 55 | 56 | Tokens can be identifiers, syntax char(s), constants, or 57 | pre-processor directives. 58 | 59 | start contains the index of the first char of the token in the source 60 | end contains the index of the last char of the token in the source 61 | 62 | """ 63 | 64 | def __init__(self, token_type, name, start, end): 65 | self.token_type = token_type 66 | self.name = name 67 | self.start = start 68 | self.end = end 69 | 70 | def __str__(self): 71 | return 'Token(%r, %s, %s)' % (self.name, self.start, self.end) 72 | 73 | __repr__ = __str__ 74 | 75 | 76 | def _get_string(source, i): 77 | i = source.find('"', i + 1) 78 | while source[i - 1] == '\\': 79 | # Count the trailing backslashes. 80 | backslash_count = 1 81 | j = i - 2 82 | while source[j] == '\\': 83 | backslash_count += 1 84 | j -= 1 85 | # When trailing backslashes are even, they escape each other. 86 | if (backslash_count % 2) == 0: 87 | break 88 | i = source.find('"', i + 1) 89 | return i + 1 90 | 91 | 92 | def _get_char(source, start, i): 93 | # NOTE(nnorwitz): may not be quite correct, should be good enough. 94 | i = source.find("'", i + 1) 95 | while i != -1 and source[i - 1] == '\\': 96 | # Need to special case '\\'. 97 | if source[i - 2] == '\\': 98 | break 99 | i = source.find("'", i + 1) 100 | # Try to handle unterminated single quotes. 101 | return i + 1 if i != -1 else start + 1 102 | 103 | 104 | def get_tokens(source): 105 | """Returns a sequence of Tokens. 106 | 107 | Args: 108 | source: string of C++ source code. 109 | 110 | Yields: 111 | Token that represents the next token in the source. 112 | 113 | """ 114 | if not source.endswith('\n'): 115 | source += '\n' 116 | 117 | # Cache various valid character sets for speed. 118 | valid_identifier_chars = VALID_IDENTIFIER_CHARS 119 | hex_digits = HEX_DIGITS 120 | int_or_float_digits = INT_OR_FLOAT_DIGITS 121 | int_or_float_digits2 = int_or_float_digits | set('.') 122 | 123 | # Ignore tokens while in a #if 0 block. 124 | count_ifs = 0 125 | 126 | i = 0 127 | end = len(source) 128 | while i < end: 129 | # Skip whitespace. 130 | while i < end and source[i].isspace(): 131 | i += 1 132 | if i >= end: 133 | return 134 | 135 | token_type = UNKNOWN 136 | start = i 137 | c = source[i] 138 | if c.isalpha() or c == '_': # Find a string token. 139 | token_type = NAME 140 | while source[i] in valid_identifier_chars: 141 | i += 1 142 | # String and character constants can look like a name if 143 | # they are something like L"". 144 | if source[i] == "'" and source[start:i] in _STR_PREFIXES: 145 | token_type = CONSTANT 146 | i = _get_char(source, start, i) 147 | elif source[i] == '"' and source[start:i] in _STR_PREFIXES: 148 | token_type = CONSTANT 149 | i = _get_string(source, i) 150 | elif c == '/' and source[i + 1] == '/': # Find // comments. 151 | i = _find(source, '\n', i) 152 | continue 153 | elif c == '/' and source[i + 1] == '*': # Find /* comments. */ 154 | i = _find(source, '*/', i) + 2 155 | continue 156 | elif c in '<>': # Handle '<' and '>' tokens. 157 | token_type = SYNTAX 158 | i += 1 159 | new_ch = source[i] 160 | # Do not merge '>>' or '>>=' into a single token 161 | if new_ch == c and c != '>': 162 | i += 1 163 | new_ch = source[i] 164 | if new_ch == '=': 165 | i += 1 166 | elif c in ':+-&|=': # Handle 'XX' and 'X=' tokens. 167 | token_type = SYNTAX 168 | i += 1 169 | new_ch = source[i] 170 | if new_ch == c: 171 | i += 1 172 | elif c == '-' and new_ch == '>': 173 | i += 1 174 | elif new_ch == '=': 175 | i += 1 176 | elif c in '!*^%/': # Handle 'X=' tokens. 177 | token_type = SYNTAX 178 | i += 1 179 | new_ch = source[i] 180 | if new_ch == '=': 181 | i += 1 182 | elif c in '()[]{}~?;.,': # Handle single char tokens. 183 | token_type = SYNTAX 184 | i += 1 185 | if c == '.' and source[i].isdigit(): 186 | token_type = CONSTANT 187 | i += 1 188 | while source[i] in int_or_float_digits: 189 | i += 1 190 | # Handle float suffixes. 191 | for suffix in ('l', 'f'): 192 | if suffix == source[i:i + 1].lower(): 193 | i += 1 194 | break 195 | elif c.isdigit(): # Find integer. 196 | token_type = CONSTANT 197 | if c == '0' and source[i + 1] in 'xX': 198 | # Handle hex digits. 199 | i += 2 200 | while source[i] in hex_digits: 201 | i += 1 202 | else: 203 | while source[i] in int_or_float_digits2: 204 | i += 1 205 | # Handle integer (and float) suffixes. 206 | if source[i].isalpha(): 207 | for suffix in ('ull', 'll', 'ul', 'l', 'f', 'u'): 208 | size = len(suffix) 209 | if suffix == source[i:i + size].lower(): 210 | i += size 211 | break 212 | elif c == '"': # Find string. 213 | token_type = CONSTANT 214 | i = _get_string(source, i) 215 | elif c == "'": # Find char. 216 | token_type = CONSTANT 217 | i = _get_char(source, start, i) 218 | elif c == '#': # Find pre-processor command. 219 | token_type = PREPROCESSOR 220 | got_if = source[i:i + 3] == '#if' 221 | if count_ifs and source[i:i + 6] == '#endif': 222 | count_ifs -= 1 223 | if count_ifs == 0: 224 | source = source[:i].ljust(i + 6) + source[i + 6:] 225 | continue 226 | 227 | # Handle preprocessor statements (\ continuations). 228 | while True: 229 | i1 = source.find('\n', i) 230 | i2 = source.find('//', i) 231 | i3 = source.find('/*', i) 232 | i4 = source.find('"', i) 233 | # Get the first important symbol (newline, comment, EOF/end). 234 | i = min([x for x in (i1, i2, i3, i4, end) if x != -1]) 235 | 236 | # Handle comments in #define macros. 237 | if i == i3: 238 | i = _find(source, '*/', i) + 2 239 | source = source[:i3].ljust(i) + source[i:] 240 | continue 241 | 242 | # Handle #include "dir//foo.h" properly. 243 | if source[i] == '"': 244 | i = _find(source, '"', i + 1) + 1 245 | continue 246 | 247 | # Keep going if end of the line and the line ends with \. 248 | if i == i1 and source[i - 1] == '\\': 249 | i += 1 250 | continue 251 | 252 | if got_if: 253 | begin = source.find('(', start, i) 254 | if begin == -1: 255 | begin = source.find(' ', start) 256 | begin = begin + 1 257 | s1 = source.find(' ', begin) 258 | s2 = source.find(')', begin) 259 | s3 = source.find('\n', begin) 260 | s = min([x for x in (s1, s2, s3, end) if x != -1]) 261 | 262 | condition = source[begin:s] 263 | if ( 264 | count_ifs or 265 | condition == '0' or 266 | condition == '__OBJC__' 267 | ): 268 | count_ifs += 1 269 | break 270 | elif c == '\\': # Handle \ in code. 271 | # This is different from the pre-processor \ handling. 272 | i += 1 273 | continue 274 | elif count_ifs: 275 | # Ignore bogus code when we are inside an #if block. 276 | i += 1 277 | continue 278 | else: 279 | raise TokenError("unexpected token '{0}'".format(c)) 280 | 281 | if count_ifs: 282 | continue 283 | 284 | assert i > 0 285 | yield Token(token_type, source[start:i], start, i) 286 | 287 | 288 | def _find(string, sub_string, start_index): 289 | """Return index of sub_string in string. 290 | 291 | Raise TokenError if sub_string is not found. 292 | 293 | """ 294 | result = string.find(sub_string, start_index) 295 | if result == -1: 296 | raise TokenError("expected '{0}'".format(sub_string)) 297 | return result 298 | -------------------------------------------------------------------------------- /src/tools/cpp/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2007 Neal Norwitz 2 | # Portions Copyright 2007 Google Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Generic utilities for C++ parsing.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import print_function 20 | from __future__ import unicode_literals 21 | 22 | import io 23 | import sys 24 | 25 | 26 | __author__ = 'nnorwitz@google.com (Neal Norwitz)' 27 | 28 | 29 | def read_file(filename, print_error=True): 30 | """Returns the contents of a file.""" 31 | try: 32 | for encoding in ['utf-8', 'latin1']: 33 | try: 34 | with io.open(filename, encoding=encoding) as fp: 35 | return fp.read() 36 | except UnicodeDecodeError: 37 | pass 38 | except IOError as exception: 39 | if print_error: 40 | print(exception, file=sys.stderr) 41 | return None 42 | -------------------------------------------------------------------------------- /src/tools/fffmock.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import os 5 | 6 | from cpp import ast 7 | from cpp import utils 8 | 9 | 10 | FILE_H_HEADER=""" 11 | /** 12 | This is automatically generate mock file (from {source_name}). If you wish to edit this, move this to mocks_man directory. 13 | */ 14 | #ifndef {define_name} 15 | #define {define_name} 16 | 17 | #include "fff.h" 18 | #include "{source_include}" 19 | 20 | """ 21 | 22 | 23 | FILE_H_FOOTER=""" 24 | #endif // {define_name} 25 | """ 26 | 27 | FILE_C_HEADER=""" 28 | /** 29 | This is automatically generated mock file, see header file for more details. 30 | */ 31 | 32 | 33 | #include "{header_file}" 34 | """ 35 | 36 | FILE_C_FOOTER="" 37 | 38 | 39 | 40 | class Mock_Generator: 41 | 42 | def __init__( self, fn_source, fn_target_h, fn_target_c ): 43 | self.fn_source = fn_source 44 | self.fn_target_h = fn_target_h 45 | self.fn_target_c = fn_target_c 46 | 47 | self.fid_h = open( fn_target_h, 'w' ) 48 | self.fid_c = open( fn_target_c, 'w' ) 49 | 50 | self.print_formatted( self.fid_h, FILE_H_HEADER) 51 | self.print_formatted( self.fid_c, FILE_C_HEADER ) 52 | 53 | 54 | def print_formatted( self, fid, string ): 55 | 56 | string_vars = { 'header_file' : os.path.basename( self.fn_target_h ), 57 | 'define_name' : "_AUTOMOCK_" + os.path.basename(self.fn_source).upper().replace(".","_"), 58 | 'source_name' : self.fn_source, 59 | 'source_include' : os.path.basename(self.fn_source) } 60 | 61 | fid.write( string.format( **string_vars ) ) 62 | 63 | def save(self): 64 | self.print_formatted( self.fid_h, FILE_H_FOOTER ) 65 | self.print_formatted( self.fid_c, FILE_C_FOOTER ) 66 | 67 | self.fid_h.close() 68 | self.fid_c.close() 69 | 70 | 71 | def _make_fff_fun( self, node_name, node, fun_par = None ): 72 | fun_ret = process_type( node.return_type ) 73 | 74 | if fun_par == None: 75 | fun_par = process_params( node ) 76 | 77 | fake_params = [ ] 78 | 79 | if len(fun_par) > 0 and fun_par[-1][0] == "...": 80 | base_end="_VARARG" 81 | else: 82 | base_end="" 83 | 84 | if fun_ret == "void": 85 | base_type = "VOID" 86 | else: 87 | base_type = "VALUE" 88 | fake_params.append( fun_ret ) 89 | 90 | fake_params.append( node_name ) 91 | for item in fun_par: 92 | fake_params.append( item[0] ) 93 | 94 | fake_params = ", ".join(fake_params) 95 | base = "FAKE_%s_FUNC%s( %s );" % ( base_type, base_end, fake_params ) 96 | self.fid_h.write( "DECLARE_" + base + "\n" ) 97 | self.fid_c.write( "DEFINE_" + base + "\n" ) 98 | 99 | def _make_class_fun( self, class_name, node ): 100 | fun_name = class_name + "__" + node.name 101 | 102 | fun_cls_name = node.name 103 | fun_cls_mod = "" 104 | if node.modifiers == ast.FUNCTION_SPECIFIER: #FIXME: The plain cpp does not make difference with const, final etc modifiers, this will be a problem. 105 | fun_cls_mod = "const" 106 | elif node.modifiers == ast.FUNCTION_DTOR: 107 | fun_name += "_DTOR" 108 | fun_cls_name = "~" + fun_cls_name 109 | 110 | fun_params = [ ("%s %s*" % (fun_cls_mod, class_name) ,"this"),] + process_params( node ) 111 | 112 | fun_fff_par_types = ", ".join( [ x[0] for x in fun_params ] ) 113 | fun_fff_par_names = ", ".join( [ x[1] for x in fun_params ] ) 114 | fun_fff_decl = "%s %s( %s )" % ( process_type( node.return_type ), fun_name , fun_fff_par_types ) 115 | 116 | fun_cls_ret = process_type( node.return_type ) 117 | 118 | if fun_cls_ret == "void": 119 | if node.return_type == None: 120 | fun_cls_ret = "" 121 | fun_cls_ret_str = "" 122 | else: 123 | fun_cls_ret_str = "return " 124 | 125 | fun_cls_par_types = ", ".join( [ "%s %s" % (x[0],x[1]) for x in process_params( node ) ] ) 126 | 127 | 128 | self.fid_h.write( fun_fff_decl + ";\n"); 129 | 130 | self._make_fff_fun( fun_name, node, fun_par = fun_params ) 131 | self.fid_c.write( "%s %s::%s(%s) %s " % ( fun_cls_ret, class_name, fun_cls_name, fun_cls_par_types, fun_cls_mod )) 132 | self.fid_c.write( "{ %s%s( %s ); }\n" % ( fun_cls_ret_str, fun_name, fun_fff_par_names ) ) 133 | 134 | 135 | def run( self, ast_list ): 136 | for node in ast_list: 137 | if isinstance(node, ast.Function): 138 | self._make_fff_fun( node.name, node ) 139 | elif isinstance(node, ast.Class) and node.body: 140 | class_name = node.name 141 | for child_node in node.body: 142 | if isinstance(child_node, ast.Function): 143 | self._make_class_fun( class_name, child_node ); 144 | 145 | 146 | def process_params( node ): 147 | return [ ( process_type(par.type) , par.name ) for par in node.parameters ] 148 | 149 | def process_type( partype ): 150 | if partype == None: 151 | return "void" 152 | ret = "%s %s%s" % (" ".join(partype.modifiers), partype.full_name(), "*" if partype.pointer else "", ) 153 | return ret.strip() 154 | 155 | def fun_full( prefix, node ): 156 | params = fun_params( node ) 157 | print "%sF:%s:%s(%s)" % ( prefix, process_type(node.return_type) , node.name, params ) 158 | 159 | 160 | def _GenerateMocks( ast_list, target_h, target_c ): 161 | processed_class_names = set() 162 | lines = [] 163 | 164 | for node in ast_list: 165 | if isinstance(node, ast.Function): 166 | fun_full("", node) 167 | elif isinstance(node, ast.Class) and node.body: 168 | class_name = node.name 169 | parent_name = class_name 170 | 171 | print "C:%s:public:%s" % (class_name, parent_name ) 172 | for child_node in node.body: 173 | #print child_node 174 | if isinstance(child_node, ast.Function): 175 | fun_full(" ", child_node) 176 | # import ipdb; ipdb.set_trace(); 177 | return lines 178 | 179 | 180 | def generate_mocks( filename_source, filename_h_target, filename_c_target ): 181 | print "Generate fff mocks from :" + filename_source +"->" + filename_c_target 182 | try: 183 | source = utils.read_file(filename_source) 184 | if source is None: 185 | return 186 | 187 | builder = ast.builder_from_source( source, filename_source ) 188 | entire_ast = list([_f for _f in builder.generate() if _f]) 189 | except KeyboardInterrupt: 190 | return 191 | 192 | generator = Mock_Generator( filename_source, filename_h_target, filename_c_target ) 193 | generator.run( entire_ast ) 194 | generator.save() 195 | #generator.run( entire_ast ) 196 | 197 | 198 | if __name__ == "__main__": 199 | generate_mocks( sys.argv[1], sys.argv[2], sys.argv[3],) 200 | 201 | -------------------------------------------------------------------------------- /src/tools/run_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | for tbin in tests/build/bin/*test_*; do 6 | echo "*** Running: $tbin ***" 7 | ./$tbin 8 | echo "*** DONE ***" 9 | done; 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /src/tools/run_coverage.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | 6 | rm -rf tests/build/coverage/ 7 | mkdir -p tests/build/coverage/ 8 | 9 | ALL_INFO_FILES="" 10 | 11 | run_single() 12 | { 13 | lcov --directory=tests/build/src/ --zerocounters 14 | BIN_NAME=$1 15 | echo "*** Running: $BIN_NAME" 16 | $BIN_NAME 17 | INFO_NAME="tests/build/coverage/$(basename $BIN_NAME).info.part" 18 | lcov --directory=tests/build/src/ --capture --output-file=$INFO_NAME.extra 19 | lcov --remove $INFO_NAME.extra "tests/fakes/*" --output-file=$INFO_NAME 20 | ALL_INFO_FILES="$ALL_INFO_FILES -a $INFO_NAME" 21 | } 22 | 23 | 24 | for tbin in tests/build/bin/*test_*; do 25 | run_single $tbin 26 | done; 27 | 28 | 29 | 30 | lcov $ALL_INFO_FILES --output-file=tests/build/coverage/full.info 31 | genhtml --output=tests/build/coverage/ tests/build/coverage/full.info 32 | 33 | echo "*********************************************" 34 | echo "DONE: check tests/build/coverage/index.html" 35 | echo "*********************************************" 36 | -------------------------------------------------------------------------------- /website/screen_coverage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/susundberg/arduino-simple-unittest/ed72c17cd8eea76fb51ed22ceb8a5b7a29daa4a1/website/screen_coverage.png -------------------------------------------------------------------------------- /website/screen_debug.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/susundberg/arduino-simple-unittest/ed72c17cd8eea76fb51ed22ceb8a5b7a29daa4a1/website/screen_debug.png --------------------------------------------------------------------------------