├── CMakeLists.txt ├── GUI ├── CMakeLists.txt ├── TargetRepo.cpp ├── TargetRepo.h ├── main.cpp ├── thumbnailctrl.cpp ├── thumbnailctrl.h ├── util.h ├── win.rc ├── wxplayer.cpp └── wxplayer.h ├── LICENSE ├── README.md ├── config.h.in ├── detection ├── CMakeLists.txt ├── include │ └── Detector.h └── src │ ├── Darknet.cpp │ ├── Darknet.h │ ├── Detector.cpp │ ├── darknet_parsing.cpp │ ├── darknet_parsing.h │ └── letterbox.h ├── models ├── yolov3-tiny.cfg └── yolov3.cfg ├── processing ├── CMakeLists.txt ├── TargetStorage.cpp ├── TargetStorage.h ├── main.cpp └── util.h ├── snapshots ├── UI-offline.png ├── UI-online.png ├── detection.png └── tracking.png └── tracking ├── CMakeLists.txt ├── include ├── DeepSORT.h ├── SORT.h └── Track.h └── src ├── DeepSORT.cpp ├── Extractor.cpp ├── Extractor.h ├── Hungarian.cpp ├── Hungarian.h ├── KalmanTracker.cpp ├── KalmanTracker.h ├── SORT.cpp ├── TrackerManager.cpp ├── TrackerManager.h ├── nn_matching.cpp └── nn_matching.h /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10 FATAL_ERROR) 2 | project(libtorch-yolov3-deepsot) 3 | 4 | set(CMAKE_CXX_STANDARD 17) 5 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 6 | 7 | # For profiler on Ubuntu 8 | if (CMAKE_BUILD_TYPE STREQUAL Debug AND NOT MSVC) 9 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O0") 10 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0") 11 | endif () 12 | 13 | # Output directory structure 14 | set(OUTPUT_DIR "result") 15 | set(TARGETS_DIR_NAME "targets") 16 | set(TRAJ_TXT_NAME "trajectories.txt") 17 | set(SNAPSHOTS_DIR_NAME "snapshots") 18 | set(VIDEO_NAME "compressed.flv") 19 | configure_file( 20 | "${PROJECT_SOURCE_DIR}/config.h.in" 21 | "${PROJECT_BINARY_DIR}/config.h" 22 | ) 23 | 24 | # GCC need to link against stdc++fs 25 | if(MSVC) 26 | set(STDCXXFS "") 27 | else() 28 | set(STDCXXFS "stdc++fs") 29 | endif() 30 | 31 | # .exe and .dll 32 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) 33 | 34 | add_subdirectory(detection) 35 | add_subdirectory(tracking) 36 | add_subdirectory(processing) 37 | add_subdirectory(GUI) 38 | -------------------------------------------------------------------------------- /GUI/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(OpenCV REQUIRED) 2 | find_package(wxWidgets COMPONENTS core base REQUIRED) 3 | 4 | include(${wxWidgets_USE_FILE}) 5 | 6 | aux_source_directory(. GUI_SRCS) 7 | 8 | if (WIN32) 9 | set(WIN_RC "win.rc") 10 | else () 11 | set(WIN_RC "") 12 | endif () 13 | 14 | add_executable(GUI WIN32 "${GUI_SRCS}" ${WIN_RC}) 15 | target_link_libraries(GUI ${wxWidgets_LIBRARIES} ${OpenCV_LIBS} ${STDCXXFS}) 16 | target_include_directories(GUI PRIVATE "${PROJECT_BINARY_DIR}") 17 | -------------------------------------------------------------------------------- /GUI/TargetRepo.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "TargetRepo.h" 5 | #include "config.h" 6 | 7 | using namespace std; 8 | namespace fs = std::experimental::filesystem; 9 | 10 | void TargetRepo::load(const std::function &show_progress) { 11 | int num_targets = 0; 12 | for (auto &trk_dir: fs::directory_iterator(fs::path(out_dir) / TARGETS_DIR_NAME)) { 13 | ++num_targets; 14 | } 15 | 16 | int i_target = 0; 17 | for (auto &trk_dir: fs::directory_iterator(fs::path(out_dir) / TARGETS_DIR_NAME)) { 18 | auto id = stoi(trk_dir.path().filename()); 19 | targets.emplace(id, Target()); 20 | auto &t = targets[id]; 21 | 22 | for (auto &ss_p: fs::directory_iterator(trk_dir / SNAPSHOTS_DIR_NAME)) { 23 | auto &ss_path = ss_p.path(); 24 | if (!t.snapshots.count(stoi(ss_path.stem()))) { 25 | auto img = cv::imread(ss_path.string()); 26 | t.snapshots[stoi(ss_path.stem())] = img; 27 | assert(!img.empty()); 28 | } 29 | } 30 | 31 | auto traj_txt = ifstream((trk_dir / TRAJ_TXT_NAME).string()); 32 | int frame; 33 | cv::Rect2f box; 34 | while (traj_txt >> frame >> box.x >> box.y >> box.width >> box.height) { 35 | t.trajectories[frame] = box; 36 | } 37 | traj_txt.close(); 38 | 39 | show_progress(100 * ++i_target / num_targets); 40 | } 41 | } 42 | 43 | std::string TargetRepo::video_path() { 44 | return (fs::path(out_dir) / VIDEO_NAME).string(); 45 | } 46 | -------------------------------------------------------------------------------- /GUI/TargetRepo.h: -------------------------------------------------------------------------------- 1 | #ifndef TARGET_H 2 | #define TARGET_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | struct Target { 10 | std::map trajectories; 11 | std::map snapshots; 12 | }; 13 | 14 | class TargetRepo { 15 | public: 16 | TargetRepo(const std::string &dir) : out_dir(dir) {} 17 | 18 | using container_t = std::map; 19 | 20 | container_t &get() { return targets; } 21 | 22 | std::string video_path(); 23 | 24 | void load(const std::function &show_progress); 25 | 26 | private: 27 | container_t targets; 28 | 29 | const std::string out_dir; 30 | }; 31 | 32 | #endif //TARGET_H 33 | -------------------------------------------------------------------------------- /GUI/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #ifndef WX_PRECOMP 4 | 5 | #include 6 | 7 | #endif 8 | 9 | #include 10 | #include 11 | 12 | #include "thumbnailctrl.h" 13 | #include "TargetRepo.h" 14 | #include "util.h" 15 | #include "wxplayer.h" 16 | 17 | class MyApp : public wxApp { 18 | public: 19 | bool OnInit() override; 20 | }; 21 | 22 | wxIMPLEMENT_APP(MyApp); 23 | 24 | class MyFrame : public wxFrame { 25 | public: 26 | MyFrame(const wxString &dir); 27 | 28 | private: 29 | wxThumbnailCtrl *InitThumbnails(); 30 | 31 | wxPlayer *player; 32 | 33 | TargetRepo repo; 34 | 35 | int hovered = -1; 36 | 37 | enum { 38 | ID_Hello = 1, 39 | ID_Timer, 40 | ID_List, 41 | }; 42 | }; 43 | 44 | 45 | bool MyApp::OnInit() { 46 | auto dialog = wxDirDialog(nullptr, 47 | "Select result directory", 48 | "", 49 | wxDD_DIR_MUST_EXIST); 50 | if (dialog.ShowModal() == wxID_OK) { 51 | auto frame = new MyFrame(dialog.GetPath()); 52 | frame->Show(true); 53 | return true; 54 | } else { 55 | return false; 56 | } 57 | } 58 | 59 | MyFrame::MyFrame(const wxString &dir) 60 | : wxFrame(nullptr, wxID_ANY, "YOLO+DeepSORT+wxWidgets"), 61 | repo(dir.ToStdString()) { 62 | player = new wxPlayer(this, wxID_ANY, repo.video_path(), 63 | [this](cv::Mat &mat, int display_frame) { 64 | for (auto &[id, t]:repo.get()) { 65 | if (t.trajectories.count(display_frame)) { 66 | auto color = 67 | hovered == -1 ? color_map(id) : hovered == id ? cv::Scalar(0, 0, 255) 68 | : cv::Scalar(0, 0, 0); 69 | draw_trajectories(mat, t.trajectories, display_frame, color); 70 | draw_bbox(mat, t.trajectories.at(display_frame), 71 | std::to_string(id), color); 72 | } 73 | } 74 | }); 75 | 76 | auto sizer = new wxBoxSizer(wxHORIZONTAL); 77 | sizer->Add(player, 3, wxEXPAND | wxALL); 78 | sizer->Add(InitThumbnails(), 1, wxEXPAND | wxALL); 79 | SetSizerAndFit(sizer); 80 | } 81 | 82 | wxThumbnailCtrl *MyFrame::InitThumbnails() { 83 | auto dialog = wxProgressDialog("Loading results", wxEmptyString); 84 | repo.load([&dialog](int value) { dialog.Update(value / 2, "Loading targets..."); }); 85 | auto thumbnails = new wxThumbnailCtrl(this, ID_List); 86 | thumbnails->SetThumbnailImageSize(wxSize(50, 50)); 87 | int i_target = 0; 88 | for (auto &[id, t]:repo.get()) { 89 | for (auto &[s_t, s]:t.snapshots) { 90 | cv::cvtColor(s, s, cv::COLOR_BGR2RGB); 91 | cv::resize(s, s, cv::Size(50, 50)); 92 | } 93 | auto item = new wxThumbnailItem(wxString::Format("%d", id)); 94 | item->SetBitmap(cvMat2wxImage(t.snapshots.begin()->second)); 95 | thumbnails->Append(item); 96 | dialog.Update(50 + 50 * ++i_target / repo.get().size(), "Loading resources..."); 97 | } 98 | 99 | static std::map::const_iterator it{}; 100 | auto timer = new wxTimer(thumbnails, ID_Timer); 101 | thumbnails->Bind(wxEVT_TIMER, 102 | [this, thumbnails](wxTimerEvent &) { 103 | if (thumbnails->GetMouseHoverItem() != wxNOT_FOUND) { 104 | auto &item = *thumbnails->GetItem(thumbnails->GetMouseHoverItem()); 105 | auto &snapshots = repo.get().at(wxAtoi(item.GetLabel())).snapshots; 106 | item.SetBitmap(cvMat2wxImage(it->second)); 107 | item.Refresh(thumbnails, thumbnails->GetMouseHoverItem()); 108 | if (++it == snapshots.end()) { 109 | it = snapshots.begin(); 110 | } 111 | } 112 | }, ID_Timer); 113 | timer->Start(1000 * 5 / player->GetFPS()); 114 | thumbnails->Bind(wxEVT_COMMAND_THUMBNAIL_ITEM_HOVER_CHANGED, 115 | [this, thumbnails](wxThumbnailEvent &event) { 116 | if (event.GetIndex() != wxNOT_FOUND) { 117 | auto &item = *thumbnails->GetItem(event.GetIndex()); 118 | item.SetBitmap(cvMat2wxImage( 119 | repo.get().at(wxAtoi(item.GetLabel())).snapshots.begin()->second)); 120 | } 121 | 122 | if (thumbnails->GetMouseHoverItem() != wxNOT_FOUND) { 123 | auto &item = *thumbnails->GetItem(thumbnails->GetMouseHoverItem()); 124 | it = repo.get().at(wxAtoi(item.GetLabel())).snapshots.begin(); 125 | } 126 | }, ID_List); 127 | 128 | thumbnails->Bind(wxEVT_COMMAND_THUMBNAIL_ITEM_SELECTED, 129 | [this, thumbnails](wxThumbnailEvent &event) { 130 | auto id = wxAtoi(thumbnails->GetItem(event.GetIndex())->GetLabel()); 131 | player->Seek(repo.get().at(id).trajectories.begin()->first); 132 | }, ID_List); 133 | 134 | thumbnails->Bind(wxEVT_COMMAND_THUMBNAIL_ITEM_HOVER_CHANGED, 135 | [this, thumbnails](wxThumbnailEvent &event) { 136 | if (thumbnails->GetMouseHoverItem() != wxNOT_FOUND) { 137 | hovered = wxAtoi(thumbnails->GetItem(thumbnails->GetMouseHoverItem())->GetLabel()); 138 | player->Refresh(); 139 | } else { 140 | hovered = -1; 141 | player->Refresh(); 142 | } 143 | Refresh(); 144 | event.Skip(); 145 | }, ID_List); 146 | 147 | return thumbnails; 148 | } 149 | -------------------------------------------------------------------------------- /GUI/thumbnailctrl.cpp: -------------------------------------------------------------------------------- 1 | ///////////////////////////////////////////////////////////////////////////// 2 | // Name: thumbnailctrl.cpp 3 | // Purpose: Displays a scrolling window of thumbnails 4 | // Author: Julian Smart 5 | // Modified by: Anil Kumar 6 | // Created: 03/08/04 17:22:46 7 | // RCS-ID: 8 | // Copyright: (c) Julian Smart 9 | // Licence: wxWidgets Licence 10 | ///////////////////////////////////////////////////////////////////////////// 11 | 12 | #if defined(__GNUG__) && !defined(__APPLE__) 13 | #pragma implementation "thumbnailctrl.h" 14 | #endif 15 | 16 | // For compilers that support precompilation, includes "wx.h". 17 | #include "wx/wxprec.h" 18 | 19 | #ifdef __BORLANDC__ 20 | #pragma hdrstop 21 | #endif 22 | 23 | #ifndef WX_PRECOMP 24 | 25 | #include "wx/wx.h" 26 | 27 | #endif 28 | 29 | #include "thumbnailctrl.h" 30 | 31 | #include "wx/settings.h" 32 | #include "wx/arrimpl.cpp" 33 | #include "wx/image.h" 34 | #include "wx/filename.h" 35 | #include "wx/dcbuffer.h" 36 | #include 37 | 38 | WX_DEFINE_OBJARRAY(wxThumbnailItemArray); 39 | 40 | wxDEFINE_EVENT(wxEVT_COMMAND_THUMBNAIL_SELECTION_CHANGED, wxThumbnailEvent); 41 | wxDEFINE_EVENT(wxEVT_COMMAND_THUMBNAIL_ITEM_SELECTED, wxThumbnailEvent); 42 | wxDEFINE_EVENT(wxEVT_COMMAND_THUMBNAIL_ITEM_DESELECTED, wxThumbnailEvent); 43 | wxDEFINE_EVENT(wxEVT_COMMAND_THUMBNAIL_LEFT_CLICK, wxThumbnailEvent); 44 | wxDEFINE_EVENT(wxEVT_COMMAND_THUMBNAIL_RIGHT_CLICK, wxThumbnailEvent); 45 | wxDEFINE_EVENT(wxEVT_COMMAND_THUMBNAIL_VIEW_RIGHT_CLICK, wxThumbnailEvent); 46 | wxDEFINE_EVENT(wxEVT_COMMAND_THUMBNAIL_LEFT_DCLICK, wxThumbnailEvent); 47 | wxDEFINE_EVENT(wxEVT_COMMAND_THUMBNAIL_RETURN, wxThumbnailEvent); 48 | wxDEFINE_EVENT(wxEVT_COMMAND_THUMBNAIL_DRAG_START, wxThumbnailEvent); 49 | wxDEFINE_EVENT(wxEVT_COMMAND_THUMBNAIL_SORTED, wxThumbnailEvent); 50 | 51 | wxDEFINE_EVENT(wxEVT_COMMAND_THUMBNAIL_ITEM_HOVER_CHANGED, wxThumbnailEvent); 52 | 53 | IMPLEMENT_CLASS(wxThumbnailCtrl, wxScrolledWindow) 54 | 55 | IMPLEMENT_CLASS(wxThumbnailEvent, wxNotifyEvent) 56 | 57 | BEGIN_EVENT_TABLE(wxThumbnailCtrl, wxScrolledWindow) 58 | EVT_MOUSE_EVENTS(wxThumbnailCtrl::OnMouse) 59 | EVT_MOTION(wxThumbnailCtrl::OnMouseMotion) 60 | EVT_LEAVE_WINDOW(wxThumbnailCtrl::OnMouseLeave) 61 | EVT_CHAR(wxThumbnailCtrl::OnChar) 62 | EVT_SIZE(wxThumbnailCtrl::OnSize) 63 | EVT_SET_FOCUS(wxThumbnailCtrl::OnSetFocus) 64 | EVT_KILL_FOCUS(wxThumbnailCtrl::OnKillFocus) 65 | 66 | EVT_MENU(wxID_SELECTALL, wxThumbnailCtrl::OnSelectAll) 67 | EVT_UPDATE_UI(wxID_SELECTALL, wxThumbnailCtrl::OnUpdateSelectAll) 68 | END_EVENT_TABLE() 69 | 70 | /*! 71 | * wxThumbnailCtrl 72 | */ 73 | 74 | wxThumbnailCtrl::wxThumbnailCtrl() { 75 | m_thumbnailOverallSize = wxTHUMBNAIL_DEFAULT_OVERALL_SIZE; 76 | m_thumbnailImageSize = wxTHUMBNAIL_DEFAULT_IMAGE_SIZE; 77 | m_spacing = wxTHUMBNAIL_DEFAULT_SPACING; 78 | m_thumbnailMargin = wxTHUMBNAIL_DEFAULT_MARGIN; 79 | m_firstSelection = -1; 80 | m_lastSelection = -1; 81 | m_focussedThumbnailBackgroundColour = wxTHUMBNAIL_DEFAULT_FOCUSSED_BACKGROUND; 82 | m_unfocussedThumbnailBackgroundColour = wxTHUMBNAIL_DEFAULT_UNFOCUSSED_BACKGROUND; 83 | m_unselectedThumbnailBackgroundColour = wxTHUMBNAIL_DEFAULT_UNSELECTED_BACKGROUND; 84 | m_typeColour = wxTHUMBNAIL_DEFAULT_TYPE_COLOUR; 85 | m_focusRectColour = wxTHUMBNAIL_DEFAULT_FOCUS_RECT_COLOUR; 86 | m_focusItem = -1; 87 | m_showOutlines = false; 88 | } 89 | 90 | wxThumbnailCtrl::wxThumbnailCtrl(wxWindow *parent, wxWindowID id, const wxPoint &pos, const wxSize &size, long style) 91 | : wxThumbnailCtrl() { 92 | Create(parent, id, pos, size, style); 93 | } 94 | 95 | /// Creation 96 | bool wxThumbnailCtrl::Create(wxWindow *parent, wxWindowID id, const wxPoint &pos, const wxSize &size, long style) { 97 | if (!wxScrolledCanvas::Create(parent, id, pos, size, style | wxFULL_REPAINT_ON_RESIZE)) 98 | return false; 99 | 100 | if (!GetFont().Ok()) { 101 | SetFont(wxSystemSettings::GetFont(wxSYS_DEFAULT_GUI_FONT)); 102 | } 103 | CalculateOverallThumbnailSize(); 104 | 105 | SetBackgroundColour(wxSystemSettings::GetColour(wxSYS_COLOUR_3DFACE)); 106 | SetBackgroundStyle(wxBG_STYLE_CUSTOM); 107 | 108 | // Tell the sizers to use the given or best size 109 | SetInitialSize(size); 110 | 111 | return true; 112 | } 113 | 114 | /// Append a single item 115 | int wxThumbnailCtrl::Append(wxThumbnailItem *item) { 116 | int sz = (int) GetCount(); 117 | m_items.Add(item); 118 | m_firstSelection = -1; 119 | m_lastSelection = -1; 120 | m_focusItem = -1; 121 | 122 | if (!IsFrozen()) { 123 | SetupScrollbars(); 124 | Refresh(); 125 | } 126 | return sz; 127 | } 128 | 129 | /// Insert a single item 130 | int wxThumbnailCtrl::Insert(wxThumbnailItem *item, int pos) { 131 | m_items.Insert(item, pos); 132 | m_firstSelection = -1; 133 | m_lastSelection = -1; 134 | m_focusItem = -1; 135 | 136 | // Must now change selection indices because 137 | // items above it have moved up 138 | size_t i; 139 | for (i = 0; i < m_selections.GetCount(); i++) { 140 | if (m_selections[i] >= pos) 141 | m_selections[i] = m_selections[i] + 1; 142 | } 143 | 144 | if (!IsFrozen()) { 145 | SetupScrollbars(); 146 | Refresh(); 147 | } 148 | return pos; 149 | } 150 | 151 | /// Clear all items 152 | void wxThumbnailCtrl::Clear() { 153 | m_firstSelection = -1; 154 | m_lastSelection = -1; 155 | m_focusItem = -1; 156 | m_items.Clear(); 157 | m_selections.Clear(); 158 | m_hoverItem = wxNOT_FOUND; 159 | 160 | if (!IsFrozen()) { 161 | SetupScrollbars(); 162 | Refresh(); 163 | } 164 | } 165 | 166 | /// Delete this item 167 | void wxThumbnailCtrl::Delete(int n) { 168 | if (m_firstSelection == n) 169 | m_firstSelection = -1; 170 | if (m_lastSelection == n) 171 | m_lastSelection = -1; 172 | if (m_focusItem == n) 173 | m_focusItem = -1; 174 | 175 | if (m_selections.Index(n) != wxNOT_FOUND) { 176 | m_selections.Remove(n); 177 | 178 | wxThumbnailEvent event(wxEVT_COMMAND_THUMBNAIL_ITEM_DESELECTED, GetId()); 179 | event.SetEventObject(this); 180 | event.SetIndex(n); 181 | GetEventHandler()->ProcessEvent(event); 182 | 183 | wxThumbnailEvent cmdEvent(wxEVT_COMMAND_THUMBNAIL_SELECTION_CHANGED, GetId()); 184 | cmdEvent.SetEventObject(this); 185 | GetEventHandler()->ProcessEvent(cmdEvent); 186 | } 187 | 188 | m_items.RemoveAt(n); 189 | 190 | // Must now change selection indices because 191 | // items have moved down 192 | size_t i; 193 | for (i = 0; i < m_selections.GetCount(); i++) { 194 | if (m_selections[i] > n) 195 | m_selections[i] = m_selections[i] - 1; 196 | } 197 | 198 | if (!IsFrozen()) { 199 | SetupScrollbars(); 200 | Refresh(); 201 | } 202 | } 203 | 204 | /// Get the nth item 205 | wxThumbnailItem *wxThumbnailCtrl::GetItem(int n) { 206 | wxASSERT(n < GetCount()); 207 | 208 | if (n < GetCount()) { 209 | return &m_items[(size_t) n]; 210 | } else { 211 | return nullptr; 212 | } 213 | } 214 | 215 | /// Get the overall rect of the given item 216 | bool wxThumbnailCtrl::GetItemRect(int n, wxRect &rect, bool transform) { 217 | wxASSERT(n < GetCount()); 218 | if (n < GetCount()) { 219 | int row, col; 220 | if (!GetRowCol(n, GetClientSize(), row, col)) 221 | return false; 222 | 223 | int x = col * (m_thumbnailOverallSize.x + m_spacing) + m_spacing; 224 | int y = row * (m_thumbnailOverallSize.y + m_spacing) + m_spacing; 225 | 226 | if (transform) { 227 | int startX, startY; 228 | int xppu, yppu; 229 | GetScrollPixelsPerUnit(&xppu, &yppu); 230 | GetViewStart(&startX, &startY); 231 | x = x - startX * xppu; 232 | y = y - startY * yppu; 233 | } 234 | 235 | rect.x = x; 236 | rect.y = y; 237 | rect.width = m_thumbnailOverallSize.x; 238 | rect.height = m_thumbnailOverallSize.y; 239 | 240 | return true; 241 | } 242 | 243 | return false; 244 | } 245 | 246 | /// Get the image rect of the given item 247 | bool wxThumbnailCtrl::GetItemRectImage(int n, wxRect &rect, bool transform) { 248 | wxASSERT(n < GetCount()); 249 | 250 | wxRect outerRect; 251 | if (!GetItemRect(n, outerRect, transform)) 252 | return false; 253 | 254 | rect.width = m_thumbnailImageSize.x; 255 | rect.height = m_thumbnailImageSize.y; 256 | rect.x = outerRect.x + (outerRect.width - rect.width) / 2; 257 | rect.y = outerRect.y + (outerRect.height - rect.height) / 2; 258 | rect.y -= m_thumbnailTextHeight / 2; 259 | 260 | return true; 261 | } 262 | 263 | /// The size of the image part 264 | void wxThumbnailCtrl::SetThumbnailImageSize(const wxSize &sz) { 265 | m_thumbnailImageSize = sz; 266 | CalculateOverallThumbnailSize(); 267 | 268 | if (GetCount() > 0 && !IsFrozen()) { 269 | SetupScrollbars(); 270 | Refresh(); 271 | } 272 | } 273 | 274 | /// Calculate the outer thumbnail size based 275 | /// on font used for text and inner size 276 | void wxThumbnailCtrl::CalculateOverallThumbnailSize() { 277 | wxCoord w; 278 | wxClientDC dc(this); 279 | dc.SetFont(GetFont()); 280 | dc.GetTextExtent(wxT("X"), &w, &m_thumbnailTextHeight); 281 | 282 | // From left to right: margin, image, margin 283 | m_thumbnailOverallSize.x = m_thumbnailMargin * 2 + m_thumbnailImageSize.x; 284 | 285 | // From top to bottom: margin, image, margin, text, margin 286 | m_thumbnailOverallSize.y = m_thumbnailMargin * 3 + m_thumbnailTextHeight + m_thumbnailImageSize.y; 287 | 288 | SetMinClientSize(m_thumbnailOverallSize); 289 | } 290 | 291 | /// Return the row and column given the client 292 | /// size and a left-to-right, top-to-bottom layout 293 | /// assumption 294 | bool wxThumbnailCtrl::GetRowCol(int item, const wxSize &clientSize, int &row, int &col) { 295 | wxASSERT(item < GetCount()); 296 | if (item >= GetCount()) 297 | return false; 298 | 299 | // How many can we fit in a row? 300 | 301 | int perRow = clientSize.x / (m_thumbnailOverallSize.x + m_spacing); 302 | if (perRow < 1) 303 | perRow = 1; 304 | 305 | row = item / perRow; 306 | col = item % perRow; 307 | 308 | return true; 309 | } 310 | 311 | 312 | /// Select or deselect an item 313 | void wxThumbnailCtrl::Select(int n, bool select) { 314 | wxASSERT(n < GetCount()); 315 | 316 | if (select) { 317 | if (m_selections.Index(n) == wxNOT_FOUND) 318 | m_selections.Add(n); 319 | } else { 320 | if (m_selections.Index(n) != wxNOT_FOUND) 321 | m_selections.Remove(n); 322 | } 323 | 324 | m_firstSelection = n; 325 | m_lastSelection = n; 326 | int oldFocusItem = m_focusItem; 327 | m_focusItem = n; 328 | 329 | if (!IsFrozen()) { 330 | wxRect rect; 331 | GetItemRect(n, rect); 332 | RefreshRect(rect); 333 | 334 | if (oldFocusItem != -1 && oldFocusItem != n) { 335 | GetItemRect(oldFocusItem, rect); 336 | RefreshRect(rect); 337 | } 338 | } 339 | } 340 | 341 | /// Select or deselect a range 342 | void wxThumbnailCtrl::SelectRange(int from, int to, bool select) { 343 | int first = from; 344 | int last = to; 345 | if (first < last) { 346 | first = to; 347 | last = from; 348 | } 349 | wxASSERT(first >= 0 && first < GetCount()); 350 | wxASSERT(last >= 0 && last < GetCount()); 351 | 352 | Freeze(); 353 | int i; 354 | for (i = first; i < last; i++) { 355 | Select(i, select); 356 | } 357 | m_focusItem = to; 358 | Thaw(); 359 | } 360 | 361 | /// Select all 362 | void wxThumbnailCtrl::SelectAll() { 363 | Freeze(); 364 | int i; 365 | for (i = 0; i < GetCount(); i++) { 366 | Select(i, true); 367 | } 368 | if (GetCount() > 0) { 369 | m_focusItem = GetCount() - 1; 370 | } else { 371 | m_focusItem = -1; 372 | } 373 | Thaw(); 374 | } 375 | 376 | /// Select none 377 | void wxThumbnailCtrl::SelectNone() { 378 | Freeze(); 379 | int i; 380 | for (i = 0; i < GetCount(); i++) { 381 | Select(i, false); 382 | } 383 | Thaw(); 384 | } 385 | 386 | /// Get the index of the single selection, if not multi-select. 387 | /// Returns -1 if there is no selection. 388 | int wxThumbnailCtrl::GetSelection() const { 389 | if (m_selections.GetCount() > 0) 390 | return m_selections[0u]; 391 | else 392 | return -1; 393 | } 394 | 395 | /// Returns true if the item is selected 396 | bool wxThumbnailCtrl::IsSelected(int n) const { 397 | return (m_selections.Index(n) != wxNOT_FOUND); 398 | } 399 | 400 | /// Clears all selections 401 | void wxThumbnailCtrl::ClearSelections() { 402 | int count = GetCount(); 403 | 404 | m_selections.Clear(); 405 | m_firstSelection = -1; 406 | m_lastSelection = -1; 407 | m_focusItem = -1; 408 | 409 | if (count > 0 && !IsFrozen()) { 410 | Refresh(); 411 | } 412 | } 413 | 414 | /// Set the focus item 415 | void wxThumbnailCtrl::SetFocusItem(int item) { 416 | wxASSERT(item < GetCount()); 417 | if (item != m_focusItem) { 418 | int oldFocusItem = m_focusItem; 419 | m_focusItem = item; 420 | 421 | if (!IsFrozen()) { 422 | wxRect rect; 423 | if (oldFocusItem != -1) { 424 | GetItemRect(oldFocusItem, rect); 425 | RefreshRect(rect); 426 | } 427 | if (m_focusItem != -1) { 428 | GetItemRect(m_focusItem, rect); 429 | RefreshRect(rect); 430 | } 431 | } 432 | } 433 | } 434 | 435 | /// Painting 436 | void wxThumbnailCtrl::OnDraw(wxDC &dc) { 437 | if (IsFrozen()) 438 | return; 439 | 440 | // Paint the background 441 | PaintBackground(dc); 442 | 443 | if (GetCount() == 0) 444 | return; 445 | 446 | wxRegion dirtyRegion = GetUpdateRegion(); 447 | bool isFocussed = (FindFocus() == this); 448 | 449 | int i; 450 | int count = GetCount(); 451 | int style = 0; 452 | wxRect rect, untransformedRect, imageRect, untransformedImageRect; 453 | for (i = 0; i < count; i++) { 454 | GetItemRect(i, rect); 455 | 456 | wxRegionContain c = dirtyRegion.Contains(rect); 457 | if (c != wxOutRegion) { 458 | GetItemRectImage(i, imageRect); 459 | style = 0; 460 | 461 | if (i == GetMouseHoverItem()) 462 | style |= wxTHUMBNAIL_IS_HOVER; 463 | if (IsSelected(i)) 464 | style |= wxTHUMBNAIL_SELECTED; 465 | if (isFocussed) 466 | style |= wxTHUMBNAIL_FOCUSSED; 467 | if (isFocussed && i == m_focusItem) 468 | style |= wxTHUMBNAIL_IS_FOCUS; 469 | 470 | GetItemRect(i, untransformedRect, false); 471 | GetItemRectImage(i, untransformedImageRect, false); 472 | 473 | DrawItem(i, dc, untransformedRect, untransformedImageRect, style); 474 | } 475 | } 476 | } 477 | 478 | void wxThumbnailCtrl::OnSetFocus(wxFocusEvent &) { 479 | if (GetCount() > 0) 480 | Refresh(); 481 | } 482 | 483 | void wxThumbnailCtrl::OnKillFocus(wxFocusEvent &) { 484 | if (GetCount() > 0) 485 | Refresh(); 486 | } 487 | 488 | /// Mouse-event 489 | void wxThumbnailCtrl::OnMouse(wxMouseEvent &event) { 490 | if (event.GetEventType() == wxEVT_MOUSEWHEEL) { 491 | // let the base handle mouse wheel events. 492 | event.Skip(); 493 | return; 494 | } 495 | 496 | if (event.LeftDown()) { 497 | OnLeftClickDown(event); 498 | } else if (event.LeftUp()) { 499 | OnLeftClickUp(event); 500 | } else if (event.LeftDClick()) { 501 | OnLeftDClick(event); 502 | } else if (event.RightDown()) { 503 | OnRightClickDown(event); 504 | } else if (event.RightUp()) { 505 | OnRightClickUp(event); 506 | } else if (event.Dragging() && event.LeftIsDown() && GetSelections().Count()) { 507 | wxThumbnailEvent cmdEvent( 508 | wxEVT_COMMAND_THUMBNAIL_DRAG_START, 509 | GetId()); 510 | 511 | if (wxDefaultPosition == m_dragStartPosition) 512 | m_dragStartPosition = event.GetPosition(); 513 | 514 | if ((GetWindowStyle() & wxTH_MULTIPLE_SELECT) != 0) 515 | cmdEvent.SetItemsIndex(GetSelections()); 516 | else 517 | cmdEvent.SetIndex(GetSelection()); 518 | 519 | cmdEvent.SetPosition(m_dragStartPosition); 520 | cmdEvent.SetEventObject(this); 521 | GetEventHandler()->ProcessEvent(cmdEvent); 522 | } else { 523 | m_dragStartPosition = wxDefaultPosition; 524 | } 525 | 526 | event.Skip(); 527 | } 528 | 529 | /// Left-click-down 530 | void wxThumbnailCtrl::OnLeftClickDown(wxMouseEvent &event) { 531 | SetFocus(); 532 | int n; 533 | if (HitTest(event.GetPosition(), n)) { 534 | int flags = 0; 535 | if (event.ControlDown()) 536 | flags |= wxTHUMBNAIL_CTRL_DOWN; 537 | if (event.ShiftDown()) 538 | flags |= wxTHUMBNAIL_SHIFT_DOWN; 539 | if (event.AltDown()) 540 | flags |= wxTHUMBNAIL_ALT_DOWN; 541 | 542 | EnsureVisible(n); 543 | 544 | auto change = false; 545 | 546 | if (((GetWindowStyle() & wxTH_MULTIPLE_SELECT) != 0)) 547 | change = !IsSelected(n) || (flags != 0); 548 | else 549 | change = !IsSelected(n); 550 | 551 | if (change) 552 | DoSelection(n, flags); 553 | 554 | wxThumbnailEvent cmdEvent( 555 | wxEVT_COMMAND_THUMBNAIL_LEFT_CLICK, 556 | GetId()); 557 | cmdEvent.SetEventObject(this); 558 | cmdEvent.SetIndex(n); 559 | cmdEvent.SetFlags(flags); 560 | GetEventHandler()->ProcessEvent(cmdEvent); 561 | } 562 | } 563 | 564 | /// Left-click-up 565 | void wxThumbnailCtrl::OnLeftClickUp(wxMouseEvent &event) { 566 | SetFocus(); 567 | int n; 568 | if (HitTest(event.GetPosition(), n)) { 569 | int flags = 0; 570 | if (event.ControlDown()) 571 | flags |= wxTHUMBNAIL_CTRL_DOWN; 572 | if (event.ShiftDown()) 573 | flags |= wxTHUMBNAIL_SHIFT_DOWN; 574 | if (event.AltDown()) 575 | flags |= wxTHUMBNAIL_ALT_DOWN; 576 | 577 | EnsureVisible(n); 578 | 579 | if ((GetWindowStyle() & wxTH_MULTIPLE_SELECT) != 0) { 580 | if ((GetSelections().Count() > 1) && IsSelected(n) && flags == 0) 581 | DoSelection(n, flags); 582 | } 583 | } 584 | } 585 | 586 | /// Right-click-down 587 | void wxThumbnailCtrl::OnRightClickDown(wxMouseEvent &event) { 588 | SetFocus(); 589 | int n; 590 | if (HitTest(event.GetPosition(), n)) { 591 | int flags = 0; 592 | if (event.ControlDown()) 593 | flags |= wxTHUMBNAIL_CTRL_DOWN; 594 | if (event.ShiftDown()) 595 | flags |= wxTHUMBNAIL_SHIFT_DOWN; 596 | if (event.AltDown()) 597 | flags |= wxTHUMBNAIL_ALT_DOWN; 598 | 599 | SetFocusItem(n); 600 | 601 | const wxArrayInt &selections = GetSelections(); 602 | if (std::find(selections.begin(), selections.end(), n) == selections.end()) { 603 | SelectNone(); 604 | Select(n); 605 | wxThumbnailEvent cmdEvent( 606 | wxEVT_COMMAND_THUMBNAIL_ITEM_SELECTED, 607 | GetId()); 608 | cmdEvent.SetEventObject(this); 609 | cmdEvent.SetIndex(n); 610 | cmdEvent.SetFlags(flags); 611 | GetEventHandler()->ProcessEvent(cmdEvent); 612 | 613 | wxThumbnailEvent event(wxEVT_COMMAND_THUMBNAIL_SELECTION_CHANGED, GetId()); 614 | event.SetEventObject(this); 615 | GetEventHandler()->ProcessEvent(event); 616 | } 617 | } 618 | 619 | event.Skip(); 620 | } 621 | 622 | /// Right-click-up 623 | void wxThumbnailCtrl::OnRightClickUp(wxMouseEvent &event) { 624 | SetFocus(); 625 | 626 | int flags = 0; 627 | if (event.ControlDown()) 628 | flags |= wxTHUMBNAIL_CTRL_DOWN; 629 | if (event.ShiftDown()) 630 | flags |= wxTHUMBNAIL_SHIFT_DOWN; 631 | if (event.AltDown()) 632 | flags |= wxTHUMBNAIL_ALT_DOWN; 633 | 634 | int n; 635 | if (HitTest(event.GetPosition(), n)) { 636 | SetFocusItem(n); 637 | 638 | const wxArrayInt &selections = GetSelections(); 639 | if (std::find(selections.begin(), selections.end(), n) != selections.end()) { 640 | wxThumbnailEvent cmdEvent( 641 | wxEVT_COMMAND_THUMBNAIL_RIGHT_CLICK, 642 | GetId()); 643 | cmdEvent.SetEventObject(this); 644 | cmdEvent.SetIndex(n); 645 | cmdEvent.SetFlags(flags); 646 | cmdEvent.SetPosition(event.GetPosition()); 647 | GetEventHandler()->ProcessEvent(cmdEvent); 648 | } 649 | } else { 650 | wxThumbnailEvent cmdEvent( 651 | wxEVT_COMMAND_THUMBNAIL_VIEW_RIGHT_CLICK, 652 | GetId()); 653 | cmdEvent.SetEventObject(this); 654 | cmdEvent.SetFlags(flags); 655 | cmdEvent.SetPosition(event.GetPosition()); 656 | GetEventHandler()->ProcessEvent(cmdEvent); 657 | } 658 | } 659 | 660 | /// Left-double-click 661 | void wxThumbnailCtrl::OnLeftDClick(wxMouseEvent &event) { 662 | int n; 663 | if (HitTest(event.GetPosition(), n)) { 664 | int flags = 0; 665 | if (event.ControlDown()) 666 | flags |= wxTHUMBNAIL_CTRL_DOWN; 667 | if (event.ShiftDown()) 668 | flags |= wxTHUMBNAIL_SHIFT_DOWN; 669 | if (event.AltDown()) 670 | flags |= wxTHUMBNAIL_ALT_DOWN; 671 | 672 | wxThumbnailEvent cmdEvent( 673 | wxEVT_COMMAND_THUMBNAIL_LEFT_DCLICK, 674 | GetId()); 675 | cmdEvent.SetEventObject(this); 676 | cmdEvent.SetIndex(n); 677 | cmdEvent.SetFlags(flags); 678 | GetEventHandler()->ProcessEvent(cmdEvent); 679 | } 680 | } 681 | 682 | /// Mouse motion 683 | void wxThumbnailCtrl::OnMouseMotion(wxMouseEvent &event) { 684 | int flags = 0; 685 | if (event.ControlDown()) 686 | flags |= wxTHUMBNAIL_CTRL_DOWN; 687 | if (event.ShiftDown()) 688 | flags |= wxTHUMBNAIL_SHIFT_DOWN; 689 | if (event.AltDown()) 690 | flags |= wxTHUMBNAIL_ALT_DOWN; 691 | 692 | int n; 693 | if (HitTest(event.GetPosition(), n)) { 694 | SetMouseHoverItem(n, flags); 695 | } else { 696 | SetMouseHoverItem(wxNOT_FOUND, flags); 697 | } 698 | 699 | event.Skip(); 700 | } 701 | 702 | /// Mouse leave 703 | void wxThumbnailCtrl::OnMouseLeave(wxMouseEvent &event) { 704 | int flags = 0; 705 | if (event.ControlDown()) 706 | flags |= wxTHUMBNAIL_CTRL_DOWN; 707 | if (event.ShiftDown()) 708 | flags |= wxTHUMBNAIL_SHIFT_DOWN; 709 | if (event.AltDown()) 710 | flags |= wxTHUMBNAIL_ALT_DOWN; 711 | SetMouseHoverItem(wxNOT_FOUND, flags); 712 | 713 | event.Skip(); 714 | } 715 | 716 | /// Key press 717 | void wxThumbnailCtrl::OnChar(wxKeyEvent &event) { 718 | int flags = 0; 719 | if (event.ControlDown()) 720 | flags |= wxTHUMBNAIL_CTRL_DOWN; 721 | if (event.ShiftDown()) 722 | flags |= wxTHUMBNAIL_SHIFT_DOWN; 723 | if (event.AltDown()) 724 | flags |= wxTHUMBNAIL_ALT_DOWN; 725 | 726 | if (event.GetKeyCode() == WXK_LEFT || 727 | event.GetKeyCode() == WXK_RIGHT || 728 | event.GetKeyCode() == WXK_UP || 729 | event.GetKeyCode() == WXK_DOWN || 730 | event.GetKeyCode() == WXK_HOME || 731 | event.GetKeyCode() == WXK_PAGEUP || 732 | event.GetKeyCode() == WXK_PAGEDOWN || 733 | //event.GetKeyCode() == WXK_PRIOR || 734 | //event.GetKeyCode() == WXK_NEXT || 735 | event.GetKeyCode() == WXK_END) { 736 | Navigate(event.GetKeyCode(), flags); 737 | } else if (event.GetKeyCode() == WXK_RETURN) { 738 | wxThumbnailEvent cmdEvent( 739 | wxEVT_COMMAND_THUMBNAIL_RETURN, 740 | GetId()); 741 | cmdEvent.SetEventObject(this); 742 | cmdEvent.SetFlags(flags); 743 | GetEventHandler()->ProcessEvent(cmdEvent); 744 | } else 745 | event.Skip(); 746 | } 747 | 748 | /// Keyboard navigation 749 | bool wxThumbnailCtrl::Navigate(int keyCode, int flags) { 750 | if (GetCount() == 0) 751 | return false; 752 | 753 | wxSize clientSize = GetClientSize(); 754 | int perRow = clientSize.x / (m_thumbnailOverallSize.x + m_spacing); 755 | if (perRow < 1) 756 | perRow = 1; 757 | 758 | int rowsInView = clientSize.y / (m_thumbnailOverallSize.y + m_spacing); 759 | if (rowsInView < 1) 760 | rowsInView = 1; 761 | 762 | int focus = m_focusItem; 763 | if (focus == -1) 764 | focus = m_lastSelection; 765 | 766 | if (focus == -1 || focus >= GetCount()) { 767 | m_lastSelection = 0; 768 | DoSelection(m_lastSelection, flags); 769 | ScrollIntoView(m_lastSelection, keyCode); 770 | return true; 771 | } 772 | 773 | if (keyCode == WXK_RIGHT) { 774 | int next = focus + 1; 775 | if (next < GetCount()) { 776 | DoSelection(next, flags); 777 | ScrollIntoView(next, keyCode); 778 | } 779 | } else if (keyCode == WXK_LEFT) { 780 | int next = focus - 1; 781 | if (next >= 0) { 782 | DoSelection(next, flags); 783 | ScrollIntoView(next, keyCode); 784 | } 785 | } else if (keyCode == WXK_UP) { 786 | int next = focus - perRow; 787 | if (next >= 0) { 788 | DoSelection(next, flags); 789 | ScrollIntoView(next, keyCode); 790 | } 791 | } else if (keyCode == WXK_DOWN) { 792 | int next = focus + perRow; 793 | if (next < GetCount()) { 794 | DoSelection(next, flags); 795 | ScrollIntoView(next, keyCode); 796 | } 797 | } else if (keyCode == WXK_PAGEUP /*|| keyCode == WXK_PRIOR*/) { 798 | int next = focus - (perRow * rowsInView); 799 | if (next < 0) 800 | next = 0; 801 | if (next >= 0) { 802 | DoSelection(next, flags); 803 | ScrollIntoView(next, keyCode); 804 | } 805 | } else if (keyCode == WXK_PAGEDOWN /*|| keyCode == WXK_NEXT*/) { 806 | int next = focus + (perRow * rowsInView); 807 | if (next >= GetCount()) 808 | next = GetCount() - 1; 809 | if (next < GetCount()) { 810 | DoSelection(next, flags); 811 | ScrollIntoView(next, keyCode); 812 | } 813 | } else if (keyCode == WXK_HOME) { 814 | DoSelection(0, flags); 815 | ScrollIntoView(0, keyCode); 816 | } else if (keyCode == WXK_END) { 817 | DoSelection(GetCount() - 1, flags); 818 | ScrollIntoView(GetCount() - 1, keyCode); 819 | } 820 | return true; 821 | } 822 | 823 | /// Scroll to see the image 824 | void wxThumbnailCtrl::ScrollIntoView(int n, int keyCode) { 825 | wxRect rect; 826 | GetItemRect(n, rect, false); // _Not_ relative to scroll start 827 | 828 | int ppuX, ppuY; 829 | GetScrollPixelsPerUnit(&ppuX, &ppuY); 830 | 831 | int startX, startY; 832 | GetViewStart(&startX, &startY); 833 | startX = 0; 834 | startY = startY * ppuY; 835 | 836 | int sx, sy; 837 | GetVirtualSize(&sx, &sy); 838 | sx = 0; 839 | if (ppuY != 0) 840 | sy = sy / ppuY; 841 | 842 | wxSize clientSize = GetClientSize(); 843 | 844 | // Going down 845 | if (keyCode == WXK_DOWN || keyCode == WXK_RIGHT || keyCode == WXK_END /*|| keyCode == WXK_NEXT*/ || 846 | keyCode == WXK_PAGEDOWN) { 847 | if ((rect.y + rect.height) > (clientSize.y + startY)) { 848 | // Make it scroll so this item is at the bottom 849 | // of the window 850 | int y = rect.y - (clientSize.y - m_thumbnailOverallSize.y - m_spacing); 851 | SetScrollbars(ppuX, ppuY, sx, sy, 0, (int) (0.5 + y / ppuY)); 852 | } else if (rect.y < startY) { 853 | // Make it scroll so this item is at the top 854 | // of the window 855 | int y = rect.y; 856 | SetScrollbars(ppuX, ppuY, sx, sy, 0, (int) (0.5 + y / ppuY)); 857 | } 858 | } 859 | // Going up 860 | else if (keyCode == WXK_UP || keyCode == WXK_LEFT || keyCode == WXK_HOME /*|| keyCode == WXK_PRIOR*/ || 861 | keyCode == WXK_PAGEUP) { 862 | if (rect.y < startY) { 863 | // Make it scroll so this item is at the top 864 | // of the window 865 | int y = rect.y; 866 | SetScrollbars(ppuX, ppuY, sx, sy, 0, (int) (0.5 + y / ppuY)); 867 | } else if ((rect.y + rect.height) > (clientSize.y + startY)) { 868 | // Make it scroll so this item is at the bottom 869 | // of the window 870 | int y = rect.y - (clientSize.y - m_thumbnailOverallSize.y - m_spacing); 871 | SetScrollbars(ppuX, ppuY, sx, sy, 0, (int) (0.5 + y / ppuY)); 872 | } 873 | } 874 | } 875 | 876 | /// Scrolls the item into view if necessary 877 | void wxThumbnailCtrl::EnsureVisible(int n) { 878 | wxRect rect; 879 | GetItemRect(n, rect, false); // _Not_ relative to scroll start 880 | 881 | int ppuX, ppuY; 882 | GetScrollPixelsPerUnit(&ppuX, &ppuY); 883 | 884 | if (ppuY == 0) 885 | return; 886 | 887 | int startX, startY; 888 | GetViewStart(&startX, &startY); 889 | startX = 0; 890 | startY = startY * ppuY; 891 | 892 | int sx, sy; 893 | GetVirtualSize(&sx, &sy); 894 | sx = 0; 895 | if (ppuY != 0) 896 | sy = sy / ppuY; 897 | 898 | wxSize clientSize = GetClientSize(); 899 | 900 | if ((rect.y + rect.height) > (clientSize.y + startY)) { 901 | // Make it scroll so this item is at the bottom 902 | // of the window 903 | int y = rect.y - (clientSize.y - m_thumbnailOverallSize.y - m_spacing); 904 | SetScrollbars(ppuX, ppuY, sx, sy, 0, (int) (0.5 + y / ppuY)); 905 | } else if (rect.y < startY) { 906 | // Make it scroll so this item is at the top 907 | // of the window 908 | int y = rect.y; 909 | SetScrollbars(ppuX, ppuY, sx, sy, 0, (int) (0.5 + y / ppuY)); 910 | } 911 | } 912 | 913 | /// Sizing 914 | void wxThumbnailCtrl::OnSize(wxSizeEvent &event) { 915 | SetupScrollbars(); 916 | event.Skip(); 917 | } 918 | 919 | /// Set up scrollbars, e.g. after a resize 920 | void wxThumbnailCtrl::SetupScrollbars() { 921 | if (IsFrozen()) 922 | return; 923 | 924 | if (GetCount() == 0) { 925 | SetScrollbars(0, 0, 0, 0, 0, 0); 926 | return; 927 | } 928 | 929 | int lastItem = wxMax(0, GetCount() - 1); 930 | int pixelsPerUnit = 10; 931 | wxSize clientSize = GetClientSize(); 932 | 933 | int row, col; 934 | GetRowCol(lastItem, clientSize, row, col); 935 | 936 | int maxHeight = (row + 1) * (m_thumbnailOverallSize.y + m_spacing) + m_spacing; 937 | 938 | int unitsY = maxHeight / pixelsPerUnit; 939 | 940 | int startX, startY; 941 | GetViewStart(&startX, &startY); 942 | 943 | int maxPositionX = 0; // wxMax(sz.x - clientSize.x, 0); 944 | int maxPositionY = (wxMax(maxHeight - clientSize.y, 0)) / pixelsPerUnit; 945 | 946 | // Move to previous scroll position if 947 | // possible 948 | SetScrollbars(0, pixelsPerUnit, 949 | 0, unitsY, 950 | wxMin(maxPositionX, startX), wxMin(maxPositionY, startY)); 951 | } 952 | 953 | /// Draws the background for the item, including bevel 954 | bool wxThumbnailCtrl::DrawItem(int n, wxDC &dc, const wxRect &rect, const wxRect &imageRect, int style) { 955 | auto item = GetItem(n); 956 | if (item) { 957 | return item->DrawBackground(dc, this, rect, imageRect, style, n) && item->Draw(dc, this, imageRect, style, n); 958 | } else { 959 | return false; 960 | } 961 | } 962 | 963 | /// Do (de)selection 964 | void wxThumbnailCtrl::DoSelection(int n, int flags) { 965 | bool isSelected = IsSelected(n); 966 | 967 | wxArrayInt stateChanged; 968 | 969 | bool multiSelect = (GetWindowStyle() & wxTH_MULTIPLE_SELECT) != 0; 970 | 971 | if (multiSelect && (flags & wxTHUMBNAIL_CTRL_DOWN) == wxTHUMBNAIL_CTRL_DOWN) { 972 | Select(n, !isSelected); 973 | stateChanged.Add(n); 974 | } else if (multiSelect && (flags & wxTHUMBNAIL_SHIFT_DOWN) == wxTHUMBNAIL_SHIFT_DOWN) { 975 | // We need to find the last item selected, 976 | // and select all in between. 977 | 978 | int first = m_firstSelection; 979 | 980 | // Want to keep the 'first' selection 981 | // if we're extending the selection 982 | bool keepFirstSelection = false; 983 | wxArrayInt oldSelections = m_selections; 984 | 985 | m_selections.Clear(); // TODO: need to refresh those that become unselected. Store old selections, compare with new 986 | 987 | if (m_firstSelection != -1 && m_firstSelection < GetCount() && m_firstSelection != n) { 988 | int step = (n < m_firstSelection) ? -1 : 1; 989 | int i; 990 | for (i = m_firstSelection; i != n; i += step) { 991 | if (!IsSelected(i)) { 992 | m_selections.Add(i); 993 | stateChanged.Add(i); 994 | 995 | wxRect rect; 996 | GetItemRect(i, rect); 997 | RefreshRect(rect); 998 | } 999 | } 1000 | keepFirstSelection = true; 1001 | } 1002 | 1003 | // Refresh all the previously selected items that became unselected 1004 | size_t i; 1005 | for (i = 0; i < oldSelections.GetCount(); i++) { 1006 | if (!IsSelected(oldSelections[i])) { 1007 | wxRect rect; 1008 | GetItemRect(oldSelections[i], rect); 1009 | RefreshRect(rect); 1010 | } 1011 | } 1012 | 1013 | Select(n, true); 1014 | if (stateChanged.Index(n) == wxNOT_FOUND) 1015 | stateChanged.Add(n); 1016 | 1017 | if (keepFirstSelection) 1018 | m_firstSelection = first; 1019 | } else { 1020 | size_t i = 0; 1021 | for (i = 0; i < m_selections.GetCount(); i++) { 1022 | wxRect rect; 1023 | GetItemRect(m_selections[i], rect); 1024 | RefreshRect(rect); 1025 | 1026 | stateChanged.Add(i); 1027 | } 1028 | 1029 | m_selections.Clear(); 1030 | Select(n, true); 1031 | if (stateChanged.Index(n) == wxNOT_FOUND) 1032 | stateChanged.Add(n); 1033 | } 1034 | 1035 | // Now notify the app of any selection changes 1036 | size_t i = 0; 1037 | for (i = 0; i < stateChanged.GetCount(); i++) { 1038 | wxThumbnailEvent event( 1039 | m_selections.Index(stateChanged[i]) != wxNOT_FOUND ? wxEVT_COMMAND_THUMBNAIL_ITEM_SELECTED 1040 | : wxEVT_COMMAND_THUMBNAIL_ITEM_DESELECTED, 1041 | GetId()); 1042 | event.SetEventObject(this); 1043 | event.SetIndex(stateChanged[i]); 1044 | GetEventHandler()->ProcessEvent(event); 1045 | } 1046 | 1047 | if (stateChanged.GetCount() > 0) { 1048 | wxThumbnailEvent event(wxEVT_COMMAND_THUMBNAIL_SELECTION_CHANGED, GetId()); 1049 | event.SetEventObject(this); 1050 | GetEventHandler()->ProcessEvent(event); 1051 | } 1052 | } 1053 | 1054 | /// Find the item under the given point 1055 | bool wxThumbnailCtrl::HitTest(const wxPoint &pt, int &n) { 1056 | wxSize clientSize = GetClientSize(); 1057 | int startX, startY; 1058 | int ppuX, ppuY; 1059 | GetViewStart(&startX, &startY); 1060 | GetScrollPixelsPerUnit(&ppuX, &ppuY); 1061 | 1062 | int perRow = clientSize.x / (m_thumbnailOverallSize.x + m_spacing); 1063 | if (perRow < 1) 1064 | perRow = 1; 1065 | 1066 | int colPos = (int) (pt.x / (m_thumbnailOverallSize.x + m_spacing)); 1067 | int rowPos = (int) ((pt.y + startY * ppuY) / (m_thumbnailOverallSize.y + m_spacing)); 1068 | 1069 | int itemN = (rowPos * perRow + colPos); 1070 | if (itemN >= GetCount()) 1071 | return false; 1072 | 1073 | wxRect rect; 1074 | GetItemRect(itemN, rect); 1075 | if (rect.Contains(pt)) { 1076 | n = itemN; 1077 | return true; 1078 | } 1079 | 1080 | return false; 1081 | } 1082 | 1083 | void wxThumbnailCtrl::OnSelectAll(wxCommandEvent &) { 1084 | SelectAll(); 1085 | } 1086 | 1087 | void wxThumbnailCtrl::OnUpdateSelectAll(wxUpdateUIEvent &event) { 1088 | event.Enable(GetCount() > 0); 1089 | } 1090 | 1091 | /// Paint the background 1092 | void wxThumbnailCtrl::PaintBackground(wxDC &dc) { 1093 | wxColour backgroundColour = GetBackgroundColour(); 1094 | if (!backgroundColour.Ok()) 1095 | backgroundColour = wxSystemSettings::GetColour(wxSYS_COLOUR_3DFACE); 1096 | 1097 | // Clear the background 1098 | dc.SetBrush(wxBrush(backgroundColour)); 1099 | dc.SetPen(*wxTRANSPARENT_PEN); 1100 | wxRect windowRect(wxPoint(0, 0), GetClientSize()); 1101 | windowRect.x -= 2; 1102 | windowRect.y -= 2; 1103 | windowRect.width += 4; 1104 | windowRect.height += 4; 1105 | 1106 | // We need to shift the rectangle to take into account 1107 | // scrolling. Converting device to logical coordinates. 1108 | CalcUnscrolledPosition(windowRect.x, windowRect.y, &windowRect.x, &windowRect.y); 1109 | dc.DrawRectangle(windowRect); 1110 | } 1111 | 1112 | void wxThumbnailCtrl::SetMouseHoverItem(int n, int flags) { 1113 | if (m_hoverItem != n) { 1114 | wxThumbnailEvent cmdEvent( 1115 | wxEVT_COMMAND_THUMBNAIL_ITEM_HOVER_CHANGED, 1116 | GetId()); 1117 | cmdEvent.SetEventObject(this); 1118 | cmdEvent.SetFlags(flags); 1119 | cmdEvent.SetIndex(m_hoverItem); 1120 | m_hoverItem = n; 1121 | GetEventHandler()->ProcessEvent(cmdEvent); 1122 | } 1123 | } 1124 | 1125 | /*! 1126 | * wxThumbnailItem 1127 | */ 1128 | 1129 | IMPLEMENT_CLASS(wxThumbnailItem, wxObject) 1130 | 1131 | /// Refresh Item 1132 | bool wxThumbnailItem::Refresh(wxThumbnailCtrl *ctrl, int index) { 1133 | wxRect r; 1134 | ctrl->GetItemRect(index, r); 1135 | ctrl->RefreshRect(r); 1136 | return true; 1137 | } 1138 | 1139 | /// Draw the item background. It has the default implementation. 1140 | /// You have to ovveride this function inorder to provide your implementaion. 1141 | bool 1142 | wxThumbnailItem::DrawBackground(wxDC &dc, wxThumbnailCtrl *ctrl, const wxRect &rect, const wxRect &imageRect, int style, 1143 | int) { 1144 | auto mediumGrey = ctrl->GetUnselectedThumbnailBackgroundColour(); 1145 | auto unfocussedDarkGrey = ctrl->GetSelectedThumbnailUnfocussedBackgroundColour(); 1146 | auto focussedDarkGrey = ctrl->GetSelectedThumbnailFocussedBackgroundColour(); 1147 | wxColour darkGrey; 1148 | if (style & wxTHUMBNAIL_FOCUSSED) 1149 | darkGrey = focussedDarkGrey; 1150 | else 1151 | darkGrey = unfocussedDarkGrey; 1152 | 1153 | if (style & wxTHUMBNAIL_SELECTED) { 1154 | wxBrush brush(darkGrey); 1155 | wxPen pen(darkGrey); 1156 | dc.SetBrush(brush); 1157 | dc.SetPen(pen); 1158 | } else { 1159 | wxBrush brush(mediumGrey); 1160 | wxPen pen(mediumGrey); 1161 | dc.SetBrush(brush); 1162 | dc.SetPen(pen); 1163 | } 1164 | 1165 | dc.DrawRectangle(rect); 1166 | 1167 | if (ctrl->IsOutlinesShown()) { 1168 | if (style & wxTHUMBNAIL_SELECTED) { 1169 | dc.SetPen(*wxWHITE_PEN); 1170 | dc.DrawLine(rect.GetRight(), rect.GetTop(), rect.GetRight(), rect.GetBottom()); 1171 | dc.DrawLine(rect.GetLeft(), rect.GetBottom(), rect.GetRight() + 1, rect.GetBottom()); 1172 | 1173 | dc.SetPen(*wxBLACK_PEN); 1174 | dc.DrawLine(rect.GetLeft(), rect.GetTop(), rect.GetRight(), rect.GetTop()); 1175 | dc.DrawLine(rect.GetLeft(), rect.GetTop(), rect.GetLeft(), rect.GetBottom()); 1176 | } else { 1177 | dc.SetPen(*wxBLACK_PEN); 1178 | dc.DrawLine(rect.GetRight(), rect.GetTop(), rect.GetRight(), rect.GetBottom()); 1179 | dc.DrawLine(rect.GetLeft(), rect.GetBottom(), rect.GetRight() + 1, rect.GetBottom()); 1180 | 1181 | dc.SetPen(*wxWHITE_PEN); 1182 | dc.DrawLine(rect.GetLeft(), rect.GetTop(), rect.GetRight(), rect.GetTop()); 1183 | dc.DrawLine(rect.GetLeft(), rect.GetTop(), rect.GetLeft(), rect.GetBottom()); 1184 | } 1185 | } 1186 | 1187 | if (!m_label.IsEmpty() && (ctrl->GetWindowStyle() & wxTH_TEXT_LABEL)) { 1188 | dc.SetFont(ctrl->GetFont()); 1189 | if (style & wxTHUMBNAIL_SELECTED) 1190 | dc.SetTextForeground(*wxWHITE); 1191 | else 1192 | dc.SetTextForeground(*wxBLACK); 1193 | dc.SetBackgroundMode(wxTRANSPARENT); 1194 | 1195 | int margin = ctrl->GetThumbnailMargin(); 1196 | 1197 | wxRect fRect; 1198 | fRect.x = rect.x + margin; 1199 | fRect.y = rect.y + margin + imageRect.height + margin; 1200 | fRect.width = rect.width - 2 * margin; 1201 | fRect.height = rect.height - margin - imageRect.height - margin - margin; 1202 | 1203 | wxCoord textW, textH; 1204 | dc.GetTextExtent(m_label, &textW, &textH); 1205 | 1206 | dc.SetClippingRegion(fRect); 1207 | int x = fRect.x + wxMax(0, (fRect.width - textW) / 2); 1208 | int y = fRect.y; 1209 | dc.DrawText(m_label, x, y); 1210 | dc.DestroyClippingRegion(); 1211 | } 1212 | 1213 | // If the item itself is the focus, draw a dotted 1214 | // rectangle around it 1215 | if (style & wxTHUMBNAIL_IS_FOCUS) { 1216 | wxPen dottedPen(ctrl->GetFocusRectColour(), 1, wxDOT); 1217 | dc.SetPen(dottedPen); 1218 | dc.SetBrush(*wxTRANSPARENT_BRUSH); 1219 | wxRect focusRect = imageRect; 1220 | focusRect.x--; 1221 | focusRect.y--; 1222 | focusRect.width += 2; 1223 | focusRect.height += 2; 1224 | dc.DrawRectangle(focusRect); 1225 | } 1226 | 1227 | return true; 1228 | } 1229 | 1230 | /// Draw the item 1231 | bool wxThumbnailItem::Draw(wxDC &dc, wxThumbnailCtrl *, const wxRect &rect, int, int) { 1232 | if (m_bitmap.Ok()) { 1233 | dc.DrawBitmap(m_bitmap, rect.GetX(), rect.GetY()); 1234 | } 1235 | return true; 1236 | } 1237 | -------------------------------------------------------------------------------- /GUI/thumbnailctrl.h: -------------------------------------------------------------------------------- 1 | ///////////////////////////////////////////////////////////////////////////// 2 | // Name: thumbnailctrl.h 3 | // Purpose: Displays a scrolling window of thumbnails 4 | // Author: Julian Smart 5 | // Modified by: Anil Kumar 6 | // Created: 03/08/04 17:22:46 7 | // RCS-ID: 8 | // Copyright: (c) Julian Smart 9 | // Licence: wxWidgets Licence 10 | ///////////////////////////////////////////////////////////////////////////// 11 | 12 | #ifndef _WX_THUMBNAILCTRL_H_ 13 | #define _WX_THUMBNAILCTRL_H_ 14 | 15 | #if defined(__GNUG__) && !defined(__APPLE__) 16 | #pragma interface "thumbnailctrl.cpp" 17 | #endif 18 | 19 | #include "wx/dynarray.h" 20 | 21 | /*! Styles 22 | */ 23 | 24 | #define wxTH_MULTIPLE_SELECT 0x0010 25 | #define wxTH_SINGLE_SELECT 0x0000 26 | #define wxTH_TEXT_LABEL 0x0020 27 | #define wxTH_IMAGE_LABEL 0x0040 28 | 29 | /*! Flags 30 | */ 31 | 32 | #define wxTHUMBNAIL_SHIFT_DOWN 0x01 33 | #define wxTHUMBNAIL_CTRL_DOWN 0x02 34 | #define wxTHUMBNAIL_ALT_DOWN 0x04 35 | 36 | /*! Defaults 37 | */ 38 | 39 | #define wxTHUMBNAIL_DEFAULT_OVERALL_SIZE wxSize(-1, -1) 40 | #define wxTHUMBNAIL_DEFAULT_IMAGE_SIZE wxSize(80, 80) 41 | #define wxTHUMBNAIL_DEFAULT_SPACING 6 42 | #define wxTHUMBNAIL_DEFAULT_MARGIN 3 43 | #define wxTHUMBNAIL_DEFAULT_UNFOCUSSED_BACKGROUND wxColour(175, 175, 175) 44 | #define wxTHUMBNAIL_DEFAULT_FOCUSSED_BACKGROUND wxColour(140, 140, 140) 45 | #define wxTHUMBNAIL_DEFAULT_UNSELECTED_BACKGROUND wxSystemSettings::GetColour(wxSYS_COLOUR_3DFACE) 46 | #define wxTHUMBNAIL_DEFAULT_TYPE_COLOUR wxColour(0, 0, 200) 47 | #define wxTHUMBNAIL_DEFAULT_FOCUS_RECT_COLOUR wxColour(100, 80, 80) 48 | 49 | class wxThumbnailCtrl; 50 | 51 | // Drawing styles/states 52 | #define wxTHUMBNAIL_SELECTED 0x01 53 | // The control is focussed 54 | #define wxTHUMBNAIL_FOCUSSED 0x04 55 | // The item itself has the focus 56 | #define wxTHUMBNAIL_IS_FOCUS 0x08 57 | #define wxTHUMBNAIL_IS_HOVER 0x10 58 | 59 | class wxThumbnailItem : public wxObject { 60 | DECLARE_CLASS(wxThumbnailItem) 61 | 62 | public: 63 | explicit wxThumbnailItem(const wxString &label = wxEmptyString) : m_label(label) {} 64 | 65 | /// Label 66 | void SetLabel(const wxString &filename) { 67 | m_label = filename; 68 | m_bitmap = wxNullBitmap; 69 | } 70 | 71 | const wxString &GetLabel() const { return m_label; } 72 | 73 | /// Refresh the item 74 | virtual bool Refresh(wxThumbnailCtrl *ctrl, int index); 75 | 76 | /// Draw the background 77 | virtual bool 78 | DrawBackground(wxDC &dc, wxThumbnailCtrl *ctrl, const wxRect &rect, const wxRect &imageRect, int style, int index); 79 | 80 | /// Draw the item 81 | virtual bool Draw(wxDC &dc, wxThumbnailCtrl *ctrl, const wxRect &rect, int style, int index); 82 | 83 | wxBitmap const &GetBitmap() { return m_bitmap; } 84 | void SetBitmap(const wxBitmap &x) { m_bitmap = x ; } 85 | 86 | private: 87 | wxBitmap m_bitmap; 88 | wxString m_label; 89 | }; 90 | 91 | WX_DECLARE_OBJARRAY(wxThumbnailItem, wxThumbnailItemArray); 92 | 93 | class wxThumbnailCtrl : public wxScrolledCanvas { 94 | DECLARE_CLASS(wxThumbnailCtrl) 95 | 96 | DECLARE_EVENT_TABLE() 97 | 98 | public: 99 | wxThumbnailCtrl(); 100 | 101 | explicit wxThumbnailCtrl(wxWindow *parent, wxWindowID id = -1, const wxPoint &pos = wxDefaultPosition, 102 | const wxSize &size = wxDefaultSize, 103 | long style = wxTH_TEXT_LABEL | wxTH_IMAGE_LABEL | wxBORDER_THEME); 104 | 105 | /// Creation 106 | bool Create(wxWindow *parent, wxWindowID id = -1, const wxPoint &pos = wxDefaultPosition, 107 | const wxSize &size = wxDefaultSize, 108 | long style = wxTH_TEXT_LABEL | wxTH_IMAGE_LABEL | wxBORDER_THEME); 109 | 110 | /// Scrolls the item into view if necessary 111 | void EnsureVisible(int n); 112 | 113 | /// Append a single item 114 | virtual int Append(wxThumbnailItem *item); 115 | 116 | /// Insert a single item 117 | virtual int Insert(wxThumbnailItem *item, int pos = 0); 118 | 119 | /// Clear all items 120 | virtual void Clear(); 121 | 122 | /// Delete this item 123 | virtual void Delete(int n); 124 | 125 | /// Get the number of items in the control 126 | virtual int GetCount() const { return m_items.GetCount(); } 127 | 128 | /// Is the control empty? 129 | bool IsEmpty() const { return GetCount() == 0; } 130 | 131 | /// Get the nth item 132 | wxThumbnailItem *GetItem(int n); 133 | 134 | /// Get the overall rect of the given item 135 | /// If transform is true, rect is relative to the scroll viewport 136 | /// (i.e. may be negative) 137 | bool GetItemRect(int item, wxRect &rect, bool transform = true); 138 | 139 | /// Get the image rect of the given item 140 | bool GetItemRectImage(int item, wxRect &rect, bool transform = true); 141 | 142 | /// Return the row and column given the client 143 | /// size and a left-to-right, top-to-bottom layout 144 | /// assumption 145 | bool GetRowCol(int item, const wxSize &clientSize, int &row, int &col); 146 | 147 | /// Get the focus item, or -1 if there is none 148 | int GetFocusItem() const { return m_focusItem; } 149 | 150 | /// Set the focus item 151 | void SetFocusItem(int item); 152 | 153 | /// Select or deselect an item 154 | void Select(int n, bool select = true); 155 | 156 | /// Select or deselect a range 157 | void SelectRange(int from, int to, bool select = true); 158 | 159 | /// Select all 160 | void SelectAll(); 161 | 162 | /// Select none 163 | void SelectNone(); 164 | 165 | /// Get the index of the single selection, if not multi-select. 166 | /// Returns -1 if there is no selection. 167 | int GetSelection() const; 168 | 169 | /// Get indexes of all selections, if multi-select 170 | const wxArrayInt &GetSelections() const { return m_selections; } 171 | 172 | /// Returns true if the item is selected 173 | bool IsSelected(int n) const; 174 | 175 | /// Clears all selections 176 | void ClearSelections(); 177 | 178 | /// Get mouse hover item 179 | int GetMouseHoverItem() const { return m_hoverItem; } 180 | 181 | /// Find the item under the given point 182 | bool HitTest(const wxPoint &pt, int &n); 183 | 184 | /// The overall size of the thumbnail, including decorations. 185 | /// DON'T USE THIS from the application, since it will 186 | /// normally be calculated by SetThumbnailImageSize. 187 | void SetThumbnailOverallSize(const wxSize &sz) { m_thumbnailOverallSize = sz; } 188 | 189 | const wxSize &GetThumbnailOverallSize() const { return m_thumbnailOverallSize; } 190 | 191 | /// The size of the image part 192 | void SetThumbnailImageSize(const wxSize &sz); 193 | 194 | const wxSize &GetThumbnailImageSize() const { return m_thumbnailImageSize; } 195 | 196 | /// The inter-item spacing 197 | void SetSpacing(int spacing) { m_spacing = spacing; } 198 | 199 | int GetSpacing() const { return m_spacing; } 200 | 201 | /// The margin between elements within the thumbnail 202 | void SetThumbnailMargin(int margin) { m_thumbnailMargin = margin; } 203 | 204 | int GetThumbnailMargin() const { return m_thumbnailMargin; } 205 | 206 | /// The height required for text in the thumbnail 207 | void SetThumbnailTextHeight(int h) { m_thumbnailTextHeight = h; } 208 | 209 | int GetThumbnailTextHeight() const { return m_thumbnailTextHeight; } 210 | 211 | /// The focussed and unfocussed background colour for a 212 | /// selected thumbnail 213 | void SetSelectedThumbnailBackgroundColour(const wxColour &focussedColour, const wxColour &unfocussedColour) { 214 | m_focussedThumbnailBackgroundColour = focussedColour; 215 | m_unfocussedThumbnailBackgroundColour = unfocussedColour; 216 | } 217 | 218 | const wxColour &GetSelectedThumbnailFocussedBackgroundColour() const { return m_focussedThumbnailBackgroundColour; } 219 | 220 | const wxColour & 221 | GetSelectedThumbnailUnfocussedBackgroundColour() const { return m_unfocussedThumbnailBackgroundColour; } 222 | 223 | /// The unselected background colour for a thumbnail 224 | void 225 | SetUnselectedThumbnailBackgroundColour(const wxColour &colour) { m_unselectedThumbnailBackgroundColour = colour; } 226 | 227 | const wxColour &GetUnselectedThumbnailBackgroundColour() const { return m_unselectedThumbnailBackgroundColour; } 228 | 229 | /// The colour for the type text (top left of thumbnail) 230 | void SetTypeColour(const wxColour &colour) { m_typeColour = colour; } 231 | 232 | const wxColour &GetTypeColour() const { return m_typeColour; } 233 | 234 | /// The focus rectangle pen colour 235 | void SetFocusRectColour(const wxColour &colour) { m_focusRectColour = colour; } 236 | 237 | const wxColour &GetFocusRectColour() const { return m_focusRectColour; } 238 | 239 | /// The thumbnail outlines show or not 240 | void ShowOutlines(bool flag = true) { m_showOutlines = flag; } 241 | 242 | bool IsOutlinesShown() const { return m_showOutlines; } 243 | 244 | /// Painting 245 | void OnDraw(wxDC &dc) override; 246 | 247 | protected: 248 | /// Command handlers 249 | void OnSelectAll(wxCommandEvent &event); 250 | 251 | void OnUpdateSelectAll(wxUpdateUIEvent &event); 252 | 253 | /// Mouse-events 254 | void OnMouse(wxMouseEvent &event); 255 | 256 | /// Left-click-down 257 | void OnLeftClickDown(wxMouseEvent &event); 258 | 259 | /// Left-click-up 260 | void OnLeftClickUp(wxMouseEvent &event); 261 | 262 | /// Left-double-click 263 | void OnLeftDClick(wxMouseEvent &event); 264 | 265 | /// Mouse-motion 266 | void OnMouseMotion(wxMouseEvent &event); 267 | 268 | /// Mouse-leave 269 | void OnMouseLeave(wxMouseEvent &event); 270 | 271 | /// Right-click-down 272 | void OnRightClickDown(wxMouseEvent &event); 273 | 274 | /// Right-click-up 275 | void OnRightClickUp(wxMouseEvent &event); 276 | 277 | /// Key press 278 | void OnChar(wxKeyEvent &event); 279 | 280 | /// Sizing 281 | void OnSize(wxSizeEvent &event); 282 | 283 | /// Setting/losing focus 284 | void OnSetFocus(wxFocusEvent &event); 285 | 286 | void OnKillFocus(wxFocusEvent &event); 287 | 288 | // Implementation 289 | 290 | /// Draws the item 291 | bool DrawItem(int n, wxDC &dc, const wxRect &rect, const wxRect &imageRect, int style); 292 | 293 | void SetMouseHoverItem(int n, int flags = 0); 294 | 295 | /// Set up scrollbars, e.g. after a resize 296 | void SetupScrollbars(); 297 | 298 | /// Calculate the outer thumbnail size based 299 | /// on font used for text and inner size 300 | void CalculateOverallThumbnailSize(); 301 | 302 | /// Do (de)selection 303 | void DoSelection(int n, int flags); 304 | 305 | /// Keyboard navigation 306 | virtual bool Navigate(int keyCode, int flags); 307 | 308 | /// Scroll to see the image 309 | void ScrollIntoView(int n, int keyCode); 310 | 311 | /// Paint the background 312 | void PaintBackground(wxDC &dc); 313 | 314 | private: 315 | /// The items 316 | wxThumbnailItemArray m_items; 317 | 318 | /// The selections 319 | wxArrayInt m_selections; 320 | 321 | /// Outer size of the thumbnail item 322 | wxSize m_thumbnailOverallSize; 323 | 324 | /// Image size of the thumbnail item 325 | wxSize m_thumbnailImageSize; 326 | 327 | /// The inter-item spacing 328 | int m_spacing; 329 | 330 | /// The margin between the image/text and the edge of the thumbnail 331 | int m_thumbnailMargin; 332 | 333 | /// The height of thumbnail text in the current font 334 | int m_thumbnailTextHeight; 335 | 336 | /// First selection in a range 337 | int m_firstSelection; 338 | 339 | /// Last selection 340 | int m_lastSelection; 341 | 342 | /// Focus item 343 | int m_focusItem; 344 | 345 | /// Outlines flag 346 | bool m_showOutlines; 347 | 348 | /// Mouse hover item 349 | int m_hoverItem = wxNOT_FOUND; 350 | 351 | /// Focussed/unfocussed selected thumbnail background colours 352 | wxColour m_focussedThumbnailBackgroundColour; 353 | wxColour m_unfocussedThumbnailBackgroundColour; 354 | wxColour m_unselectedThumbnailBackgroundColour; 355 | wxColour m_focusRectColour; 356 | 357 | /// Type text colour 358 | wxColour m_typeColour; 359 | 360 | /// Drag start position 361 | wxPoint m_dragStartPosition = wxDefaultPosition; 362 | }; 363 | 364 | /*! 365 | * wxThumbnailEvent - the event class for wxThumbnailCtrl notifications 366 | */ 367 | 368 | class wxThumbnailEvent : public wxNotifyEvent { 369 | public: 370 | wxThumbnailEvent(wxEventType commandType = wxEVT_NULL, int winid = 0) 371 | : wxNotifyEvent(commandType, winid), 372 | m_itemIndex(-1), m_flags(0) {} 373 | 374 | wxThumbnailEvent(const wxThumbnailEvent &event) 375 | : wxNotifyEvent(event), 376 | m_itemIndex(event.m_itemIndex), m_flags(event.m_flags) {} 377 | 378 | int GetIndex() const { return m_itemIndex; } 379 | 380 | void SetIndex(int n) { m_itemIndex = n; } 381 | 382 | const wxArrayInt &GetItemsIndex() const { return m_itemsIndex; } 383 | 384 | void SetItemsIndex(const wxArrayInt &itemsIndex) { m_itemsIndex = itemsIndex; } 385 | 386 | int GetFlags() const { return m_flags; } 387 | 388 | void SetFlags(int flags) { m_flags = flags; } 389 | 390 | const wxPoint &GetPosition() const { return m_position; } 391 | 392 | void SetPosition(const wxPoint &position) { m_position = position; } 393 | 394 | virtual wxEvent *Clone() const { return new wxThumbnailEvent(*this); } 395 | 396 | protected: 397 | int m_itemIndex; 398 | int m_flags; 399 | wxPoint m_position; 400 | wxArrayInt m_itemsIndex; 401 | 402 | private: 403 | DECLARE_DYNAMIC_CLASS_NO_ASSIGN(wxThumbnailEvent) 404 | }; 405 | 406 | /*! 407 | * wxThumbnailCtrl event macros 408 | */ 409 | wxDECLARE_EVENT(wxEVT_COMMAND_THUMBNAIL_SELECTION_CHANGED, wxThumbnailEvent); 410 | wxDECLARE_EVENT(wxEVT_COMMAND_THUMBNAIL_ITEM_SELECTED, wxThumbnailEvent); 411 | wxDECLARE_EVENT(wxEVT_COMMAND_THUMBNAIL_ITEM_DESELECTED, wxThumbnailEvent); 412 | wxDECLARE_EVENT(wxEVT_COMMAND_THUMBNAIL_LEFT_CLICK, wxThumbnailEvent); 413 | wxDECLARE_EVENT(wxEVT_COMMAND_THUMBNAIL_RIGHT_CLICK, wxThumbnailEvent); 414 | wxDECLARE_EVENT(wxEVT_COMMAND_THUMBNAIL_VIEW_RIGHT_CLICK, wxThumbnailEvent); 415 | wxDECLARE_EVENT(wxEVT_COMMAND_THUMBNAIL_LEFT_DCLICK, wxThumbnailEvent); 416 | wxDECLARE_EVENT(wxEVT_COMMAND_THUMBNAIL_RETURN, wxThumbnailEvent); 417 | wxDECLARE_EVENT(wxEVT_COMMAND_THUMBNAIL_DRAG_START, wxThumbnailEvent); 418 | wxDECLARE_EVENT(wxEVT_COMMAND_THUMBNAIL_SORTED, wxThumbnailEvent); 419 | wxDECLARE_EVENT(wxEVT_COMMAND_THUMBNAIL_ITEM_HOVER_CHANGED, wxThumbnailEvent); 420 | 421 | #endif 422 | // _WX_THUMBNAILCTRL_H_ 423 | -------------------------------------------------------------------------------- /GUI/util.h: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_H 2 | #define UTIL_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | #ifndef WX_PRECOMP 15 | 16 | #include 17 | 18 | #endif 19 | 20 | #include "TargetRepo.h" 21 | 22 | namespace { 23 | auto load_classes(const std::string &path) { 24 | std::vector out; 25 | std::ifstream fp(path); 26 | while (fp) { 27 | std::string cls; 28 | std::getline(fp, cls); 29 | if (!cls.empty()) { 30 | out.push_back(cls); 31 | } 32 | } 33 | fp.close(); 34 | return out; 35 | } 36 | 37 | cv::Scalar color_map(int64_t n) { 38 | auto bit_get = [](int64_t x, int64_t i) { return x & (1 << i); }; 39 | 40 | int64_t r = 0, g = 0, b = 0; 41 | int64_t i = n; 42 | for (int64_t j = 7; j >= 0; --j) { 43 | r |= bit_get(i, 0) << j; 44 | g |= bit_get(i, 1) << j; 45 | b |= bit_get(i, 2) << j; 46 | i >>= 3; 47 | } 48 | return cv::Scalar(b, g, r); 49 | } 50 | 51 | void draw_text(cv::Mat &img, const std::string &str, 52 | const cv::Scalar &color, cv::Point pos, bool reverse = false) { 53 | auto t_size = cv::getTextSize(str, cv::FONT_HERSHEY_PLAIN, 1, 1, nullptr); 54 | cv::Point bottom_left, upper_right; 55 | if (reverse) { 56 | upper_right = pos; 57 | bottom_left = cv::Point(upper_right.x - t_size.width, upper_right.y + t_size.height); 58 | } else { 59 | bottom_left = pos; 60 | upper_right = cv::Point(bottom_left.x + t_size.width, bottom_left.y - t_size.height); 61 | } 62 | 63 | cv::rectangle(img, bottom_left, upper_right, color, -1); 64 | cv::putText(img, str, bottom_left, cv::FONT_HERSHEY_PLAIN, 1, cv::Scalar(255, 255, 255) - color); 65 | } 66 | 67 | cv::Rect2f unnormalize_rect(cv::Rect2f rect, float w, float h) { 68 | rect.x *= w; 69 | rect.y *= h; 70 | rect.width *= w; 71 | rect.height *= h; 72 | return rect; 73 | } 74 | 75 | void draw_bbox(cv::Mat &img, cv::Rect2f bbox, 76 | const std::string &label = "", const cv::Scalar &color = {0, 0, 0}) { 77 | bbox = unnormalize_rect(bbox, img.cols, img.rows); 78 | cv::rectangle(img, bbox, color); 79 | if (!label.empty()) { 80 | draw_text(img, label, color, bbox.tl()); 81 | } 82 | } 83 | 84 | void draw_trajectories(cv::Mat &img, const std::map &traj, int curent, 85 | const cv::Scalar &color = {0, 0, 0}) { 86 | auto it = traj.begin(); 87 | for (; it != traj.end() && it->first < curent - 20; ++it) {} 88 | if (it == traj.end()) return; 89 | 90 | auto cur = it->second; 91 | auto pt1 = cur.br(); 92 | pt1.x -= cur.width / 2; 93 | pt1.x *= img.cols; 94 | pt1.y *= img.rows; 95 | 96 | for (; it != traj.end() && it->first <= curent; ++it) { 97 | cur = it->second; 98 | auto pt2 = cur.br(); 99 | pt2.x -= cur.width / 2; 100 | pt2.x *= img.cols; 101 | pt2.y *= img.rows; 102 | cv::line(img, pt1, pt2, color); 103 | pt1 = pt2; 104 | } 105 | } 106 | 107 | wxImage cvMat2wxImage(cv::Mat mat) { 108 | return wxImage(mat.cols, mat.rows, mat.data, true); 109 | } 110 | } 111 | #endif //UTIL_H 112 | -------------------------------------------------------------------------------- /GUI/win.rc: -------------------------------------------------------------------------------- 1 | #include "wx/msw/wx.rc" 2 | -------------------------------------------------------------------------------- /GUI/wxplayer.cpp: -------------------------------------------------------------------------------- 1 | #include "wxplayer.h" 2 | #include "util.h" 3 | 4 | wxPlayer::wxPlayer(wxWindow *parent, wxWindowID id, 5 | const wxString &file, 6 | const std::function &post) 7 | : wxWindow(parent, id), post(post) { 8 | if (!capture.open(file.ToStdString())) { 9 | throw std::runtime_error("Cannot open video!"); 10 | } 11 | 12 | auto video_size = wxSize(capture.get(cv::CAP_PROP_FRAME_WIDTH), capture.get(cv::CAP_PROP_FRAME_HEIGHT)); 13 | mat = cv::Mat::zeros(video_size.GetWidth(), video_size.GetHeight(), CV_8UC3); 14 | bitmap = new wxGenericStaticBitmap(this, wxID_ANY, wxNullBitmap); 15 | bitmap->Bind(wxEVT_SIZE, [this](wxSizeEvent &) { RescaleToBitmap(); }); 16 | bitmap->Bind(wxEVT_ERASE_BACKGROUND, [](wxEraseEvent& event) {}); 17 | bitmap->SetMinClientSize(video_size / 3); 18 | 19 | timer = new wxTimer(this, ID_Timer); 20 | Bind(wxEVT_TIMER, 21 | [this](wxTimerEvent &) { LoadNext(); }, 22 | ID_Timer); 23 | 24 | auto play = new wxButton(this, ID_Start_Pause, "Start/Pause", wxDefaultPosition, wxDefaultSize, wxBU_EXACTFIT); 25 | auto stop = new wxButton(this, ID_Reload, "Reload", wxDefaultPosition, wxDefaultSize, wxBU_EXACTFIT); 26 | progress = new wxSlider(this, ID_Progress, 0, 0, capture.get(cv::CAP_PROP_FRAME_COUNT) - 1); 27 | auto bar = new wxBoxSizer(wxHORIZONTAL); 28 | bar->Add(play, 0, wxEXPAND | wxALL); 29 | bar->Add(stop, 0, wxEXPAND | wxALL); 30 | bar->Add(progress, 1, wxEXPAND | wxALL); 31 | 32 | Bind(wxEVT_BUTTON, [this](wxCommandEvent &) { 33 | if (timer->IsRunning()) { 34 | timer->Stop(); 35 | } else { 36 | timer->Start(1000 / GetFPS()); 37 | } 38 | }, ID_Start_Pause); 39 | Bind(wxEVT_BUTTON, [this](wxCommandEvent &) { 40 | Seek(0); 41 | }, ID_Reload); 42 | Bind(wxEVT_SLIDER, [this](wxCommandEvent &) { 43 | if (GetFrame() != progress->GetValue()) { 44 | Seek(progress->GetValue()); 45 | } 46 | }, ID_Progress); 47 | 48 | auto sizer = new wxBoxSizer(wxVERTICAL); 49 | sizer->Add(bitmap, 1, wxEXPAND | wxALL); 50 | sizer->Add(bar, 0, wxEXPAND | wxALL); 51 | 52 | SetSizerAndFit(sizer); 53 | 54 | LoadNext(); 55 | } 56 | 57 | void wxPlayer::LoadNext() { 58 | if (capture.read(mat)) { 59 | RescaleToBitmap(); 60 | progress->SetValue(GetFrame()); 61 | } 62 | } 63 | 64 | void wxPlayer::Seek(int frame) { 65 | capture.set(cv::CAP_PROP_POS_FRAMES, frame); 66 | LoadNext(); 67 | } 68 | 69 | void wxPlayer::RescaleToBitmap() { 70 | cv::Mat resized; 71 | auto size = bitmap->GetClientSize(); 72 | cv::resize(mat, resized, {size.GetWidth(), size.GetHeight()}); 73 | post(resized, GetFrame()); 74 | cv::cvtColor(resized, resized, cv::COLOR_BGR2RGB); 75 | bitmap->SetBitmap(wxBitmap(cvMat2wxImage(resized))); 76 | } 77 | -------------------------------------------------------------------------------- /GUI/wxplayer.h: -------------------------------------------------------------------------------- 1 | #ifndef WXPLAYER_H 2 | #define WXPLAYER_H 3 | 4 | #include 5 | 6 | #ifndef WX_PRECOMP 7 | 8 | #include 9 | 10 | #endif 11 | 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | class wxPlayer : public wxWindow { 19 | public: 20 | explicit wxPlayer(wxWindow *parent, wxWindowID id, 21 | const wxString &file, 22 | const std::function &post); 23 | 24 | int GetFrame() { return capture.get(cv::CAP_PROP_POS_FRAMES) - 1; } 25 | 26 | int GetFPS() { return capture.get(cv::CAP_PROP_FPS); } 27 | 28 | bool isPlaying() { return timer->IsRunning(); } 29 | 30 | void Seek(int frame); 31 | 32 | void Refresh() { 33 | Seek(GetFrame()); 34 | } 35 | 36 | private: 37 | void LoadNext(); 38 | 39 | wxGenericStaticBitmap *bitmap = nullptr; 40 | wxSlider *progress = nullptr; 41 | wxTimer *timer = nullptr; 42 | 43 | cv::VideoCapture capture; 44 | 45 | std::function post; 46 | cv::Mat mat; 47 | 48 | void RescaleToBitmap(); 49 | 50 | enum { 51 | ID_Timer = 1, 52 | ID_Start_Pause, 53 | ID_Reload, 54 | ID_Progress 55 | }; 56 | }; 57 | 58 | #endif //WXPLAYER_H 59 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Xu Wei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | *It is for my undergrad thesis in Tsinghua University.* 3 | 4 | There are four modules in the project: 5 | 6 | - Detection: YOLOv3 7 | - Tracking: SORT and DeepSORT 8 | - Processing: Run detection and tracking, then display and save the results (a compressed video, a few snapshots for each target) 9 | - GUI: Display the results 10 | 11 | # YOLOv3 12 | A Libtorch implementation of the YOLO v3 object detection algorithm, written with modern C++. 13 | 14 | The code is based on the [walktree](https://github.com/walktree/libtorch-yolov3). 15 | 16 | The config file in .\models can be found at [Darknet](https://github.com/pjreddie/darknet/tree/master/cfg). 17 | 18 | # SORT 19 | I also merged [SORT](https://github.com/mcximing/sort-cpp) to do tracking. 20 | 21 | A similar software in Python is [here](https://github.com/weixu000/pytorch-yolov3), which also rewrite form [the most starred version](https://github.com/ayooshkathuria/pytorch-yolo-v3) and [SORT](https://github.com/abewley/sort) 22 | 23 | ## DeepSORT 24 | Recently I reimplement [DeepSORT](https://github.com/nwojke/deep_sort) which employs another CNN for re-id. 25 | It seems it gives better result but also slows the program a bit. 26 | Also, a PyTorch version is available at [ZQPei](https://github.com/ZQPei/deep_sort_pytorch), thanks! 27 | 28 | # Performance 29 | Currently on a GTX 1060 6G it consumes about 1G RAM and have 37 FPS. 30 | 31 | The video I test is [TownCentreXVID.avi](http://www.robots.ox.ac.uk/ActiveVision/Research/Projects/2009bbenfold_headpose/Datasets/TownCentreXVID.avi). 32 | 33 | # GUI 34 | With [wxWidgets](https://www.wxwidgets.org/), I developed the GUI module for visualization of results. 35 | 36 | Previously I used [Dear ImGui](https://github.com/ocornut/imgui). 37 | However, I do not think it suits my purpose. 38 | 39 | # Pre-trained network 40 | This project uses pre-trained network weights from others 41 | - [YOLOv3](https://pjreddie.com/media/files/yolov3.weights) 42 | - [YOLOv3-tiny](https://pjreddie.com/media/files/yolov3-tiny.weights) 43 | - [DeepSORT](https://drive.google.com/drive/folders/1xhG0kRH1EX5B9_Iz8gQJb7UNnn_riXi6) 44 | 45 | # How to build 46 | This project requires [LibTorch](https://pytorch.org/), [OpenCV](https://opencv.org/), [wxWidgets](https://www.wxwidgets.org/) and [CMake](https://cmake.org/) to build. 47 | 48 | LibTorch can be easily integrated with CMake, but there are a lot of strange things... 49 | 50 | On Ubuntu 16.04, I use `apt install` to install the others. Everything is fine. 51 | On Windows 10 + Visual Studio 2017, I use the latest stable version of the others from their official websites. 52 | 53 | # Snapshots 54 | Here are some intermediate output from detection and tracking module: 55 | ![Detection](https://github.com/weixu000/libtorch-yolov3-deepsort/blob/master/snapshots/detection.png) 56 | ![Tracking](https://github.com/weixu000/libtorch-yolov3-deepsort/blob/master/snapshots/tracking.png) 57 | 58 | Here is the snapshot of processing module: 59 | ![Processing](https://github.com/weixu000/libtorch-yolov3-deepsort/blob/master/snapshots/UI-online.png) 60 | 61 | Here is the snapshot of GUI module: 62 | ![GUI](https://github.com/weixu000/libtorch-yolov3-deepsort/blob/master/snapshots/UI-offline.png) 63 | -------------------------------------------------------------------------------- /config.h.in: -------------------------------------------------------------------------------- 1 | #ifndef CONFIG_H 2 | #define CONFIG_H 3 | 4 | #include 5 | 6 | // output directory structure 7 | const std::string OUTPUT_DIR = "@OUTPUT_DIR@"; 8 | const std::string TARGETS_DIR_NAME = "@TARGETS_DIR_NAME@"; 9 | const std::string TRAJ_TXT_NAME = "@TRAJ_TXT_NAME@"; 10 | const std::string SNAPSHOTS_DIR_NAME = "@SNAPSHOTS_DIR_NAME@"; 11 | const std::string VIDEO_NAME = "@VIDEO_NAME@"; 12 | 13 | #endif //CONFIG_H 14 | -------------------------------------------------------------------------------- /detection/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(OpenCV REQUIRED) 2 | find_package(Torch REQUIRED) 3 | 4 | aux_source_directory(src DETECTION_SRCS) 5 | add_library(detection SHARED ${DETECTION_SRCS}) 6 | 7 | include(GenerateExportHeader) 8 | GENERATE_EXPORT_HEADER(detection) 9 | 10 | target_link_libraries(detection PUBLIC ${OpenCV_LIBS} PRIVATE "${TORCH_LIBRARIES}") 11 | target_include_directories(detection 12 | PUBLIC include ${CMAKE_CURRENT_BINARY_DIR} 13 | PRIVATE src) -------------------------------------------------------------------------------- /detection/include/Detector.h: -------------------------------------------------------------------------------- 1 | #ifndef DETECTOR_H 2 | #define DETECTOR_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "detection_export.h" 9 | 10 | enum class YOLOType { 11 | YOLOv3, 12 | YOLOv3_TINY 13 | }; 14 | 15 | class DETECTION_EXPORT Detector { 16 | public: 17 | explicit Detector(const std::array &_inp_dim, YOLOType type = YOLOType::YOLOv3); 18 | 19 | ~Detector(); 20 | 21 | std::vector detect(cv::Mat image); 22 | 23 | private: 24 | class Darknet; 25 | 26 | std::unique_ptr net; 27 | 28 | std::array inp_dim; 29 | static const float NMS_threshold; 30 | static const float confidence_threshold; 31 | }; 32 | 33 | #endif //DETECTOR_H 34 | -------------------------------------------------------------------------------- /detection/src/Darknet.cpp: -------------------------------------------------------------------------------- 1 | #include "Darknet.h" 2 | #include "darknet_parsing.h" 3 | 4 | using namespace std; 5 | 6 | struct EmptyLayerImpl : torch::nn::Module { 7 | EmptyLayerImpl() = default; 8 | 9 | torch::Tensor forward(torch::Tensor x) { 10 | return x; 11 | } 12 | }; 13 | 14 | TORCH_MODULE(EmptyLayer); 15 | 16 | struct UpsampleLayerImpl : torch::nn::Module { 17 | int _stride; 18 | 19 | explicit UpsampleLayerImpl(int stride) : _stride(stride) {} 20 | 21 | torch::Tensor forward(torch::Tensor x) { 22 | auto sizes = x.sizes(); 23 | auto w = sizes[2] * _stride; 24 | auto h = sizes[3] * _stride; 25 | 26 | return torch::upsample_nearest2d(x, {w, h}); 27 | } 28 | }; 29 | 30 | TORCH_MODULE(UpsampleLayer); 31 | 32 | struct MaxPoolLayer2DImpl : torch::nn::Module { 33 | int _kernel_size; 34 | int _stride; 35 | 36 | MaxPoolLayer2DImpl(int kernel_size, int stride) : _kernel_size(kernel_size), _stride(stride) {} 37 | 38 | torch::Tensor forward(torch::Tensor x) { 39 | if (_stride != 1) { 40 | x = torch::max_pool2d(x, {_kernel_size, _kernel_size}, {_stride, _stride}); 41 | } else { 42 | auto pad = _kernel_size - 1; 43 | torch::Tensor padded_x = torch::replication_pad2d(x, {0, pad, 0, pad}); 44 | x = torch::max_pool2d(padded_x, {_kernel_size, _kernel_size}, {_stride, _stride}); 45 | } 46 | 47 | return x; 48 | } 49 | }; 50 | 51 | TORCH_MODULE(MaxPoolLayer2D); 52 | 53 | struct DetectionLayerImpl : torch::nn::Module { 54 | torch::Tensor anchors; 55 | std::array grid; 56 | 57 | explicit DetectionLayerImpl(const ::std::vector &_anchors) 58 | : anchors(register_buffer("anchors", 59 | torch::from_blob((void *) _anchors.data(), 60 | {static_cast(_anchors.size() / 2), 2}).clone())), 61 | grid({torch::empty({0}), torch::empty({0})}) {} 62 | 63 | torch::Tensor forward(torch::Tensor prediction, torch::IntArrayRef inp_dim) { 64 | auto grid_size = prediction.sizes().slice(2); 65 | if (grid_size[0] != grid[0].size(0) || grid_size[1] != grid[1].size(0)) { 66 | // update grid if size not match 67 | grid = {torch::arange(grid_size[0], prediction.options()), 68 | torch::arange(grid_size[1], prediction.options())}; 69 | } 70 | 71 | auto batch_size = prediction.size(0); 72 | int64_t stride[] = {inp_dim[0] / grid_size[0], inp_dim[1] / grid_size[1]}; 73 | auto num_anchors = anchors.size(0); 74 | auto bbox_attrs = prediction.size(1) / num_anchors; 75 | prediction = prediction.view({batch_size, num_anchors, bbox_attrs, grid_size[0], grid_size[1]}); 76 | 77 | // sigmoid object confidence 78 | prediction.select(2, 4).sigmoid_(); 79 | 80 | // softmax the class scores 81 | prediction.slice(2, 5) = prediction.slice(2, 5).softmax(-1); 82 | 83 | // sigmoid the centre_X, centre_Y 84 | prediction.select(2, 0).sigmoid_().add_(grid[1].view({1, 1, 1, -1})).mul_(stride[1]); 85 | prediction.select(2, 1).sigmoid_().add_(grid[0].view({1, 1, -1, 1})).mul_(stride[0]); 86 | 87 | // log space transform height and the width 88 | prediction.select(2, 2).exp_().mul_(anchors.select(1, 0).view({1, -1, 1, 1})); 89 | prediction.select(2, 3).exp_().mul_(anchors.select(1, 1).view({1, -1, 1, 1})); 90 | 91 | return prediction.transpose(2, -1).contiguous().view({prediction.size(0), -1, prediction.size(2)}); 92 | } 93 | }; 94 | 95 | TORCH_MODULE(DetectionLayer); 96 | 97 | 98 | Detector::Darknet::Darknet(const string &cfg_file) { 99 | blocks = load_cfg(cfg_file); 100 | 101 | create_modules(); 102 | } 103 | 104 | void Detector::Darknet::load_weights(const string &weight_file) { 105 | ::load_weights(weight_file, blocks, module_list); // TODO: remove this function 106 | } 107 | 108 | // TODO: reimplement the python version 109 | torch::Tensor Detector::Darknet::forward(torch::Tensor x) { 110 | auto inp_dim = x.sizes().slice(2); 111 | auto module_count = module_list.size(); 112 | 113 | std::vector outputs(module_count); 114 | 115 | vector result; 116 | 117 | for (int i = 0; i < module_count; i++) { 118 | auto block = blocks[i + 1]; 119 | 120 | auto layer_type = block["type"]; 121 | 122 | if (layer_type == "net") 123 | continue; 124 | else if (layer_type == "convolutional" || layer_type == "upsample" || layer_type == "maxpool") { 125 | x = module_list[i]->forward(x); 126 | outputs[i] = x; 127 | } else if (layer_type == "route") { 128 | int start = std::stoi(block["start"]); 129 | int end = std::stoi(block["end"]); 130 | 131 | if (start > 0) start = start - i; 132 | 133 | if (end == 0) { 134 | x = outputs[i + start]; 135 | } else { 136 | if (end > 0) end = end - i; 137 | 138 | torch::Tensor map_1 = outputs[i + start]; 139 | torch::Tensor map_2 = outputs[i + end]; 140 | 141 | x = torch::cat({map_1, map_2}, 1); 142 | } 143 | 144 | outputs[i] = x; 145 | } else if (layer_type == "shortcut") { 146 | int from = std::stoi(block["from"]); 147 | x = outputs[i - 1] + outputs[i + from]; 148 | outputs[i] = x; 149 | } else if (layer_type == "yolo") { 150 | x = module_list[i]->forward(x, inp_dim); 151 | result.push_back(x); 152 | outputs[i] = outputs[i - 1]; 153 | } 154 | } 155 | return torch::cat(result, 1); 156 | } 157 | 158 | // TODO: reimplement the python version 159 | void Detector::Darknet::create_modules() { 160 | int prev_filters = 3; 161 | 162 | vector output_filters; 163 | 164 | int index = 0; 165 | 166 | int filters = 0; 167 | 168 | for (auto &block:blocks) { 169 | auto layer_type = block["type"]; 170 | 171 | torch::nn::Sequential module; 172 | 173 | if (layer_type == "net") 174 | continue; 175 | 176 | if (layer_type == "convolutional") { 177 | string activation = get_string_from_cfg(block, "activation", ""); 178 | int batch_normalize = get_int_from_cfg(block, "batch_normalize", 0); 179 | filters = get_int_from_cfg(block, "filters", 0); 180 | int padding = get_int_from_cfg(block, "pad", 0); 181 | int kernel_size = get_int_from_cfg(block, "size", 0); 182 | int stride = get_int_from_cfg(block, "stride", 1); 183 | 184 | int pad = padding > 0 ? (kernel_size - 1) / 2 : 0; 185 | bool with_bias = batch_normalize <= 0; 186 | 187 | torch::nn::Conv2d conv = torch::nn::Conv2d( 188 | conv_options(prev_filters, filters, kernel_size, stride, pad, 1, with_bias)); 189 | module->push_back(conv); 190 | 191 | if (batch_normalize > 0) { 192 | torch::nn::BatchNorm bn = torch::nn::BatchNorm(bn_options(filters)); 193 | module->push_back(bn); 194 | } 195 | 196 | if (activation == "leaky") { 197 | module->push_back(torch::nn::Functional(at::leaky_relu, /*slope=*/0.1)); 198 | } 199 | } else if (layer_type == "upsample") { 200 | int stride = get_int_from_cfg(block, "stride", 1); 201 | 202 | UpsampleLayer uplayer(stride); 203 | module->push_back(uplayer); 204 | } else if (layer_type == "maxpool") { 205 | int stride = get_int_from_cfg(block, "stride", 1); 206 | int size = get_int_from_cfg(block, "size", 1); 207 | 208 | MaxPoolLayer2D poolLayer(size, stride); 209 | module->push_back(poolLayer); 210 | } else if (layer_type == "shortcut") { 211 | // skip connection 212 | int from = get_int_from_cfg(block, "from", 0); 213 | block["from"] = to_string(from); 214 | 215 | // placeholder 216 | EmptyLayer layer; 217 | module->push_back(layer); 218 | } else if (layer_type == "route") { 219 | // L 85: -1, 61 220 | string layers_info = get_string_from_cfg(block, "layers", ""); 221 | 222 | vector layers; 223 | split(layers_info, layers, ","); 224 | 225 | string::size_type sz; 226 | signed int start = stoi(layers[0], &sz); 227 | signed int end = 0; 228 | 229 | if (layers.size() > 1) { 230 | end = stoi(layers[1], &sz); 231 | } 232 | 233 | if (start > 0) start = start - index; 234 | 235 | if (end > 0) end = end - index; 236 | 237 | block["start"] = to_string(start); 238 | block["end"] = to_string(end); 239 | 240 | // placeholder 241 | EmptyLayer layer; 242 | module->push_back(layer); 243 | 244 | if (end < 0) { 245 | filters = output_filters[index + start] + output_filters[index + end]; 246 | } else { 247 | filters = output_filters[index + start]; 248 | } 249 | } else if (layer_type == "yolo") { 250 | string mask_info = get_string_from_cfg(block, "mask", ""); 251 | vector masks; 252 | split(mask_info, masks, ","); 253 | 254 | string anchor_info = get_string_from_cfg(block, "anchors", ""); 255 | vector anchors; 256 | split(anchor_info, anchors, ","); 257 | 258 | vector anchor_points; 259 | for (auto mask : masks) { 260 | anchor_points.push_back(anchors[mask * 2]); 261 | anchor_points.push_back(anchors[mask * 2 + 1]); 262 | } 263 | 264 | DetectionLayer layer(anchor_points); 265 | module->push_back(layer); 266 | } else { 267 | cout << "unsupported operator:" << layer_type << endl; 268 | } 269 | 270 | prev_filters = filters; 271 | output_filters.push_back(filters); 272 | module_list.push_back(module); 273 | 274 | register_module("layer_" + to_string(index), module); 275 | 276 | index += 1; 277 | } 278 | } 279 | -------------------------------------------------------------------------------- /detection/src/Darknet.h: -------------------------------------------------------------------------------- 1 | #ifndef DARKNET_H 2 | #define DARKNET_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "Detector.h" 10 | 11 | struct Detector::Darknet : torch::nn::Module { 12 | public: 13 | explicit Darknet(const std::string &cfg_file); 14 | 15 | const std::map &net_info() { 16 | assert(!blocks.empty() && blocks[0]["type"] == "net"); 17 | return blocks[0]; 18 | } 19 | 20 | void load_weights(const std::string &weight_file); 21 | 22 | torch::Tensor forward(torch::Tensor x); 23 | 24 | private: 25 | std::vector> blocks; 26 | 27 | std::vector module_list; 28 | 29 | void create_modules(); 30 | }; 31 | 32 | #endif //DARKNET_H -------------------------------------------------------------------------------- /detection/src/Detector.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "Detector.h" 4 | #include "Darknet.h" 5 | #include "letterbox.h" 6 | 7 | namespace { 8 | void center_to_corner(torch::Tensor bbox) { 9 | bbox.select(1, 0) -= bbox.select(1, 2) / 2; 10 | bbox.select(1, 1) -= bbox.select(1, 3) / 2; 11 | } 12 | 13 | auto threshold_confidence(torch::Tensor pred, float threshold) { 14 | auto[max_cls_score, max_cls] = pred.slice(1, 5).max(1); 15 | max_cls_score *= pred.select(1, 4); 16 | auto prob_thresh = max_cls_score > threshold; 17 | 18 | pred = pred.slice(1, 0, 4); 19 | 20 | auto index = prob_thresh.nonzero().squeeze_(); 21 | return std::make_tuple(pred.index_select(0, index), 22 | max_cls.index_select(0, index), 23 | max_cls_score.index_select(0, index)); 24 | } 25 | 26 | float iou(const cv::Rect2f &bb_test, const cv::Rect2f &bb_gt) { 27 | auto in = (bb_test & bb_gt).area(); 28 | auto un = bb_test.area() + bb_gt.area() - in; 29 | 30 | return in / un; 31 | } 32 | 33 | struct Detection { 34 | cv::Rect2f bbox; 35 | float scr; 36 | }; 37 | 38 | void NMS(std::vector &dets, float threshold) { 39 | std::sort(dets.begin(), dets.end(), 40 | [](const Detection &a, const Detection &b) { return a.scr > b.scr; }); 41 | 42 | for (size_t i = 0; i < dets.size(); ++i) { 43 | dets.erase(std::remove_if(dets.begin() + i + 1, dets.end(), 44 | [&](const Detection &d) { 45 | return iou(dets[i].bbox, d.bbox) > threshold; 46 | }), 47 | dets.end()); 48 | } 49 | } 50 | } 51 | 52 | const float Detector::NMS_threshold = 0.4f; 53 | const float Detector::confidence_threshold = 0.1f; 54 | 55 | Detector::Detector(const std::array &_inp_dim, YOLOType type) { 56 | switch (type) { 57 | case YOLOType::YOLOv3: 58 | net = std::make_unique("models/yolov3.cfg"); 59 | net->load_weights("weights/yolov3.weights"); 60 | break; 61 | case YOLOType::YOLOv3_TINY: 62 | net = std::make_unique("models/yolov3-tiny.cfg"); 63 | net->load_weights("weights/yolov3-tiny.weights"); 64 | break; 65 | default: 66 | break; 67 | } 68 | net->to(torch::kCUDA); 69 | net->eval(); 70 | 71 | inp_dim = _inp_dim; 72 | } 73 | 74 | Detector::~Detector() = default; 75 | 76 | std::vector Detector::detect(cv::Mat image) { 77 | torch::NoGradGuard no_grad; 78 | 79 | int64_t orig_dim[] = {image.rows, image.cols}; 80 | image = letterbox_img(image, inp_dim); 81 | cv::cvtColor(image, image, cv::COLOR_RGB2BGR); 82 | image.convertTo(image, CV_32F, 1.0 / 255); 83 | 84 | auto img_tensor = torch::from_blob(image.data, {1, inp_dim[0], inp_dim[1], 3}) 85 | .permute({0, 3, 1, 2}).to(torch::kCUDA); 86 | auto prediction = net->forward(img_tensor).squeeze_(0); 87 | auto[bbox, cls, scr] = threshold_confidence(prediction, confidence_threshold); 88 | bbox = bbox.cpu(); 89 | cls = cls.cpu(); 90 | scr = scr.cpu(); 91 | 92 | auto cls_mask = cls == 0; 93 | bbox = bbox.index_select(0, cls_mask.nonzero().squeeze_()); 94 | scr = scr.masked_select(cls_mask); 95 | 96 | center_to_corner(bbox); 97 | inv_letterbox_bbox(bbox, inp_dim, orig_dim); 98 | 99 | auto bbox_acc = bbox.accessor(); 100 | auto scr_acc = scr.accessor(); 101 | std::vector dets; 102 | for (int64_t i = 0; i < bbox_acc.size(0); ++i) { 103 | auto d = Detection{cv::Rect2f(bbox_acc[i][0], bbox_acc[i][1], bbox_acc[i][2], bbox_acc[i][3]), 104 | scr_acc[i]}; 105 | dets.emplace_back(d); 106 | } 107 | 108 | NMS(dets, NMS_threshold); 109 | 110 | auto img_box = cv::Rect2f(0, 0, orig_dim[1], orig_dim[0]); 111 | std::vector out; 112 | for (auto &d:dets) { 113 | out.push_back(d.bbox & img_box); 114 | } 115 | 116 | return out; 117 | } 118 | -------------------------------------------------------------------------------- /detection/src/darknet_parsing.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "darknet_parsing.h" 4 | 5 | using namespace std; 6 | 7 | namespace { 8 | // trim from start (in place) 9 | void ltrim(string &s) { 10 | s.erase(s.begin(), find_if(s.begin(), s.end(), [](char ch) { 11 | return !isspace(ch); 12 | })); 13 | } 14 | 15 | // trim from end (in place) 16 | void rtrim(string &s) { 17 | s.erase(find_if(s.rbegin(), s.rend(), [](char ch) { 18 | return !isspace(ch); 19 | }).base(), s.end()); 20 | } 21 | 22 | // trim from both ends (in place) 23 | void trim(string &s) { 24 | ltrim(s); 25 | rtrim(s); 26 | } 27 | 28 | void load_tensor(torch::Tensor t, ifstream &fs) { 29 | fs.read(static_cast(t.data_ptr()), t.numel() * sizeof(float)); 30 | } 31 | } 32 | 33 | int split(const string &str, vector &ret_, string sep) { 34 | if (str.empty()) { 35 | return 0; 36 | } 37 | 38 | string tmp; 39 | string::size_type pos_begin = str.find_first_not_of(sep); 40 | string::size_type comma_pos = 0; 41 | 42 | while (pos_begin != string::npos) { 43 | comma_pos = str.find(sep, pos_begin); 44 | if (comma_pos != string::npos) { 45 | tmp = str.substr(pos_begin, comma_pos - pos_begin); 46 | pos_begin = comma_pos + sep.length(); 47 | } else { 48 | tmp = str.substr(pos_begin); 49 | pos_begin = comma_pos; 50 | } 51 | 52 | if (!tmp.empty()) { 53 | trim(tmp); 54 | ret_.push_back(tmp); 55 | tmp.clear(); 56 | } 57 | } 58 | return 0; 59 | } 60 | 61 | int split(const string &str, vector &ret_, string sep) { 62 | vector tmp; 63 | split(str, tmp, sep); 64 | 65 | for (int i = 0; i < tmp.size(); i++) { 66 | ret_.push_back(stoi(tmp[i])); 67 | } 68 | return ret_.size(); 69 | } 70 | 71 | int get_int_from_cfg(map block, string key, int default_value) { 72 | if (block.find(key) != block.end()) { 73 | return stoi(block.at(key)); 74 | } 75 | return default_value; 76 | } 77 | 78 | string get_string_from_cfg(map block, string key, string default_value) { 79 | if (block.find(key) != block.end()) { 80 | return block.at(key); 81 | } 82 | return default_value; 83 | } 84 | 85 | torch::nn::Conv2dOptions conv_options(int64_t in_planes, int64_t out_planes, int64_t kerner_size, 86 | int64_t stride, int64_t padding, int64_t groups, bool with_bias) { 87 | torch::nn::Conv2dOptions conv_options = torch::nn::Conv2dOptions(in_planes, out_planes, kerner_size); 88 | conv_options.stride_ = stride; 89 | conv_options.padding_ = padding; 90 | conv_options.groups_ = groups; 91 | conv_options.with_bias_ = with_bias; 92 | return conv_options; 93 | } 94 | 95 | torch::nn::BatchNormOptions bn_options(int64_t features) { 96 | torch::nn::BatchNormOptions bn_options = torch::nn::BatchNormOptions(features); 97 | bn_options.affine_ = true; 98 | bn_options.stateful_ = true; 99 | return bn_options; 100 | } 101 | 102 | Blocks load_cfg(const string &cfg_file) { 103 | ifstream fs(cfg_file); 104 | string line; 105 | 106 | Blocks blocks; 107 | 108 | if (!fs) { 109 | throw "Fail to load cfg file"; 110 | } 111 | 112 | while (getline(fs, line)) { 113 | trim(line); 114 | 115 | if (line.empty()) { 116 | continue; 117 | } 118 | 119 | if (line.substr(0, 1) == "[") { 120 | map block; 121 | 122 | string key = line.substr(1, line.length() - 2); 123 | block["type"] = key; 124 | 125 | blocks.push_back(block); 126 | } else { 127 | auto &block = blocks.back(); 128 | 129 | vector op_info; 130 | 131 | split(line, op_info, "="); 132 | 133 | if (op_info.size() == 2) { 134 | string p_key = op_info[0]; 135 | string p_value = op_info[1]; 136 | block[p_key] = p_value; 137 | } 138 | } 139 | } 140 | fs.close(); 141 | 142 | return blocks; 143 | } 144 | 145 | void load_weights(const string &weight_file, const Blocks &blocks, vector &module_list) { 146 | ifstream fs(weight_file, ios_base::binary); 147 | if (!fs) { 148 | throw std::runtime_error("No weight file for Darknet!"); 149 | } 150 | 151 | fs.seekg(sizeof(int32_t) * 5, ios_base::beg); 152 | 153 | for (size_t i = 0; i < module_list.size(); i++) { 154 | auto &module_info = blocks[i + 1]; 155 | 156 | // only conv layer need to load weight 157 | if (module_info.at("type") != "convolutional") continue; 158 | 159 | auto seq_module = module_list[i]; 160 | 161 | auto conv = dynamic_pointer_cast(seq_module[0]); 162 | 163 | if (get_int_from_cfg(module_info, "batch_normalize", 0)) { 164 | // second module 165 | auto bn = dynamic_pointer_cast(seq_module[1]); 166 | 167 | load_tensor(bn->bias, fs); 168 | load_tensor(bn->weight, fs); 169 | load_tensor(bn->running_mean, fs); 170 | load_tensor(bn->running_var, fs); 171 | } else { 172 | load_tensor(conv->bias, fs); 173 | } 174 | load_tensor(conv->weight, fs); 175 | } 176 | } 177 | 178 | -------------------------------------------------------------------------------- /detection/src/darknet_parsing.h: -------------------------------------------------------------------------------- 1 | #ifndef DARKNET_PARSING_H 2 | #define DARKNET_PARSING_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | int split(const std::string &str, std::vector &ret_, std::string sep = ","); 9 | 10 | int split(const std::string &str, std::vector &ret_, std::string sep = ","); 11 | 12 | int get_int_from_cfg(std::map block, std::string key, int default_value); 13 | 14 | std::string get_string_from_cfg(std::map block, std::string key, std::string default_value); 15 | 16 | torch::nn::Conv2dOptions conv_options(int64_t in_planes, int64_t out_planes, int64_t kerner_size, 17 | int64_t stride, int64_t padding, int64_t groups, bool with_bias = false); 18 | 19 | torch::nn::BatchNormOptions bn_options(int64_t features); 20 | 21 | using Blocks = std::vector>; 22 | 23 | Blocks load_cfg(const std::string &cfg_file); 24 | 25 | void load_weights(const std::string &weight_file, const Blocks &blocks, 26 | std::vector &module_list); 27 | 28 | #endif //DARKNET_PARSING_H 29 | -------------------------------------------------------------------------------- /detection/src/letterbox.h: -------------------------------------------------------------------------------- 1 | #ifndef LETTERBOX_H 2 | #define LETTERBOX_H 3 | 4 | #include 5 | #include 6 | 7 | static inline std::array letterbox_dim(torch::IntArrayRef img, torch::IntArrayRef box) { 8 | auto h = box[0], w = box[1]; 9 | auto img_h = img[0], img_w = img[1]; 10 | auto s = std::min(1.0f * w / img_w, 1.0f * h / img_h); 11 | return std::array{int64_t(img_h * s), int64_t(img_w * s)}; 12 | } 13 | 14 | static inline cv::Mat letterbox_img(const cv::Mat &img, torch::IntArrayRef box) { 15 | auto h = box[0], w = box[1]; 16 | auto[new_h, new_w] = letterbox_dim({img.rows, img.cols}, box); 17 | 18 | cv::Mat out = (cv::Mat::zeros(h, w, CV_8UC3) + 1) * 128; 19 | 20 | cv::resize(img, 21 | out({int((h - new_h) / 2), int((h - new_h) / 2 + new_h)}, 22 | {int((w - new_w) / 2), int((w - new_w) / 2 + new_w)}), 23 | {int(new_w), int(new_h)}, 24 | 0, 0, cv::INTER_CUBIC); 25 | return out; 26 | } 27 | 28 | static inline void inv_letterbox_bbox(torch::Tensor bbox, torch::IntArrayRef box_dim, torch::IntArrayRef img_dim) { 29 | auto img_h = img_dim[0], img_w = img_dim[1]; 30 | auto h = box_dim[0], w = box_dim[1]; 31 | auto[new_h, new_w] = letterbox_dim(img_dim, box_dim); 32 | 33 | bbox.select(1, 0).add_(-(w - new_w) / 2).mul_(1.0f * img_w / new_w); 34 | bbox.select(1, 2).mul_(1.0f * img_w / new_w); 35 | 36 | bbox.select(1, 1).add_(-(h - new_h) / 2).mul_(1.0f * img_h / new_h); 37 | bbox.select(1, 3).mul_(1.0f * img_h / new_h); 38 | } 39 | 40 | #endif //LETTERBOX_H 41 | -------------------------------------------------------------------------------- /models/yolov3-tiny.cfg: -------------------------------------------------------------------------------- 1 | [net] 2 | # Testing 3 | batch=1 4 | subdivisions=1 5 | # Training 6 | # batch=64 7 | # subdivisions=2 8 | width=416 9 | height=416 10 | channels=3 11 | momentum=0.9 12 | decay=0.0005 13 | angle=0 14 | saturation = 1.5 15 | exposure = 1.5 16 | hue=.1 17 | 18 | learning_rate=0.001 19 | burn_in=1000 20 | max_batches = 500200 21 | policy=steps 22 | steps=400000,450000 23 | scales=.1,.1 24 | 25 | [convolutional] 26 | batch_normalize=1 27 | filters=16 28 | size=3 29 | stride=1 30 | pad=1 31 | activation=leaky 32 | 33 | [maxpool] 34 | size=2 35 | stride=2 36 | 37 | [convolutional] 38 | batch_normalize=1 39 | filters=32 40 | size=3 41 | stride=1 42 | pad=1 43 | activation=leaky 44 | 45 | [maxpool] 46 | size=2 47 | stride=2 48 | 49 | [convolutional] 50 | batch_normalize=1 51 | filters=64 52 | size=3 53 | stride=1 54 | pad=1 55 | activation=leaky 56 | 57 | [maxpool] 58 | size=2 59 | stride=2 60 | 61 | [convolutional] 62 | batch_normalize=1 63 | filters=128 64 | size=3 65 | stride=1 66 | pad=1 67 | activation=leaky 68 | 69 | [maxpool] 70 | size=2 71 | stride=2 72 | 73 | [convolutional] 74 | batch_normalize=1 75 | filters=256 76 | size=3 77 | stride=1 78 | pad=1 79 | activation=leaky 80 | 81 | [maxpool] 82 | size=2 83 | stride=2 84 | 85 | [convolutional] 86 | batch_normalize=1 87 | filters=512 88 | size=3 89 | stride=1 90 | pad=1 91 | activation=leaky 92 | 93 | [maxpool] 94 | size=2 95 | stride=1 96 | 97 | [convolutional] 98 | batch_normalize=1 99 | filters=1024 100 | size=3 101 | stride=1 102 | pad=1 103 | activation=leaky 104 | 105 | ########### 106 | 107 | [convolutional] 108 | batch_normalize=1 109 | filters=256 110 | size=1 111 | stride=1 112 | pad=1 113 | activation=leaky 114 | 115 | [convolutional] 116 | batch_normalize=1 117 | filters=512 118 | size=3 119 | stride=1 120 | pad=1 121 | activation=leaky 122 | 123 | [convolutional] 124 | size=1 125 | stride=1 126 | pad=1 127 | filters=255 128 | activation=linear 129 | 130 | 131 | 132 | [yolo] 133 | mask = 3,4,5 134 | anchors = 10,14, 23,27, 37,58, 81,82, 135,169, 344,319 135 | classes=80 136 | num=6 137 | jitter=.3 138 | ignore_thresh = .7 139 | truth_thresh = 1 140 | random=1 141 | 142 | [route] 143 | layers = -4 144 | 145 | [convolutional] 146 | batch_normalize=1 147 | filters=128 148 | size=1 149 | stride=1 150 | pad=1 151 | activation=leaky 152 | 153 | [upsample] 154 | stride=2 155 | 156 | [route] 157 | layers = -1, 8 158 | 159 | [convolutional] 160 | batch_normalize=1 161 | filters=256 162 | size=3 163 | stride=1 164 | pad=1 165 | activation=leaky 166 | 167 | [convolutional] 168 | size=1 169 | stride=1 170 | pad=1 171 | filters=255 172 | activation=linear 173 | 174 | [yolo] 175 | mask = 0,1,2 176 | anchors = 10,14, 23,27, 37,58, 81,82, 135,169, 344,319 177 | classes=80 178 | num=6 179 | jitter=.3 180 | ignore_thresh = .7 181 | truth_thresh = 1 182 | random=1 -------------------------------------------------------------------------------- /models/yolov3.cfg: -------------------------------------------------------------------------------- 1 | [net] 2 | # Testing 3 | # batch=1 4 | # subdivisions=1 5 | # Training 6 | batch = 64 7 | subdivisions = 16 8 | width = 416 9 | height = 416 10 | channels=3 11 | momentum=0.9 12 | decay=0.0005 13 | angle=0 14 | saturation = 1.5 15 | exposure = 1.5 16 | hue=.1 17 | 18 | learning_rate=0.001 19 | burn_in=1000 20 | max_batches = 500200 21 | policy=steps 22 | steps=400000,450000 23 | scales=.1,.1 24 | 25 | [convolutional] 26 | batch_normalize=1 27 | filters=32 28 | size=3 29 | stride=1 30 | pad=1 31 | activation=leaky 32 | 33 | # Downsample 34 | 35 | [convolutional] 36 | batch_normalize=1 37 | filters=64 38 | size=3 39 | stride=2 40 | pad=1 41 | activation=leaky 42 | 43 | [convolutional] 44 | batch_normalize=1 45 | filters=32 46 | size=1 47 | stride=1 48 | pad=1 49 | activation=leaky 50 | 51 | [convolutional] 52 | batch_normalize=1 53 | filters=64 54 | size=3 55 | stride=1 56 | pad=1 57 | activation=leaky 58 | 59 | [shortcut] 60 | from=-3 61 | activation=linear 62 | 63 | # Downsample 64 | 65 | [convolutional] 66 | batch_normalize=1 67 | filters=128 68 | size=3 69 | stride=2 70 | pad=1 71 | activation=leaky 72 | 73 | [convolutional] 74 | batch_normalize=1 75 | filters=64 76 | size=1 77 | stride=1 78 | pad=1 79 | activation=leaky 80 | 81 | [convolutional] 82 | batch_normalize=1 83 | filters=128 84 | size=3 85 | stride=1 86 | pad=1 87 | activation=leaky 88 | 89 | [shortcut] 90 | from=-3 91 | activation=linear 92 | 93 | [convolutional] 94 | batch_normalize=1 95 | filters=64 96 | size=1 97 | stride=1 98 | pad=1 99 | activation=leaky 100 | 101 | [convolutional] 102 | batch_normalize=1 103 | filters=128 104 | size=3 105 | stride=1 106 | pad=1 107 | activation=leaky 108 | 109 | [shortcut] 110 | from=-3 111 | activation=linear 112 | 113 | # Downsample 114 | 115 | [convolutional] 116 | batch_normalize=1 117 | filters=256 118 | size=3 119 | stride=2 120 | pad=1 121 | activation=leaky 122 | 123 | [convolutional] 124 | batch_normalize=1 125 | filters=128 126 | size=1 127 | stride=1 128 | pad=1 129 | activation=leaky 130 | 131 | [convolutional] 132 | batch_normalize=1 133 | filters=256 134 | size=3 135 | stride=1 136 | pad=1 137 | activation=leaky 138 | 139 | [shortcut] 140 | from=-3 141 | activation=linear 142 | 143 | [convolutional] 144 | batch_normalize=1 145 | filters=128 146 | size=1 147 | stride=1 148 | pad=1 149 | activation=leaky 150 | 151 | [convolutional] 152 | batch_normalize=1 153 | filters=256 154 | size=3 155 | stride=1 156 | pad=1 157 | activation=leaky 158 | 159 | [shortcut] 160 | from=-3 161 | activation=linear 162 | 163 | [convolutional] 164 | batch_normalize=1 165 | filters=128 166 | size=1 167 | stride=1 168 | pad=1 169 | activation=leaky 170 | 171 | [convolutional] 172 | batch_normalize=1 173 | filters=256 174 | size=3 175 | stride=1 176 | pad=1 177 | activation=leaky 178 | 179 | [shortcut] 180 | from=-3 181 | activation=linear 182 | 183 | [convolutional] 184 | batch_normalize=1 185 | filters=128 186 | size=1 187 | stride=1 188 | pad=1 189 | activation=leaky 190 | 191 | [convolutional] 192 | batch_normalize=1 193 | filters=256 194 | size=3 195 | stride=1 196 | pad=1 197 | activation=leaky 198 | 199 | [shortcut] 200 | from=-3 201 | activation=linear 202 | 203 | 204 | [convolutional] 205 | batch_normalize=1 206 | filters=128 207 | size=1 208 | stride=1 209 | pad=1 210 | activation=leaky 211 | 212 | [convolutional] 213 | batch_normalize=1 214 | filters=256 215 | size=3 216 | stride=1 217 | pad=1 218 | activation=leaky 219 | 220 | [shortcut] 221 | from=-3 222 | activation=linear 223 | 224 | [convolutional] 225 | batch_normalize=1 226 | filters=128 227 | size=1 228 | stride=1 229 | pad=1 230 | activation=leaky 231 | 232 | [convolutional] 233 | batch_normalize=1 234 | filters=256 235 | size=3 236 | stride=1 237 | pad=1 238 | activation=leaky 239 | 240 | [shortcut] 241 | from=-3 242 | activation=linear 243 | 244 | [convolutional] 245 | batch_normalize=1 246 | filters=128 247 | size=1 248 | stride=1 249 | pad=1 250 | activation=leaky 251 | 252 | [convolutional] 253 | batch_normalize=1 254 | filters=256 255 | size=3 256 | stride=1 257 | pad=1 258 | activation=leaky 259 | 260 | [shortcut] 261 | from=-3 262 | activation=linear 263 | 264 | [convolutional] 265 | batch_normalize=1 266 | filters=128 267 | size=1 268 | stride=1 269 | pad=1 270 | activation=leaky 271 | 272 | [convolutional] 273 | batch_normalize=1 274 | filters=256 275 | size=3 276 | stride=1 277 | pad=1 278 | activation=leaky 279 | 280 | [shortcut] 281 | from=-3 282 | activation=linear 283 | 284 | # Downsample 285 | 286 | [convolutional] 287 | batch_normalize=1 288 | filters=512 289 | size=3 290 | stride=2 291 | pad=1 292 | activation=leaky 293 | 294 | [convolutional] 295 | batch_normalize=1 296 | filters=256 297 | size=1 298 | stride=1 299 | pad=1 300 | activation=leaky 301 | 302 | [convolutional] 303 | batch_normalize=1 304 | filters=512 305 | size=3 306 | stride=1 307 | pad=1 308 | activation=leaky 309 | 310 | [shortcut] 311 | from=-3 312 | activation=linear 313 | 314 | 315 | [convolutional] 316 | batch_normalize=1 317 | filters=256 318 | size=1 319 | stride=1 320 | pad=1 321 | activation=leaky 322 | 323 | [convolutional] 324 | batch_normalize=1 325 | filters=512 326 | size=3 327 | stride=1 328 | pad=1 329 | activation=leaky 330 | 331 | [shortcut] 332 | from=-3 333 | activation=linear 334 | 335 | 336 | [convolutional] 337 | batch_normalize=1 338 | filters=256 339 | size=1 340 | stride=1 341 | pad=1 342 | activation=leaky 343 | 344 | [convolutional] 345 | batch_normalize=1 346 | filters=512 347 | size=3 348 | stride=1 349 | pad=1 350 | activation=leaky 351 | 352 | [shortcut] 353 | from=-3 354 | activation=linear 355 | 356 | 357 | [convolutional] 358 | batch_normalize=1 359 | filters=256 360 | size=1 361 | stride=1 362 | pad=1 363 | activation=leaky 364 | 365 | [convolutional] 366 | batch_normalize=1 367 | filters=512 368 | size=3 369 | stride=1 370 | pad=1 371 | activation=leaky 372 | 373 | [shortcut] 374 | from=-3 375 | activation=linear 376 | 377 | [convolutional] 378 | batch_normalize=1 379 | filters=256 380 | size=1 381 | stride=1 382 | pad=1 383 | activation=leaky 384 | 385 | [convolutional] 386 | batch_normalize=1 387 | filters=512 388 | size=3 389 | stride=1 390 | pad=1 391 | activation=leaky 392 | 393 | [shortcut] 394 | from=-3 395 | activation=linear 396 | 397 | 398 | [convolutional] 399 | batch_normalize=1 400 | filters=256 401 | size=1 402 | stride=1 403 | pad=1 404 | activation=leaky 405 | 406 | [convolutional] 407 | batch_normalize=1 408 | filters=512 409 | size=3 410 | stride=1 411 | pad=1 412 | activation=leaky 413 | 414 | [shortcut] 415 | from=-3 416 | activation=linear 417 | 418 | 419 | [convolutional] 420 | batch_normalize=1 421 | filters=256 422 | size=1 423 | stride=1 424 | pad=1 425 | activation=leaky 426 | 427 | [convolutional] 428 | batch_normalize=1 429 | filters=512 430 | size=3 431 | stride=1 432 | pad=1 433 | activation=leaky 434 | 435 | [shortcut] 436 | from=-3 437 | activation=linear 438 | 439 | [convolutional] 440 | batch_normalize=1 441 | filters=256 442 | size=1 443 | stride=1 444 | pad=1 445 | activation=leaky 446 | 447 | [convolutional] 448 | batch_normalize=1 449 | filters=512 450 | size=3 451 | stride=1 452 | pad=1 453 | activation=leaky 454 | 455 | [shortcut] 456 | from=-3 457 | activation=linear 458 | 459 | # Downsample 460 | 461 | [convolutional] 462 | batch_normalize=1 463 | filters=1024 464 | size=3 465 | stride=2 466 | pad=1 467 | activation=leaky 468 | 469 | [convolutional] 470 | batch_normalize=1 471 | filters=512 472 | size=1 473 | stride=1 474 | pad=1 475 | activation=leaky 476 | 477 | [convolutional] 478 | batch_normalize=1 479 | filters=1024 480 | size=3 481 | stride=1 482 | pad=1 483 | activation=leaky 484 | 485 | [shortcut] 486 | from=-3 487 | activation=linear 488 | 489 | [convolutional] 490 | batch_normalize=1 491 | filters=512 492 | size=1 493 | stride=1 494 | pad=1 495 | activation=leaky 496 | 497 | [convolutional] 498 | batch_normalize=1 499 | filters=1024 500 | size=3 501 | stride=1 502 | pad=1 503 | activation=leaky 504 | 505 | [shortcut] 506 | from=-3 507 | activation=linear 508 | 509 | [convolutional] 510 | batch_normalize=1 511 | filters=512 512 | size=1 513 | stride=1 514 | pad=1 515 | activation=leaky 516 | 517 | [convolutional] 518 | batch_normalize=1 519 | filters=1024 520 | size=3 521 | stride=1 522 | pad=1 523 | activation=leaky 524 | 525 | [shortcut] 526 | from=-3 527 | activation=linear 528 | 529 | [convolutional] 530 | batch_normalize=1 531 | filters=512 532 | size=1 533 | stride=1 534 | pad=1 535 | activation=leaky 536 | 537 | [convolutional] 538 | batch_normalize=1 539 | filters=1024 540 | size=3 541 | stride=1 542 | pad=1 543 | activation=leaky 544 | 545 | [shortcut] 546 | from=-3 547 | activation=linear 548 | 549 | ###################### 550 | 551 | [convolutional] 552 | batch_normalize=1 553 | filters=512 554 | size=1 555 | stride=1 556 | pad=1 557 | activation=leaky 558 | 559 | [convolutional] 560 | batch_normalize=1 561 | size=3 562 | stride=1 563 | pad=1 564 | filters=1024 565 | activation=leaky 566 | 567 | [convolutional] 568 | batch_normalize=1 569 | filters=512 570 | size=1 571 | stride=1 572 | pad=1 573 | activation=leaky 574 | 575 | [convolutional] 576 | batch_normalize=1 577 | size=3 578 | stride=1 579 | pad=1 580 | filters=1024 581 | activation=leaky 582 | 583 | [convolutional] 584 | batch_normalize=1 585 | filters=512 586 | size=1 587 | stride=1 588 | pad=1 589 | activation=leaky 590 | 591 | [convolutional] 592 | batch_normalize=1 593 | size=3 594 | stride=1 595 | pad=1 596 | filters=1024 597 | activation=leaky 598 | 599 | [convolutional] 600 | size=1 601 | stride=1 602 | pad=1 603 | filters=255 604 | activation=linear 605 | 606 | 607 | [yolo] 608 | mask = 6,7,8 609 | anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326 610 | classes=80 611 | num=9 612 | jitter=.3 613 | ignore_thresh = .7 614 | truth_thresh = 1 615 | random=1 616 | 617 | 618 | [route] 619 | layers = -4 620 | 621 | [convolutional] 622 | batch_normalize=1 623 | filters=256 624 | size=1 625 | stride=1 626 | pad=1 627 | activation=leaky 628 | 629 | [upsample] 630 | stride=2 631 | 632 | [route] 633 | layers = -1, 61 634 | 635 | 636 | 637 | [convolutional] 638 | batch_normalize=1 639 | filters=256 640 | size=1 641 | stride=1 642 | pad=1 643 | activation=leaky 644 | 645 | [convolutional] 646 | batch_normalize=1 647 | size=3 648 | stride=1 649 | pad=1 650 | filters=512 651 | activation=leaky 652 | 653 | [convolutional] 654 | batch_normalize=1 655 | filters=256 656 | size=1 657 | stride=1 658 | pad=1 659 | activation=leaky 660 | 661 | [convolutional] 662 | batch_normalize=1 663 | size=3 664 | stride=1 665 | pad=1 666 | filters=512 667 | activation=leaky 668 | 669 | [convolutional] 670 | batch_normalize=1 671 | filters=256 672 | size=1 673 | stride=1 674 | pad=1 675 | activation=leaky 676 | 677 | [convolutional] 678 | batch_normalize=1 679 | size=3 680 | stride=1 681 | pad=1 682 | filters=512 683 | activation=leaky 684 | 685 | [convolutional] 686 | size=1 687 | stride=1 688 | pad=1 689 | filters=255 690 | activation=linear 691 | 692 | 693 | [yolo] 694 | mask = 3,4,5 695 | anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326 696 | classes=80 697 | num=9 698 | jitter=.3 699 | ignore_thresh = .7 700 | truth_thresh = 1 701 | random=1 702 | 703 | 704 | 705 | [route] 706 | layers = -4 707 | 708 | [convolutional] 709 | batch_normalize=1 710 | filters=128 711 | size=1 712 | stride=1 713 | pad=1 714 | activation=leaky 715 | 716 | [upsample] 717 | stride=2 718 | 719 | [route] 720 | layers = -1, 36 721 | 722 | 723 | 724 | [convolutional] 725 | batch_normalize=1 726 | filters=128 727 | size=1 728 | stride=1 729 | pad=1 730 | activation=leaky 731 | 732 | [convolutional] 733 | batch_normalize=1 734 | size=3 735 | stride=1 736 | pad=1 737 | filters=256 738 | activation=leaky 739 | 740 | [convolutional] 741 | batch_normalize=1 742 | filters=128 743 | size=1 744 | stride=1 745 | pad=1 746 | activation=leaky 747 | 748 | [convolutional] 749 | batch_normalize=1 750 | size=3 751 | stride=1 752 | pad=1 753 | filters=256 754 | activation=leaky 755 | 756 | [convolutional] 757 | batch_normalize=1 758 | filters=128 759 | size=1 760 | stride=1 761 | pad=1 762 | activation=leaky 763 | 764 | [convolutional] 765 | batch_normalize=1 766 | size=3 767 | stride=1 768 | pad=1 769 | filters=256 770 | activation=leaky 771 | 772 | [convolutional] 773 | size=1 774 | stride=1 775 | pad=1 776 | filters=255 777 | activation=linear 778 | 779 | 780 | [yolo] 781 | mask = 0,1,2 782 | anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326 783 | classes=80 784 | num=9 785 | jitter=.3 786 | ignore_thresh = .7 787 | truth_thresh = 1 788 | random=1 789 | -------------------------------------------------------------------------------- /processing/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(OpenCV REQUIRED) 2 | 3 | aux_source_directory(. PROCESSING_SRCS) 4 | 5 | add_executable(processing ${PROCESSING_SRCS}) 6 | target_link_libraries(processing ${OpenCV_LIBS} detection tracking ${STDCXXFS}) 7 | target_include_directories(processing PRIVATE "${PROJECT_BINARY_DIR}") -------------------------------------------------------------------------------- /processing/TargetStorage.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "TargetStorage.h" 7 | #include "config.h" 8 | 9 | using namespace std; 10 | namespace fs = std::experimental::filesystem; 11 | 12 | TargetStorage::TargetStorage(const array &orig_dim, int video_FPS) { 13 | fs::create_directories(OUTPUT_DIR); 14 | writer.open((fs::path(OUTPUT_DIR) / VIDEO_NAME).string(), 15 | cv::VideoWriter::fourcc('a', 'v', 'c', '1'), 16 | video_FPS, cv::Size(orig_dim[1], orig_dim[0])); 17 | if (!writer.isOpened()) { 18 | throw std::runtime_error("Cannot open cv::VideoWriter"); 19 | } 20 | } 21 | 22 | void TargetStorage::update(const vector &trks, int frame, const cv::Mat &image) { 23 | for (auto[id, box]:trks) { 24 | // save normalized boxes 25 | box = normalize_rect(box, image.cols, image.rows); 26 | 27 | auto &t = targets[id]; 28 | t.trajectories.emplace(frame, box); 29 | if ((frame - t.last_snap) > 5) { 30 | t.snapshots[frame] = image(unnormalize_rect(pad_rect(box, padding), image.cols, image.rows)).clone(); 31 | t.last_snap = frame; 32 | } 33 | } 34 | 35 | record(20); 36 | 37 | writer.write(image); 38 | } 39 | 40 | void TargetStorage::record(int remain) { 41 | for (auto&[id, t]:targets) { 42 | auto dir_path = fs::path(OUTPUT_DIR) / TARGETS_DIR_NAME / to_string(id); 43 | fs::create_directories(dir_path); 44 | 45 | ofstream fp(dir_path / TRAJ_TXT_NAME, ios::app); 46 | fp << right << fixed << setprecision(3); 47 | while (t.trajectories.size() > remain) { 48 | auto &[frame, box] = *t.trajectories.begin(); 49 | fp << setw(6) << frame 50 | << setw(6) << box.x 51 | << setw(6) << box.y 52 | << setw(6) << box.width 53 | << setw(6) << box.height 54 | << setw(6) << endl; 55 | t.trajectories.erase(t.trajectories.begin()); 56 | } 57 | 58 | dir_path /= SNAPSHOTS_DIR_NAME; 59 | fs::create_directories(dir_path); 60 | while (!t.snapshots.empty()) { 61 | auto &[frame, ss] = *t.snapshots.begin(); 62 | if (t.trajectories.empty() || frame < t.trajectories.begin()->first) { 63 | cv::imwrite((dir_path / (to_string(frame) + ".jpg")).string(), ss); 64 | t.snapshots.erase(t.snapshots.begin()); 65 | } 66 | else { 67 | break; 68 | } 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /processing/TargetStorage.h: -------------------------------------------------------------------------------- 1 | #ifndef TARGET_H 2 | #define TARGET_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "Track.h" 10 | #include "util.h" 11 | 12 | class TargetStorage { 13 | public: 14 | explicit TargetStorage(const std::array &orig_dim, int video_FPS); 15 | 16 | virtual ~TargetStorage() { record(0); } 17 | 18 | void update(const std::vector &trks, 19 | int frame, const cv::Mat &image); 20 | 21 | struct Target { 22 | std::map trajectories; 23 | std::map snapshots; 24 | int last_snap = 0; 25 | }; 26 | 27 | const std::map &get() const { return targets; } 28 | 29 | private: 30 | void record(int remain); 31 | 32 | static constexpr float padding = 0.1f; 33 | 34 | std::map targets; 35 | 36 | cv::VideoWriter writer; 37 | }; 38 | 39 | #endif //TARGET_H 40 | -------------------------------------------------------------------------------- /processing/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "Detector.h" 7 | #include "DeepSORT.h" 8 | #include "TargetStorage.h" 9 | 10 | using namespace std; 11 | 12 | int main(int argc, const char *argv[]) { 13 | if (argc < 2 || argc > 3) { 14 | throw runtime_error("usage: processing []"); 15 | } 16 | auto input_path = string(argv[1]); 17 | auto scale_factor = argc == 3 ? stoi(argv[2]) : 1; 18 | 19 | cv::VideoCapture cap(input_path); 20 | if (!cap.isOpened()) { 21 | throw runtime_error("Cannot open cv::VideoCapture"); 22 | } 23 | 24 | array orig_dim{int64_t(cap.get(cv::CAP_PROP_FRAME_HEIGHT)), int64_t(cap.get(cv::CAP_PROP_FRAME_WIDTH))}; 25 | array inp_dim; 26 | for (size_t i = 0; i < 2; ++i) { 27 | auto factor = 1 << 5; 28 | inp_dim[i] = (orig_dim[i] / scale_factor / factor + 1) * factor; 29 | } 30 | Detector detector(inp_dim); 31 | DeepSORT tracker(orig_dim); 32 | 33 | TargetStorage repo(orig_dim, static_cast(cap.get(cv::CAP_PROP_FPS))); 34 | 35 | auto image = cv::Mat(); 36 | cv::namedWindow("Output", cv::WINDOW_NORMAL | cv::WINDOW_KEEPRATIO); 37 | while (cap.read(image)) { 38 | auto frame_processed = static_cast(cap.get(cv::CAP_PROP_POS_FRAMES)) - 1; 39 | 40 | auto start = chrono::steady_clock::now(); 41 | 42 | auto dets = detector.detect(image); 43 | auto trks = tracker.update(dets, image); 44 | 45 | repo.update(trks, frame_processed, image); 46 | 47 | stringstream str; 48 | str << "Frame: " << frame_processed << "/" << cap.get(cv::CAP_PROP_FRAME_COUNT) << ", " 49 | << "FPS: " << fixed << setprecision(2) 50 | << 1000.0 / chrono::duration_cast(chrono::steady_clock::now() - start).count(); 51 | draw_text(image, str.str(), {0, 0, 0}, {image.cols, 0}, true); 52 | 53 | for (auto &d:dets) { 54 | draw_bbox(image, d); 55 | } 56 | for (auto &t:trks) { 57 | draw_bbox(image, t.box, to_string(t.id), color_map(t.id)); 58 | draw_trajectories(image, repo.get().at(t.id).trajectories, color_map(t.id)); 59 | } 60 | 61 | cv::imshow("Output", image); 62 | 63 | switch (cv::waitKey(1) & 0xFF) { 64 | case 'q': 65 | return 0; 66 | case ' ': 67 | cv::imwrite(to_string(frame_processed) + ".jpg", image); 68 | break; 69 | default: 70 | break; 71 | } 72 | } 73 | } -------------------------------------------------------------------------------- /processing/util.h: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_H 2 | #define UTIL_H 3 | 4 | #include 5 | 6 | namespace { 7 | cv::Rect2f pad_rect(cv::Rect2f rect, float padding) { 8 | rect.x = std::max(0.0f, rect.x - rect.width * padding); 9 | rect.y = std::max(0.0f, rect.y - rect.height * padding); 10 | rect.width = std::min(1 - rect.x, rect.width * (1 + 2 * padding)); 11 | rect.height = std::min(1 - rect.y, rect.height * (1 + 2 * padding)); 12 | 13 | return rect; 14 | } 15 | 16 | cv::Rect2f normalize_rect(cv::Rect2f rect, float w, float h) { 17 | rect.x /= w; 18 | rect.y /= h; 19 | rect.width /= w; 20 | rect.height /= h; 21 | return rect; 22 | } 23 | 24 | cv::Rect2f unnormalize_rect(cv::Rect2f rect, float w, float h) { 25 | rect.x *= w; 26 | rect.y *= h; 27 | rect.width *= w; 28 | rect.height *= h; 29 | return rect; 30 | } 31 | 32 | cv::Scalar color_map(int64_t n) { 33 | auto bit_get = [](int64_t x, int64_t i) { return x & (1 << i); }; 34 | 35 | int64_t r = 0, g = 0, b = 0; 36 | int64_t i = n; 37 | for (int64_t j = 7; j >= 0; --j) { 38 | r |= bit_get(i, 0) << j; 39 | g |= bit_get(i, 1) << j; 40 | b |= bit_get(i, 2) << j; 41 | i >>= 3; 42 | } 43 | return cv::Scalar(b, g, r); 44 | } 45 | 46 | void draw_text(cv::Mat &img, const std::string &str, 47 | const cv::Scalar &color, cv::Point pos, bool reverse = false) { 48 | auto t_size = cv::getTextSize(str, cv::FONT_HERSHEY_PLAIN, 1, 1, nullptr); 49 | cv::Point bottom_left, upper_right; 50 | if (reverse) { 51 | upper_right = pos; 52 | bottom_left = cv::Point(upper_right.x - t_size.width, upper_right.y + t_size.height); 53 | } else { 54 | bottom_left = pos; 55 | upper_right = cv::Point(bottom_left.x + t_size.width, bottom_left.y - t_size.height); 56 | } 57 | 58 | cv::rectangle(img, bottom_left, upper_right, color, -1); 59 | cv::putText(img, str, bottom_left, cv::FONT_HERSHEY_PLAIN, 1, cv::Scalar(255, 255, 255) - color); 60 | } 61 | 62 | void draw_bbox(cv::Mat &img, cv::Rect2f bbox, 63 | const std::string &label = "", const cv::Scalar &color = {0, 0, 0}) { 64 | cv::rectangle(img, bbox, color); 65 | if (!label.empty()) { 66 | draw_text(img, label, color, bbox.tl()); 67 | } 68 | } 69 | 70 | void draw_trajectories(cv::Mat &img, const std::map &traj, 71 | const cv::Scalar &color = {0, 0, 0}) { 72 | if (traj.size() < 2) return; 73 | 74 | auto cur = traj.begin()->second; 75 | auto pt1 = cur.br(); 76 | pt1.x -= cur.width / 2; 77 | pt1.x *= img.cols; 78 | pt1.y *= img.rows; 79 | 80 | for (auto it = ++traj.begin(); it != traj.end(); ++it) { 81 | cur = it->second; 82 | auto pt2 = cur.br(); 83 | pt2.x -= cur.width / 2; 84 | pt2.x *= img.cols; 85 | pt2.y *= img.rows; 86 | cv::line(img, pt1, pt2, color); 87 | pt1 = pt2; 88 | } 89 | } 90 | } 91 | 92 | #endif //UTIL_H -------------------------------------------------------------------------------- /snapshots/UI-offline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixu000/libtorch-yolov3-deepsort/a026aa378e01c4e66371e532b5ad69605517d939/snapshots/UI-offline.png -------------------------------------------------------------------------------- /snapshots/UI-online.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixu000/libtorch-yolov3-deepsort/a026aa378e01c4e66371e532b5ad69605517d939/snapshots/UI-online.png -------------------------------------------------------------------------------- /snapshots/detection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixu000/libtorch-yolov3-deepsort/a026aa378e01c4e66371e532b5ad69605517d939/snapshots/detection.png -------------------------------------------------------------------------------- /snapshots/tracking.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixu000/libtorch-yolov3-deepsort/a026aa378e01c4e66371e532b5ad69605517d939/snapshots/tracking.png -------------------------------------------------------------------------------- /tracking/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(OpenCV REQUIRED) 2 | find_package(Torch REQUIRED) 3 | 4 | aux_source_directory(src TRACKING_SRCS) 5 | add_library(tracking SHARED ${TRACKING_SRCS}) 6 | 7 | include(GenerateExportHeader) 8 | GENERATE_EXPORT_HEADER(tracking) 9 | 10 | target_link_libraries(tracking PUBLIC ${OpenCV_LIBS} PRIVATE "${TORCH_LIBRARIES}") 11 | target_include_directories(tracking 12 | PUBLIC include ${CMAKE_CURRENT_BINARY_DIR} 13 | PRIVATE src) -------------------------------------------------------------------------------- /tracking/include/DeepSORT.h: -------------------------------------------------------------------------------- 1 | #ifndef DEEPSORT_H 2 | #define DEEPSORT_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "tracking_export.h" 9 | #include "Track.h" 10 | 11 | class Extractor; 12 | 13 | template 14 | class TrackerManager; 15 | 16 | template 17 | class FeatureMetric; 18 | 19 | class TRACKING_EXPORT DeepSORT { 20 | public: 21 | explicit DeepSORT(const std::array &dim); 22 | 23 | ~DeepSORT(); 24 | 25 | std::vector update(const std::vector &detections, cv::Mat ori_img); 26 | 27 | private: 28 | class TrackData; 29 | 30 | std::vector data; 31 | std::unique_ptr extractor; 32 | std::unique_ptr> manager; 33 | std::unique_ptr> feat_metric; 34 | }; 35 | 36 | 37 | #endif //DEEPSORT_H 38 | -------------------------------------------------------------------------------- /tracking/include/SORT.h: -------------------------------------------------------------------------------- 1 | #ifndef SORT_H 2 | #define SORT_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "tracking_export.h" 9 | #include "Track.h" 10 | 11 | template 12 | class TrackerManager; 13 | 14 | class TRACKING_EXPORT SORT { 15 | public: 16 | explicit SORT(const std::array &dim); 17 | 18 | ~SORT(); 19 | 20 | std::vector update(const std::vector &dets); 21 | 22 | private: 23 | class TrackData; 24 | 25 | std::vector data; 26 | 27 | std::unique_ptr> manager; 28 | }; 29 | 30 | #endif //SORT_H 31 | -------------------------------------------------------------------------------- /tracking/include/Track.h: -------------------------------------------------------------------------------- 1 | #ifndef DEFINES_H 2 | #define DEFINES_H 3 | 4 | #include 5 | 6 | struct Track { 7 | int id; 8 | cv::Rect2f box; 9 | }; 10 | 11 | 12 | #endif //DEFINES_H 13 | -------------------------------------------------------------------------------- /tracking/src/DeepSORT.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "DeepSORT.h" 4 | #include "Extractor.h" 5 | #include "TrackerManager.h" 6 | #include "nn_matching.h" 7 | 8 | using namespace std; 9 | 10 | struct DeepSORT::TrackData { 11 | KalmanTracker kalman; 12 | FeatureBundle feats; 13 | }; 14 | 15 | DeepSORT::DeepSORT(const array &dim) 16 | : extractor(make_unique()), 17 | manager(make_unique>(data, dim)), 18 | feat_metric(make_unique>(data)) {} 19 | 20 | 21 | DeepSORT::~DeepSORT() = default; 22 | 23 | vector DeepSORT::update(const std::vector &detections, cv::Mat ori_img) { 24 | manager->predict(); 25 | manager->remove_nan(); 26 | 27 | auto matched = manager->update( 28 | detections, 29 | [this, &detections, &ori_img](const std::vector &trk_ids, const std::vector &det_ids) { 30 | vector trks; 31 | for (auto t : trk_ids) { 32 | trks.push_back(data[t].kalman.rect()); 33 | } 34 | vector boxes; 35 | vector dets; 36 | for (auto d:det_ids) { 37 | dets.push_back(detections[d]); 38 | boxes.push_back(ori_img(detections[d])); 39 | } 40 | 41 | auto iou_mat = iou_dist(dets, trks); 42 | auto feat_mat = feat_metric->distance(extractor->extract(boxes), trk_ids); 43 | feat_mat.masked_fill_((iou_mat > 0.8f).__ior__(feat_mat > 0.2f), INVALID_DIST); 44 | return feat_mat; 45 | }, 46 | [this, &detections](const std::vector &trk_ids, const std::vector &det_ids) { 47 | vector trks; 48 | for (auto t : trk_ids) { 49 | trks.push_back(data[t].kalman.rect()); 50 | } 51 | vector dets; 52 | for (auto &d:det_ids) { 53 | dets.push_back(detections[d]); 54 | } 55 | auto iou_mat = iou_dist(dets, trks); 56 | iou_mat.masked_fill_(iou_mat > 0.7f, INVALID_DIST); 57 | return iou_mat; 58 | }); 59 | 60 | vector boxes; 61 | vector targets; 62 | for (auto[x, y]:matched) { 63 | targets.emplace_back(x); 64 | boxes.emplace_back(ori_img(detections[y])); 65 | } 66 | feat_metric->update(extractor->extract(boxes), targets); 67 | 68 | manager->remove_deleted(); 69 | 70 | return manager->visible_tracks(); 71 | } 72 | -------------------------------------------------------------------------------- /tracking/src/Extractor.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "Extractor.h" 4 | 5 | namespace nn = torch::nn; 6 | using namespace std; 7 | 8 | namespace { 9 | struct BasicBlockImpl : nn::Module { 10 | explicit BasicBlockImpl(int64_t c_in, int64_t c_out, bool is_downsample = false) { 11 | conv = register_module( 12 | "conv", 13 | nn::Sequential( 14 | nn::Conv2d(nn::Conv2dOptions(c_in, c_out, 3) 15 | .stride(is_downsample ? 2 : 1) 16 | .padding(1).with_bias(false)), 17 | nn::BatchNorm(c_out), 18 | nn::Functional(torch::relu), 19 | nn::Conv2d(nn::Conv2dOptions(c_out, c_out, 3) 20 | .stride(1).padding(1).with_bias(false)), 21 | nn::BatchNorm(c_out))); 22 | 23 | if (is_downsample) { 24 | downsample = register_module( 25 | "downsample", 26 | nn::Sequential(nn::Conv2d(nn::Conv2dOptions(c_in, c_out, 1) 27 | .stride(2).with_bias(false)), 28 | nn::BatchNorm(c_out))); 29 | } else if (c_in != c_out) { 30 | downsample = register_module( 31 | "downsample", 32 | nn::Sequential(nn::Conv2d(nn::Conv2dOptions(c_in, c_out, 1) 33 | .stride(1).with_bias(false)), 34 | nn::BatchNorm(c_out))); 35 | } 36 | } 37 | 38 | torch::Tensor forward(torch::Tensor x) { 39 | auto y = conv->forward(x); 40 | if (!downsample.is_empty()) { 41 | x = downsample->forward(x); 42 | } 43 | return torch::relu(x + y); 44 | } 45 | 46 | nn::Sequential conv{nullptr}, downsample{nullptr}; 47 | }; 48 | 49 | TORCH_MODULE(BasicBlock); 50 | 51 | void load_tensor(torch::Tensor t, ifstream &fs) { 52 | fs.read(static_cast(t.data_ptr()), t.numel() * sizeof(float)); 53 | } 54 | 55 | void load_Conv2d(nn::Conv2d m, ifstream &fs) { 56 | load_tensor(m->weight, fs); 57 | if (m->options.with_bias()) { 58 | load_tensor(m->bias, fs); 59 | } 60 | } 61 | 62 | void load_BatchNorm(nn::BatchNorm m, ifstream &fs) { 63 | load_tensor(m->weight, fs); 64 | load_tensor(m->bias, fs); 65 | load_tensor(m->running_mean, fs); 66 | load_tensor(m->running_var, fs); 67 | } 68 | 69 | void load_Sequential(nn::Sequential s, ifstream &fs) { 70 | if (s.is_empty()) return; 71 | for (auto &m:s->children()) { 72 | if (auto c = dynamic_pointer_cast(m)) { 73 | load_Conv2d(c, fs); 74 | } else if (auto b = dynamic_pointer_cast(m)) { 75 | load_BatchNorm(b, fs); 76 | } 77 | } 78 | } 79 | 80 | nn::Sequential make_layers(int64_t c_in, int64_t c_out, size_t repeat_times, bool is_downsample = false) { 81 | nn::Sequential ret; 82 | for (size_t i = 0; i < repeat_times; ++i) { 83 | ret->push_back(BasicBlock(i == 0 ? c_in : c_out, c_out, i == 0 ? is_downsample : false)); 84 | } 85 | return ret; 86 | } 87 | } 88 | 89 | NetImpl::NetImpl() { 90 | conv1 = register_module("conv1", 91 | nn::Sequential( 92 | nn::Conv2d(nn::Conv2dOptions(3, 64, 3) 93 | .stride(1).padding(1)), 94 | nn::BatchNorm(64), 95 | nn::Functional(torch::relu))); 96 | conv2 = register_module("conv2", nn::Sequential()); 97 | conv2->extend(*make_layers(64, 64, 2, false)); 98 | conv2->extend(*make_layers(64, 128, 2, true)); 99 | conv2->extend(*make_layers(128, 256, 2, true)); 100 | conv2->extend(*make_layers(256, 512, 2, true)); 101 | } 102 | 103 | torch::Tensor NetImpl::forward(torch::Tensor x) { 104 | x = conv1->forward(x); 105 | x = torch::max_pool2d(x, 3, 2, 1); 106 | x = conv2->forward(x); 107 | x = torch::avg_pool2d(x, {8, 4}, 1); 108 | x = x.view({x.size(0), -1}); 109 | x.div_(x.norm(2, 1, true)); 110 | return x; 111 | } 112 | 113 | void NetImpl::load_form(const std::string &bin_path) { 114 | ifstream fs(bin_path, ios_base::binary); 115 | 116 | load_Sequential(conv1, fs); 117 | 118 | for (auto &m:conv2->children()) { 119 | auto b = static_pointer_cast(m); 120 | load_Sequential(b->conv, fs); 121 | load_Sequential(b->downsample, fs); 122 | } 123 | 124 | fs.close(); 125 | } 126 | 127 | Extractor::Extractor() { 128 | net->load_form("weights/ckpt.bin"); 129 | net->to(torch::kCUDA); 130 | net->eval(); 131 | } 132 | 133 | torch::Tensor Extractor::extract(vector input) { 134 | if (input.empty()) { 135 | return torch::empty({0, 512}); 136 | } 137 | 138 | torch::NoGradGuard no_grad; 139 | 140 | static const auto MEAN = torch::tensor({0.485f, 0.456f, 0.406f}).view({1, -1, 1, 1}).cuda(); 141 | static const auto STD = torch::tensor({0.229f, 0.224f, 0.225f}).view({1, -1, 1, 1}).cuda(); 142 | 143 | vector resized; 144 | for (auto &x:input) { 145 | cv::resize(x, x, {64, 128}); 146 | cv::cvtColor(x, x, cv::COLOR_RGB2BGR); 147 | x.convertTo(x, CV_32F, 1.0 / 255); 148 | resized.push_back(torch::from_blob(x.data, {128, 64, 3})); 149 | } 150 | auto tensor = torch::stack(resized).cuda().permute({0, 3, 1, 2}).sub_(MEAN).div_(STD); 151 | return net(tensor); 152 | } -------------------------------------------------------------------------------- /tracking/src/Extractor.h: -------------------------------------------------------------------------------- 1 | #ifndef EXTRACTOR_H 2 | #define EXTRACTOR_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | struct NetImpl : torch::nn::Module { 10 | public: 11 | NetImpl(); 12 | 13 | torch::Tensor forward(torch::Tensor x); 14 | 15 | void load_form(const std::string &bin_path); 16 | 17 | private: 18 | torch::nn::Sequential conv1{nullptr}, conv2{nullptr}; 19 | }; 20 | 21 | TORCH_MODULE(Net); 22 | 23 | class Extractor { 24 | public: 25 | Extractor(); 26 | 27 | torch::Tensor extract(std::vector input); // return GPUTensor 28 | 29 | private: 30 | Net net; 31 | }; 32 | 33 | 34 | #endif //EXTRACTOR_H 35 | -------------------------------------------------------------------------------- /tracking/src/Hungarian.cpp: -------------------------------------------------------------------------------- 1 | // Hungarian.cpp: Implementation file for Class HungarianAlgorithm. 2 | // 3 | // This is a C++ wrapper with slight modification of a hungarian algorithm implementation by Markus Buehren. 4 | // The original implementation is a few mex-functions for use in MATLAB, found here: 5 | // http://www.mathworks.com/matlabcentral/fileexchange/6543-functions-for-the-rectangular-assignment-problem 6 | // 7 | // Both this code and the orignal code are published under the BSD license. 8 | // by Cong Ma, 2016 9 | 10 | #include 11 | #include 12 | 13 | #include "Hungarian.h" 14 | 15 | using namespace std; 16 | 17 | //********************************************************// 18 | // A single function wrapper for solving assignment problem. 19 | //********************************************************// 20 | double HungarianAlgorithm::Solve(vector> &DistMatrix, vector &Assignment) { 21 | if (DistMatrix.empty()) { 22 | Assignment.clear(); 23 | return 0; 24 | } 25 | auto nRows = DistMatrix.size(); 26 | auto nCols = DistMatrix[0].size(); 27 | 28 | auto distMatrixIn = new double[nRows * nCols]; 29 | auto assignment = new int[nRows]; 30 | auto cost = 0.0; 31 | 32 | // Fill in the distMatrixIn. Mind the index is "i + nRows * j". 33 | // Here the cost matrix of size MxN is defined as a double precision array of N*M elements. 34 | // In the solving functions matrices are seen to be saved MATLAB-internally in row-order. 35 | // (i.e. the matrix [1 2; 3 4] will be stored as a vector [1 3 2 4], NOT [1 2 3 4]). 36 | for (unsigned int i = 0; i < nRows; i++) 37 | for (unsigned int j = 0; j < nCols; j++) 38 | distMatrixIn[i + nRows * j] = DistMatrix[i][j]; 39 | 40 | // call solving function 41 | assignmentoptimal(assignment, &cost, distMatrixIn, nRows, nCols); 42 | 43 | Assignment.clear(); 44 | for (unsigned int r = 0; r < nRows; r++) 45 | Assignment.push_back(assignment[r]); 46 | 47 | delete[] distMatrixIn; 48 | delete[] assignment; 49 | return cost; 50 | } 51 | 52 | 53 | //********************************************************// 54 | // Solve optimal solution for assignment problem using Munkres algorithm, also known as Hungarian Algorithm. 55 | //********************************************************// 56 | void HungarianAlgorithm::assignmentoptimal(int *assignment, double *cost, double *distMatrixIn, int nOfRows, 57 | int nOfColumns) { 58 | double *distMatrix, *distMatrixTemp, *distMatrixEnd, *columnEnd, value, minValue; 59 | bool *coveredColumns, *coveredRows, *starMatrix, *newStarMatrix, *primeMatrix; 60 | int nOfElements, minDim, row, col; 61 | 62 | /* initialization */ 63 | *cost = 0; 64 | for (row = 0; row < nOfRows; row++) 65 | assignment[row] = -1; 66 | 67 | /* generate working copy of distance Matrix */ 68 | /* check if all matrix elements are positive */ 69 | nOfElements = nOfRows * nOfColumns; 70 | distMatrix = (double *) malloc(nOfElements * sizeof(double)); 71 | distMatrixEnd = distMatrix + nOfElements; 72 | 73 | for (row = 0; row < nOfElements; row++) { 74 | value = distMatrixIn[row]; 75 | if (value < 0) 76 | cerr << "All matrix elements have to be non-negative." << endl; 77 | distMatrix[row] = value; 78 | } 79 | 80 | 81 | /* memory allocation */ 82 | coveredColumns = (bool *) calloc(nOfColumns, sizeof(bool)); 83 | coveredRows = (bool *) calloc(nOfRows, sizeof(bool)); 84 | starMatrix = (bool *) calloc(nOfElements, sizeof(bool)); 85 | primeMatrix = (bool *) calloc(nOfElements, sizeof(bool)); 86 | newStarMatrix = (bool *) calloc(nOfElements, sizeof(bool)); /* used in step4 */ 87 | 88 | /* preliminary steps */ 89 | if (nOfRows <= nOfColumns) { 90 | minDim = nOfRows; 91 | 92 | for (row = 0; row < nOfRows; row++) { 93 | /* find the smallest element in the row */ 94 | distMatrixTemp = distMatrix + row; 95 | minValue = *distMatrixTemp; 96 | distMatrixTemp += nOfRows; 97 | while (distMatrixTemp < distMatrixEnd) { 98 | value = *distMatrixTemp; 99 | if (value < minValue) 100 | minValue = value; 101 | distMatrixTemp += nOfRows; 102 | } 103 | 104 | /* subtract the smallest element from each element of the row */ 105 | distMatrixTemp = distMatrix + row; 106 | while (distMatrixTemp < distMatrixEnd) { 107 | *distMatrixTemp -= minValue; 108 | distMatrixTemp += nOfRows; 109 | } 110 | } 111 | 112 | /* Steps 1 and 2a */ 113 | for (row = 0; row < nOfRows; row++) 114 | for (col = 0; col < nOfColumns; col++) 115 | if (fabs(distMatrix[row + nOfRows * col]) < DBL_EPSILON) 116 | if (!coveredColumns[col]) { 117 | starMatrix[row + nOfRows * col] = true; 118 | coveredColumns[col] = true; 119 | break; 120 | } 121 | } else /* if(nOfRows > nOfColumns) */ 122 | { 123 | minDim = nOfColumns; 124 | 125 | for (col = 0; col < nOfColumns; col++) { 126 | /* find the smallest element in the column */ 127 | distMatrixTemp = distMatrix + nOfRows * col; 128 | columnEnd = distMatrixTemp + nOfRows; 129 | 130 | minValue = *distMatrixTemp++; 131 | while (distMatrixTemp < columnEnd) { 132 | value = *distMatrixTemp++; 133 | if (value < minValue) 134 | minValue = value; 135 | } 136 | 137 | /* subtract the smallest element from each element of the column */ 138 | distMatrixTemp = distMatrix + nOfRows * col; 139 | while (distMatrixTemp < columnEnd) 140 | *distMatrixTemp++ -= minValue; 141 | } 142 | 143 | /* Steps 1 and 2a */ 144 | for (col = 0; col < nOfColumns; col++) 145 | for (row = 0; row < nOfRows; row++) 146 | if (fabs(distMatrix[row + nOfRows * col]) < DBL_EPSILON) 147 | if (!coveredRows[row]) { 148 | starMatrix[row + nOfRows * col] = true; 149 | coveredColumns[col] = true; 150 | coveredRows[row] = true; 151 | break; 152 | } 153 | for (row = 0; row < nOfRows; row++) 154 | coveredRows[row] = false; 155 | 156 | } 157 | 158 | /* move to step 2b */ 159 | step2b(assignment, distMatrix, starMatrix, newStarMatrix, primeMatrix, coveredColumns, coveredRows, nOfRows, 160 | nOfColumns, minDim); 161 | 162 | /* compute cost and remove invalid assignments */ 163 | computeassignmentcost(assignment, cost, distMatrixIn, nOfRows); 164 | 165 | /* free allocated memory */ 166 | free(distMatrix); 167 | free(coveredColumns); 168 | free(coveredRows); 169 | free(starMatrix); 170 | free(primeMatrix); 171 | free(newStarMatrix); 172 | 173 | return; 174 | } 175 | 176 | void HungarianAlgorithm::buildassignmentvector(int *assignment, bool *starMatrix, int nOfRows, int nOfColumns) { 177 | int row, col; 178 | 179 | for (row = 0; row < nOfRows; row++) 180 | for (col = 0; col < nOfColumns; col++) 181 | if (starMatrix[row + nOfRows * col]) { 182 | #ifdef ONE_INDEXING 183 | assignment[row] = col + 1; /* MATLAB-Indexing */ 184 | #else 185 | assignment[row] = col; 186 | #endif 187 | break; 188 | } 189 | } 190 | 191 | void HungarianAlgorithm::computeassignmentcost(int *assignment, double *cost, double *distMatrix, int nOfRows) { 192 | int row, col; 193 | 194 | for (row = 0; row < nOfRows; row++) { 195 | col = assignment[row]; 196 | if (col >= 0) 197 | *cost += distMatrix[row + nOfRows * col]; 198 | } 199 | } 200 | 201 | void HungarianAlgorithm::step2a(int *assignment, double *distMatrix, bool *starMatrix, bool *newStarMatrix, 202 | bool *primeMatrix, bool *coveredColumns, bool *coveredRows, int nOfRows, int nOfColumns, 203 | int minDim) { 204 | bool *starMatrixTemp, *columnEnd; 205 | int col; 206 | 207 | /* cover every column containing a starred zero */ 208 | for (col = 0; col < nOfColumns; col++) { 209 | starMatrixTemp = starMatrix + nOfRows * col; 210 | columnEnd = starMatrixTemp + nOfRows; 211 | while (starMatrixTemp < columnEnd) { 212 | if (*starMatrixTemp++) { 213 | coveredColumns[col] = true; 214 | break; 215 | } 216 | } 217 | } 218 | 219 | /* move to step 3 */ 220 | step2b(assignment, distMatrix, starMatrix, newStarMatrix, primeMatrix, coveredColumns, coveredRows, nOfRows, 221 | nOfColumns, minDim); 222 | } 223 | 224 | void HungarianAlgorithm::step2b(int *assignment, double *distMatrix, bool *starMatrix, bool *newStarMatrix, 225 | bool *primeMatrix, bool *coveredColumns, bool *coveredRows, int nOfRows, int nOfColumns, 226 | int minDim) { 227 | int col, nOfCoveredColumns; 228 | 229 | /* count covered columns */ 230 | nOfCoveredColumns = 0; 231 | for (col = 0; col < nOfColumns; col++) 232 | if (coveredColumns[col]) 233 | nOfCoveredColumns++; 234 | 235 | if (nOfCoveredColumns == minDim) { 236 | /* algorithm finished */ 237 | buildassignmentvector(assignment, starMatrix, nOfRows, nOfColumns); 238 | } else { 239 | /* move to step 3 */ 240 | step3(assignment, distMatrix, starMatrix, newStarMatrix, primeMatrix, coveredColumns, coveredRows, nOfRows, 241 | nOfColumns, minDim); 242 | } 243 | 244 | } 245 | 246 | void 247 | HungarianAlgorithm::step3(int *assignment, double *distMatrix, bool *starMatrix, bool *newStarMatrix, bool *primeMatrix, 248 | bool *coveredColumns, bool *coveredRows, int nOfRows, int nOfColumns, int minDim) { 249 | bool zerosFound; 250 | int row, col, starCol; 251 | 252 | zerosFound = true; 253 | while (zerosFound) { 254 | zerosFound = false; 255 | for (col = 0; col < nOfColumns; col++) 256 | if (!coveredColumns[col]) 257 | for (row = 0; row < nOfRows; row++) 258 | if ((!coveredRows[row]) && (fabs(distMatrix[row + nOfRows * col]) < DBL_EPSILON)) { 259 | /* prime zero */ 260 | primeMatrix[row + nOfRows * col] = true; 261 | 262 | /* find starred zero in current row */ 263 | for (starCol = 0; starCol < nOfColumns; starCol++) 264 | if (starMatrix[row + nOfRows * starCol]) 265 | break; 266 | 267 | if (starCol == nOfColumns) /* no starred zero found */ 268 | { 269 | /* move to step 4 */ 270 | step4(assignment, distMatrix, starMatrix, newStarMatrix, primeMatrix, coveredColumns, 271 | coveredRows, nOfRows, nOfColumns, minDim, row, col); 272 | return; 273 | } else { 274 | coveredRows[row] = true; 275 | coveredColumns[starCol] = false; 276 | zerosFound = true; 277 | break; 278 | } 279 | } 280 | } 281 | 282 | /* move to step 5 */ 283 | step5(assignment, distMatrix, starMatrix, newStarMatrix, primeMatrix, coveredColumns, coveredRows, nOfRows, 284 | nOfColumns, minDim); 285 | } 286 | 287 | void 288 | HungarianAlgorithm::step4(int *assignment, double *distMatrix, bool *starMatrix, bool *newStarMatrix, bool *primeMatrix, 289 | bool *coveredColumns, bool *coveredRows, int nOfRows, int nOfColumns, int minDim, int row, 290 | int col) { 291 | int n, starRow, starCol, primeRow, primeCol; 292 | int nOfElements = nOfRows * nOfColumns; 293 | 294 | /* generate temporary copy of starMatrix */ 295 | for (n = 0; n < nOfElements; n++) 296 | newStarMatrix[n] = starMatrix[n]; 297 | 298 | /* star current zero */ 299 | newStarMatrix[row + nOfRows * col] = true; 300 | 301 | /* find starred zero in current column */ 302 | starCol = col; 303 | for (starRow = 0; starRow < nOfRows; starRow++) 304 | if (starMatrix[starRow + nOfRows * starCol]) 305 | break; 306 | 307 | while (starRow < nOfRows) { 308 | /* unstar the starred zero */ 309 | newStarMatrix[starRow + nOfRows * starCol] = false; 310 | 311 | /* find primed zero in current row */ 312 | primeRow = starRow; 313 | for (primeCol = 0; primeCol < nOfColumns; primeCol++) 314 | if (primeMatrix[primeRow + nOfRows * primeCol]) 315 | break; 316 | 317 | /* star the primed zero */ 318 | newStarMatrix[primeRow + nOfRows * primeCol] = true; 319 | 320 | /* find starred zero in current column */ 321 | starCol = primeCol; 322 | for (starRow = 0; starRow < nOfRows; starRow++) 323 | if (starMatrix[starRow + nOfRows * starCol]) 324 | break; 325 | } 326 | 327 | /* use temporary copy as new starMatrix */ 328 | /* delete all primes, uncover all rows */ 329 | for (n = 0; n < nOfElements; n++) { 330 | primeMatrix[n] = false; 331 | starMatrix[n] = newStarMatrix[n]; 332 | } 333 | for (n = 0; n < nOfRows; n++) 334 | coveredRows[n] = false; 335 | 336 | /* move to step 2a */ 337 | step2a(assignment, distMatrix, starMatrix, newStarMatrix, primeMatrix, coveredColumns, coveredRows, nOfRows, 338 | nOfColumns, minDim); 339 | } 340 | 341 | void 342 | HungarianAlgorithm::step5(int *assignment, double *distMatrix, bool *starMatrix, bool *newStarMatrix, bool *primeMatrix, 343 | bool *coveredColumns, bool *coveredRows, int nOfRows, int nOfColumns, int minDim) { 344 | double h, value; 345 | int row, col; 346 | 347 | /* find smallest uncovered element h */ 348 | h = DBL_MAX; 349 | for (row = 0; row < nOfRows; row++) 350 | if (!coveredRows[row]) 351 | for (col = 0; col < nOfColumns; col++) 352 | if (!coveredColumns[col]) { 353 | value = distMatrix[row + nOfRows * col]; 354 | if (value < h) 355 | h = value; 356 | } 357 | 358 | /* add h to each covered row */ 359 | for (row = 0; row < nOfRows; row++) 360 | if (coveredRows[row]) 361 | for (col = 0; col < nOfColumns; col++) 362 | distMatrix[row + nOfRows * col] += h; 363 | 364 | /* subtract h from each uncovered column */ 365 | for (col = 0; col < nOfColumns; col++) 366 | if (!coveredColumns[col]) 367 | for (row = 0; row < nOfRows; row++) 368 | distMatrix[row + nOfRows * col] -= h; 369 | 370 | /* move to step 3 */ 371 | step3(assignment, distMatrix, starMatrix, newStarMatrix, primeMatrix, coveredColumns, coveredRows, nOfRows, 372 | nOfColumns, minDim); 373 | } 374 | -------------------------------------------------------------------------------- /tracking/src/Hungarian.h: -------------------------------------------------------------------------------- 1 | // Hungarian.h: Header file for Class HungarianAlgorithm. 2 | // 3 | // This is a C++ wrapper with slight modification of a hungarian algorithm implementation by Markus Buehren. 4 | // The original implementation is a few mex-functions for use in MATLAB, found here: 5 | // http://www.mathworks.com/matlabcentral/fileexchange/6543-functions-for-the-rectangular-assignment-problem 6 | // 7 | // Both this code and the orignal code are published under the BSD license. 8 | // by Cong Ma, 2016 9 | 10 | #include 11 | #include 12 | 13 | 14 | class HungarianAlgorithm { 15 | public: 16 | double Solve(std::vector> &DistMatrix, std::vector &Assignment); 17 | 18 | private: 19 | void assignmentoptimal(int *assignment, double *cost, double *distMatrix, int nOfRows, int nOfColumns); 20 | 21 | void buildassignmentvector(int *assignment, bool *starMatrix, int nOfRows, int nOfColumns); 22 | 23 | void computeassignmentcost(int *assignment, double *cost, double *distMatrix, int nOfRows); 24 | 25 | void step2a(int *assignment, double *distMatrix, bool *starMatrix, bool *newStarMatrix, bool *primeMatrix, 26 | bool *coveredColumns, bool *coveredRows, int nOfRows, int nOfColumns, int minDim); 27 | 28 | void step2b(int *assignment, double *distMatrix, bool *starMatrix, bool *newStarMatrix, bool *primeMatrix, 29 | bool *coveredColumns, bool *coveredRows, int nOfRows, int nOfColumns, int minDim); 30 | 31 | void step3(int *assignment, double *distMatrix, bool *starMatrix, bool *newStarMatrix, bool *primeMatrix, 32 | bool *coveredColumns, bool *coveredRows, int nOfRows, int nOfColumns, int minDim); 33 | 34 | void step4(int *assignment, double *distMatrix, bool *starMatrix, bool *newStarMatrix, bool *primeMatrix, 35 | bool *coveredColumns, bool *coveredRows, int nOfRows, int nOfColumns, int minDim, int row, int col); 36 | 37 | void step5(int *assignment, double *distMatrix, bool *starMatrix, bool *newStarMatrix, bool *primeMatrix, 38 | bool *coveredColumns, bool *coveredRows, int nOfRows, int nOfColumns, int minDim); 39 | }; 40 | -------------------------------------------------------------------------------- /tracking/src/KalmanTracker.cpp: -------------------------------------------------------------------------------- 1 | #include "KalmanTracker.h" 2 | 3 | using namespace cv; 4 | 5 | namespace { 6 | // Convert bounding box from [cx,cy,s,r] to [x,y,w,h] style. 7 | cv::Rect2f get_rect_xysr(const Mat &xysr) { 8 | auto cx = xysr.at(0, 0), cy = xysr.at(1, 0), s = xysr.at(2, 0), r = xysr.at(3, 0); 9 | float w = sqrt(s * r); 10 | float h = s / w; 11 | float x = (cx - w / 2); 12 | float y = (cy - h / 2); 13 | 14 | return cv::Rect2f(x, y, w, h); 15 | } 16 | } 17 | 18 | int KalmanTracker::count = 0; 19 | 20 | KalmanTracker::KalmanTracker() { 21 | int stateNum = 7; 22 | int measureNum = 4; 23 | kf = KalmanFilter(stateNum, measureNum, 0); 24 | 25 | measurement = Mat::zeros(measureNum, 1, CV_32F); 26 | 27 | kf.transitionMatrix = (Mat_(stateNum, stateNum) 28 | << 29 | 1, 0, 0, 0, 1, 0, 0, 30 | 0, 1, 0, 0, 0, 1, 0, 31 | 0, 0, 1, 0, 0, 0, 1, 32 | 0, 0, 0, 1, 0, 0, 0, 33 | 0, 0, 0, 0, 1, 0, 0, 34 | 0, 0, 0, 0, 0, 1, 0, 35 | 0, 0, 0, 0, 0, 0, 1); 36 | 37 | setIdentity(kf.measurementMatrix); 38 | setIdentity(kf.processNoiseCov, Scalar::all(1e-2)); 39 | setIdentity(kf.measurementNoiseCov, Scalar::all(1e-1)); 40 | setIdentity(kf.errorCovPost, Scalar::all(1)); 41 | } 42 | 43 | void KalmanTracker::init(cv::Rect2f initRect) { 44 | // initialize state vector with bounding box in [cx,cy,s,r] style 45 | kf.statePost.at(0, 0) = initRect.x + initRect.width / 2; 46 | kf.statePost.at(1, 0) = initRect.y + initRect.height / 2; 47 | kf.statePost.at(2, 0) = initRect.area(); 48 | kf.statePost.at(3, 0) = initRect.width / initRect.height; 49 | } 50 | 51 | // Predict the estimated bounding box. 52 | void KalmanTracker::predict() { 53 | ++time_since_update; 54 | 55 | kf.predict(); 56 | } 57 | 58 | // Update the state vector with observed bounding box. 59 | void KalmanTracker::update(cv::Rect2f stateMat) { 60 | time_since_update = 0; 61 | ++hits; 62 | 63 | if (_state == TrackState::Tentative && hits > n_init) { 64 | _state = TrackState::Confirmed; 65 | _id = count++; 66 | } 67 | 68 | // measurement 69 | measurement.at(0, 0) = stateMat.x + stateMat.width / 2; 70 | measurement.at(1, 0) = stateMat.y + stateMat.height / 2; 71 | measurement.at(2, 0) = stateMat.area(); 72 | measurement.at(3, 0) = stateMat.width / stateMat.height; 73 | 74 | // update 75 | kf.correct(measurement); 76 | } 77 | 78 | void KalmanTracker::miss() { 79 | if (_state == TrackState::Tentative) { 80 | _state = TrackState::Deleted; 81 | } else if (time_since_update > max_age) { 82 | _state = TrackState::Deleted; 83 | } 84 | } 85 | 86 | // Return the current state vector 87 | cv::Rect2f KalmanTracker::rect() const { 88 | return get_rect_xysr(kf.statePost); 89 | } 90 | -------------------------------------------------------------------------------- /tracking/src/KalmanTracker.h: -------------------------------------------------------------------------------- 1 | #ifndef KALMAN_H 2 | #define KALMAN_H 3 | 4 | #include "opencv2/video/tracking.hpp" 5 | 6 | enum class TrackState { 7 | Tentative, 8 | Confirmed, 9 | Deleted 10 | }; 11 | 12 | 13 | // This class represents the internel state of individual tracked objects observed as bounding box. 14 | class KalmanTracker { 15 | public: 16 | KalmanTracker(); 17 | 18 | explicit KalmanTracker(cv::Rect2f initRect) : KalmanTracker() { init(initRect); } 19 | 20 | void init(cv::Rect2f initRect); 21 | 22 | void predict(); 23 | 24 | void update(cv::Rect2f stateMat); 25 | 26 | void miss(); 27 | 28 | cv::Rect2f rect() const; 29 | 30 | TrackState state() const { return _state; } 31 | 32 | int id() const { return _id; } 33 | 34 | private: 35 | static const auto max_age = 30; 36 | static const auto n_init = 3; 37 | 38 | static int count; 39 | 40 | TrackState _state = TrackState::Tentative; 41 | 42 | int _id = -1; 43 | 44 | int time_since_update = 0; 45 | int hits = 0; 46 | 47 | cv::KalmanFilter kf; 48 | cv::Mat measurement; 49 | }; 50 | 51 | #endif //KALMAN_H -------------------------------------------------------------------------------- /tracking/src/SORT.cpp: -------------------------------------------------------------------------------- 1 | #include "SORT.h" 2 | #include "TrackerManager.h" 3 | #include "KalmanTracker.h" 4 | #include "nn_matching.h" 5 | 6 | using namespace std; 7 | 8 | struct SORT::TrackData { 9 | KalmanTracker kalman; 10 | }; 11 | 12 | SORT::SORT(const array &dim) 13 | : manager(make_unique>(data, dim)) {} 14 | 15 | SORT::~SORT() = default; 16 | 17 | vector SORT::update(const vector &detections) { 18 | manager->predict(); 19 | manager->remove_nan(); 20 | 21 | auto metric = [this, &detections](const std::vector &trk_ids, const std::vector &det_ids) { 22 | vector trks; 23 | for (auto t : trk_ids) { 24 | trks.push_back(data[t].kalman.rect()); 25 | } 26 | vector dets; 27 | for (auto &d:det_ids) { 28 | dets.push_back(detections[d]); 29 | } 30 | auto iou_mat = iou_dist(dets, trks); 31 | iou_mat.masked_fill_(iou_mat > 0.7f, INVALID_DIST); 32 | return iou_mat; 33 | }; 34 | manager->update(detections, metric, metric); 35 | manager->remove_deleted(); 36 | 37 | return manager->visible_tracks(); 38 | } 39 | -------------------------------------------------------------------------------- /tracking/src/TrackerManager.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "TrackerManager.h" 5 | #include "Hungarian.h" 6 | 7 | using namespace std; 8 | using namespace cv; 9 | 10 | void associate_detections_to_trackers_idx(const DistanceMetricFunc &metric, 11 | vector &unmatched_trks, 12 | vector &unmatched_dets, 13 | vector> &matched) { 14 | auto dist = metric(unmatched_trks, unmatched_dets); 15 | auto dist_a = dist.accessor(); 16 | auto dist_v = vector>(dist.size(0), vector(dist.size(1))); 17 | for (size_t i = 0; i < dist.size(0); ++i) { 18 | for (size_t j = 0; j < dist.size(1); ++j) { 19 | dist_v[i][j] = dist_a[i][j]; 20 | } 21 | } 22 | 23 | vector assignment; 24 | HungarianAlgorithm().Solve(dist_v, assignment); 25 | 26 | // filter out matched with low IOU 27 | for (size_t i = 0; i < assignment.size(); ++i) { 28 | if (assignment[i] == -1) // pass over invalid values 29 | continue; 30 | if (dist_v[i][assignment[i]] > INVALID_DIST / 10) { 31 | assignment[i] = -1; 32 | } else { 33 | matched.emplace_back(make_tuple(unmatched_trks[i], unmatched_dets[assignment[i]])); 34 | } 35 | } 36 | 37 | for (size_t i = 0; i < assignment.size(); ++i) { 38 | if (assignment[i] != -1) { 39 | unmatched_trks[i] = -1; 40 | } 41 | } 42 | unmatched_trks.erase(remove_if(unmatched_trks.begin(), unmatched_trks.end(), 43 | [](int i) { return i == -1; }), 44 | unmatched_trks.end()); 45 | 46 | sort(assignment.begin(), assignment.end()); 47 | vector unmatched_dets_new; 48 | set_difference(unmatched_dets.begin(), unmatched_dets.end(), 49 | assignment.begin(), assignment.end(), 50 | inserter(unmatched_dets_new, unmatched_dets_new.begin())); 51 | unmatched_dets = move(unmatched_dets_new); 52 | } 53 | -------------------------------------------------------------------------------- /tracking/src/TrackerManager.h: -------------------------------------------------------------------------------- 1 | #ifndef TRACKER_H 2 | #define TRACKER_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "Track.h" 9 | #include "KalmanTracker.h" 10 | 11 | using DistanceMetricFunc = std::function< 12 | torch::Tensor(const std::vector &trk_ids, const std::vector &det_ids)>; 13 | 14 | const float INVALID_DIST = 1E3f; 15 | 16 | void associate_detections_to_trackers_idx(const DistanceMetricFunc &metric, 17 | std::vector &unmatched_trks, 18 | std::vector &unmatched_dets, 19 | std::vector> &matched); 20 | 21 | template 22 | class TrackerManager { 23 | public: 24 | explicit TrackerManager(std::vector &data, const std::array &dim) 25 | : data(data), img_box(0, 0, dim[1], dim[0]) {} 26 | 27 | void predict() { 28 | for (auto &t:data) { 29 | t.kalman.predict(); 30 | } 31 | } 32 | 33 | void remove_nan() { 34 | data.erase(remove_if(data.begin(), data.end(), 35 | [](const TrackData &t) { 36 | auto bbox = t.kalman.rect(); 37 | return std::isnan(bbox.x) || std::isnan(bbox.y) || 38 | std::isnan(bbox.width) || std::isnan(bbox.height); 39 | }), 40 | data.end()); 41 | } 42 | 43 | void remove_deleted() { 44 | data.erase(remove_if(data.begin(), data.end(), 45 | [this](const TrackData &t) { 46 | return t.kalman.state() == TrackState::Deleted; 47 | }), data.end()); 48 | } 49 | 50 | std::vector> 51 | update(const std::vector &dets, 52 | const DistanceMetricFunc &confirmed_metric, const DistanceMetricFunc &unconfirmed_metric) { 53 | std::vector unmatched_trks; 54 | for (size_t i = 0; i < data.size(); ++i) { 55 | if (data[i].kalman.state() == TrackState::Confirmed) { 56 | unmatched_trks.emplace_back(i); 57 | } 58 | } 59 | 60 | std::vector unmatched_dets(dets.size()); 61 | iota(unmatched_dets.begin(), unmatched_dets.end(), 0); 62 | 63 | std::vector> matched; 64 | 65 | associate_detections_to_trackers_idx(confirmed_metric, unmatched_trks, unmatched_dets, matched); 66 | 67 | for (size_t i = 0; i < data.size(); ++i) { 68 | if (data[i].kalman.state() == TrackState::Tentative) { 69 | unmatched_trks.emplace_back(i); 70 | } 71 | } 72 | 73 | associate_detections_to_trackers_idx(unconfirmed_metric, unmatched_trks, unmatched_dets, matched); 74 | 75 | for (auto i : unmatched_trks) { 76 | data[i].kalman.miss(); 77 | } 78 | 79 | // update matched trackers with assigned detections. 80 | // each prediction is corresponding to a manager 81 | for (auto[x, y] : matched) { 82 | data[x].kalman.update(dets[y]); 83 | } 84 | 85 | // create and initialise new trackers for unmatched detections 86 | for (auto umd : unmatched_dets) { 87 | matched.emplace_back(data.size(), umd); 88 | auto t = TrackData{}; 89 | t.kalman.init(dets[umd]); 90 | data.emplace_back(t); 91 | } 92 | 93 | return matched; 94 | } 95 | 96 | std::vector visible_tracks() { 97 | std::vector ret; 98 | for (auto &t : data) { 99 | auto bbox = t.kalman.rect(); 100 | if (t.kalman.state() == TrackState::Confirmed && 101 | img_box.contains(bbox.tl()) && img_box.contains(bbox.br())) { 102 | Track res{t.kalman.id(), bbox}; 103 | ret.push_back(res); 104 | } 105 | } 106 | return ret; 107 | } 108 | 109 | private: 110 | std::vector &data; 111 | const cv::Rect2f img_box; 112 | }; 113 | 114 | #endif //TRACKER_H 115 | -------------------------------------------------------------------------------- /tracking/src/nn_matching.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "nn_matching.h" 4 | 5 | using namespace std; 6 | using namespace cv; 7 | 8 | namespace { 9 | float iou(const Rect2f &bb_test, const Rect2f &bb_gt) { 10 | auto in = (bb_test & bb_gt).area(); 11 | auto un = bb_test.area() + bb_gt.area() - in; 12 | 13 | if (un < DBL_EPSILON) 14 | return 0; 15 | 16 | return in / un; 17 | } 18 | } 19 | 20 | torch::Tensor iou_dist(const vector &dets, const vector &trks) { 21 | auto trk_num = trks.size(); 22 | auto det_num = dets.size(); 23 | auto dist = torch::empty({int64_t(trk_num), int64_t(det_num)}); 24 | for (size_t i = 0; i < trk_num; i++) // compute iou matrix as a distance matrix 25 | { 26 | for (size_t j = 0; j < det_num; j++) { 27 | dist[i][j] = 1 - iou(trks[i], dets[j]); 28 | } 29 | } 30 | return dist; 31 | } 32 | -------------------------------------------------------------------------------- /tracking/src/nn_matching.h: -------------------------------------------------------------------------------- 1 | #ifndef NN_MATCHING_H 2 | #define NN_MATCHING_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | torch::Tensor iou_dist(const std::vector &dets, const std::vector &trks); 10 | 11 | // save features of the track in GPU 12 | class FeatureBundle { 13 | public: 14 | FeatureBundle() : full(false), next(0), store(torch::empty({budget, feat_dim}).cuda()) {} 15 | 16 | void clear() { 17 | next = 0; 18 | full = false; 19 | } 20 | 21 | bool empty() const { 22 | return next == 0 && !full; 23 | } 24 | 25 | void add(torch::Tensor feat) { 26 | if (next == budget) { 27 | full = true; 28 | next = 0; 29 | } 30 | store[next++] = feat; 31 | } 32 | 33 | torch::Tensor get() const { 34 | return full ? store : store.slice(0, 0, next); 35 | } 36 | 37 | private: 38 | static const int64_t budget = 100, feat_dim = 512; 39 | 40 | torch::Tensor store; 41 | 42 | bool full; 43 | int64_t next; 44 | }; 45 | 46 | template 47 | class FeatureMetric { 48 | public: 49 | explicit FeatureMetric(std::vector &data) : data(data) {} 50 | 51 | torch::Tensor distance(torch::Tensor features, const std::vector &targets) { 52 | auto dist = torch::empty({int64_t(targets.size()), features.size(0)}); 53 | if (features.size(0)) { 54 | for (size_t i = 0; i < targets.size(); ++i) { 55 | dist[i] = nn_cosine_distance(data[targets[i]].feats.get(), features); 56 | } 57 | } 58 | 59 | return dist; 60 | } 61 | 62 | void update(torch::Tensor feats, const std::vector &targets) { 63 | for (size_t i = 0; i < targets.size(); ++i) { 64 | data[targets[i]].feats.add(feats[i]); 65 | } 66 | } 67 | 68 | private: 69 | std::vector &data; 70 | 71 | torch::Tensor nn_cosine_distance(torch::Tensor x, torch::Tensor y) { 72 | return std::get<0>(torch::min(1 - torch::matmul(x, y.t()), 0)).cpu(); 73 | } 74 | }; 75 | 76 | #endif //NN_MATCHING_H 77 | --------------------------------------------------------------------------------