├── asset ├── teaser.png ├── ov_ch_50k.png ├── ov_ch_gt.png ├── ov_ch_in.png ├── path_ch_50k.png ├── path_ch_gt.png ├── path_ch_in.png ├── vec_39693_t.png ├── vec_39693_in.png ├── vec_39693_out.png ├── vec_39693_overlap.png └── vec_39693.svg ├── .gitignore ├── gco ├── build_linux.sh ├── build_win.bat ├── CMakeLists.txt ├── instances.inc ├── CHANGES.TXT ├── LinkedBlockList.cpp ├── LinkedBlockList.h ├── graph.cpp ├── main.cpp ├── QPBO_postprocessing.cpp ├── energy.h ├── block.h ├── QPBO_maxflow.cpp ├── graph.h └── maxflow.cpp ├── train.bat ├── models.py ├── main.py ├── test.bat ├── config.py ├── utils.py ├── README.md ├── trainer.py ├── ops.py ├── data_qdraw.py ├── data_ch.py ├── data_kanji.py └── data_line.py /asset/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/vectornet/HEAD/asset/teaser.png -------------------------------------------------------------------------------- /asset/ov_ch_50k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/vectornet/HEAD/asset/ov_ch_50k.png -------------------------------------------------------------------------------- /asset/ov_ch_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/vectornet/HEAD/asset/ov_ch_gt.png -------------------------------------------------------------------------------- /asset/ov_ch_in.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/vectornet/HEAD/asset/ov_ch_in.png -------------------------------------------------------------------------------- /asset/path_ch_50k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/vectornet/HEAD/asset/path_ch_50k.png -------------------------------------------------------------------------------- /asset/path_ch_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/vectornet/HEAD/asset/path_ch_gt.png -------------------------------------------------------------------------------- /asset/path_ch_in.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/vectornet/HEAD/asset/path_ch_in.png -------------------------------------------------------------------------------- /asset/vec_39693_t.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/vectornet/HEAD/asset/vec_39693_t.png -------------------------------------------------------------------------------- /asset/vec_39693_in.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/vectornet/HEAD/asset/vec_39693_in.png -------------------------------------------------------------------------------- /asset/vec_39693_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/vectornet/HEAD/asset/vec_39693_out.png -------------------------------------------------------------------------------- /asset/vec_39693_overlap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/vectornet/HEAD/asset/vec_39693_overlap.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.suo 3 | *.o 4 | *~ 5 | *.sql 6 | *.db 7 | *.opendb 8 | *.user 9 | log/ 10 | data/ 11 | potrace/ 12 | gco/build -------------------------------------------------------------------------------- /gco/build_linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p build 4 | cd build 5 | 6 | cmake -G"Unix Makefiles" -D CMAKE_BUILD_TYPE="Release" .. 7 | 8 | make -j4 -------------------------------------------------------------------------------- /gco/build_win.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | set CMAKE_GENERATOR="Visual Studio 14 2015 Win64" 4 | set MSBUILD="%ProgramFiles(x86)%\MSBuild\14.0\Bin\amd64\MSBuild.exe" 5 | 6 | mkdir build 7 | cd build 8 | 9 | cmake -G%CMAKE_GENERATOR% .. 10 | 11 | %MSBUILD% /p:Configuration=Release gco.sln -------------------------------------------------------------------------------- /gco/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8) 2 | 3 | project(gco) 4 | 5 | SET(SOURCES 6 | LinkedBlockList.cpp 7 | graph.cpp 8 | maxflow.cpp 9 | QPBO.cpp 10 | QPBO_extra.cpp 11 | QPBO_maxflow.cpp 12 | QPBO_postprocessing.cpp 13 | GCoptimization.cpp 14 | main.cpp 15 | ) 16 | 17 | SET(HEADERS 18 | LinkedBlockList.h 19 | graph.h 20 | block.h 21 | QPBO.h 22 | GCoptimization.h 23 | ) 24 | 25 | include_directories("./") 26 | add_executable(gco ${SOURCES} ${HEADERS}) 27 | -------------------------------------------------------------------------------- /train.bat: -------------------------------------------------------------------------------- 1 | REM pathnet 2 | python main.py --archi=path --tag=win --dataset=line 3 | python main.py --archi=path --tag=win --dataset=ch 4 | python main.py --archi=path --tag=win --dataset=kanji 5 | python main.py --archi=path --tag=win --dataset=baseball --height=128 --width=128 --lr=0.002 6 | python main.py --archi=path --tag=win --dataset=cat --height=128 --width=128 --lr=0.002 7 | 8 | REM overlapnet 9 | python main.py --archi=overlap --tag=win --dataset=line 10 | python main.py --archi=overlap --tag=win --dataset=ch 11 | python main.py --archi=overlap --tag=win --dataset=kanji 12 | python main.py --archi=overlap --tag=win --dataset=baseball --height=128 --width=128 --lr=0.002 13 | python main.py --archi=overlap --tag=win --dataset=cat --height=128 --width=128 --lr=0.002 -------------------------------------------------------------------------------- /gco/instances.inc: -------------------------------------------------------------------------------- 1 | #include "QPBO.h" 2 | 3 | #ifdef _MSC_VER 4 | #pragma warning(disable: 4661) 5 | #endif 6 | 7 | // Instantiations 8 | 9 | template class QPBO; 10 | template class QPBO; 11 | template class QPBO; 12 | 13 | template <> 14 | inline void QPBO::get_type_information(const char*& type_name, const char*& type_format) 15 | { 16 | type_name = "int"; 17 | type_format = "d"; 18 | } 19 | 20 | template <> 21 | inline void QPBO::get_type_information(const char*& type_name, const char*& type_format) 22 | { 23 | type_name = "float"; 24 | type_format = "f"; 25 | } 26 | 27 | template <> 28 | inline void QPBO::get_type_information(const char*& type_name, const char*& type_format) 29 | { 30 | type_name = "double"; 31 | type_format = "lf"; 32 | } 33 | 34 | 35 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from ops import * 4 | slim = tf.contrib.slim 5 | 6 | def VDSR(x, hidden_num, repeat_num, data_format, use_norm, name='VDSR', 7 | k=3, train=True, reuse=False): 8 | with tf.variable_scope(name, reuse=reuse) as vs: 9 | for i in range(repeat_num-1): 10 | x = conv2d(x, hidden_num, data_format, k=k, s=1, act=tf.nn.relu) 11 | if use_norm: 12 | x = batch_norm(x, train, data_format, act=tf.nn.relu) 13 | 14 | x = conv2d(x, 1, data_format, k=k, s=1) 15 | if use_norm: 16 | x = batch_norm(x, train, data_format) 17 | out = tf.nn.relu(x) 18 | variables = tf.contrib.framework.get_variables(vs) 19 | return out, variables 20 | 21 | def main(_): 22 | b_num = 16 23 | h = 64 24 | w = 64 25 | ch_num = 2 26 | 27 | data_format = 'NCHW' 28 | 29 | x = tf.placeholder(dtype=tf.float32, shape=[b_num, h, w, ch_num]) 30 | if data_format == 'NCHW': 31 | x = nhwc_to_nchw(x) 32 | 33 | model = 1 34 | if model == 1: 35 | hidden_num = 64 36 | repeat_num = 20 37 | use_norm = True 38 | y = VDSR(x, hidden_num, repeat_num, data_format, use_norm) 39 | else: 40 | hidden_num = 128 # 128 41 | repeat_num = 16 # 16 42 | y = EDSR(x, hidden_num, repeat_num, data_format) 43 | show_all_variables() 44 | 45 | if __name__ == '__main__': 46 | tf.app.run() -------------------------------------------------------------------------------- /gco/CHANGES.TXT: -------------------------------------------------------------------------------- 1 | QPBO, version 1.4. 2 | 3 | Changes from version 1.32: 4 | - put under GPL license 5 | 6 | Changes from version 1.31: 7 | - made it compile without warnings (on g++ 4.6.3) 8 | - fixed an issue in Save() and Load() 9 | 10 | Changes from version 1.3: 11 | - fixed a bug in Improve(): the value INFTY used for 'fixing' nodes could have 12 | been underestimated. 13 | Thanks to Yu Miao for pointing this out. 14 | 15 | Changes from version 1.2: 16 | - fixed a bug: MergeParallelEdges() followed by Probe() may have worked incorrectly. 17 | Details: edges freed by MergeParallelEdges() are added to a list of "free arcs". 18 | These free arcs may then be used when Probe() needs to add a new pairwise term. 19 | However, there was an inconsistency between how MergeParallelEdges() marks free arcs, 20 | and how free arcs are treated in AddPairwiseTerm(). The result may have been a segmentation fault. 21 | Thanks to Lena Gorelick for pointing this out. 22 | 23 | Changes from version 1.1: 24 | - updated to make it compile under gcc 4.1.2. 25 | 26 | Changes from version 1.0: 27 | - fixed a bug in Probe(). (Thanks to Tian Taipeng for noticing that there is a bug). 28 | 29 | Details: In version 1.1 the transformed energy after calling Probe() was incorrect. 30 | As a result, the option ProbeOptions::weak_persistencies=1 was not working correctly, 31 | since it called the main probing function iteratively. 32 | 33 | - Added new function Improve() (without arguments), which generates a random permutation itself. 34 | 35 | 36 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from config import get_config 5 | from utils import prepare_dirs_and_logger, save_config 6 | 7 | def main(config): 8 | prepare_dirs_and_logger(config) 9 | save_config(config) 10 | 11 | if config.is_train: 12 | from trainer import Trainer 13 | if config.dataset == 'line': 14 | from data_line import BatchManager 15 | elif config.dataset == 'ch': 16 | from data_ch import BatchManager 17 | elif config.dataset == 'kanji': 18 | from data_kanji import BatchManager 19 | elif config.dataset == 'baseball' or\ 20 | config.dataset == 'cat': 21 | from data_qdraw import BatchManager 22 | 23 | batch_manager = BatchManager(config) 24 | trainer = Trainer(config, batch_manager) 25 | trainer.train() 26 | else: 27 | from tester import Tester 28 | if config.dataset == 'line': 29 | from data_line import BatchManager 30 | elif config.dataset == 'ch': 31 | from data_ch import BatchManager 32 | elif config.dataset == 'kanji': 33 | from data_kanji import BatchManager 34 | elif config.dataset == 'baseball' or\ 35 | config.dataset == 'cat': 36 | from data_qdraw import BatchManager 37 | 38 | batch_manager = BatchManager(config) 39 | tester = Tester(config, batch_manager) 40 | tester.test() 41 | 42 | if __name__ == "__main__": 43 | config, unparsed = get_config() 44 | main(config) 45 | -------------------------------------------------------------------------------- /test.bat: -------------------------------------------------------------------------------- 1 | python main.py --is_train=False --dataset=ch --load_pathnet=log/path/ch_1231_170036_win --find_overlap=False --tag=nv 2 | python main.py --is_train=False --dataset=ch --load_pathnet=log/path/ch_1231_170036_win --load_overlapnet=log/overlap/ch_0101_012450_win --tag=ov 3 | 4 | python main.py --is_train=False --dataset=kanji --load_pathnet=log/path/kanji_1231_175226_win --find_overlap=False --tag=nv 5 | python main.py --is_train=False --dataset=kanji --load_pathnet=log/path/kanji_1231_175226_win --load_overlapnet=log/overlap/kanji_0101_035424_win --tag=ov 6 | 7 | python main.py --is_train=False --dataset=line --load_pathnet=log/path/line_1231_162901_win --find_overlap=False --tag=nv 8 | python main.py --is_train=False --dataset=line --load_pathnet=log/path/line_1231_162901_win --load_overlapnet=log/overlap/line_0101_003631_win --tag=ov 9 | 10 | python main.py --is_train=False --dataset=baseball --load_pathnet=log/path/baseball_1231_185540_win --find_overlap=False --tag=nv --height=128 --width=128 --neighbor_sample=0.02 --test_batch_size=256 11 | python main.py --is_train=False --dataset=baseball --load_pathnet=log/path/baseball_1231_185540_win --load_overlapnet=log/overlap/baseball_0101_084611_win --tag=ov --height=128 --width=128 --neighbor_sample=0.02 --test_batch_size=256 12 | 13 | python main.py --is_train=False --dataset=cat --load_pathnet=log/path/cat_1231_205036_win --find_overlap=False --tag=nv --height=128 --width=128 --neighbor_sample=0.02 --test_batch_size=256 14 | python main.py --is_train=False --dataset=cat --load_pathnet=log/path/cat_1231_205036_win --load_overlapnet=log/overlap/cat_0101_104304_win --tag=ov --height=128 --width=128 --neighbor_sample=0.02 --test_batch_size=256 -------------------------------------------------------------------------------- /gco/LinkedBlockList.cpp: -------------------------------------------------------------------------------- 1 | #include "LinkedBlockList.h" 2 | #include 3 | #include 4 | 5 | /*********************************************************************/ 6 | 7 | void LinkedBlockList::addFront(ListType item) { 8 | 9 | if ( m_head_block_size == GCLL_BLOCK_SIZE ) 10 | { 11 | LLBlock *tmp = (LLBlock *) new LLBlock; 12 | if ( !tmp ) {printf("\nOut of memory");exit(1);} 13 | tmp -> m_next = m_head; 14 | m_head = tmp; 15 | m_head_block_size = 0; 16 | } 17 | 18 | m_head ->m_item[m_head_block_size] = item; 19 | m_head_block_size++; 20 | } 21 | 22 | /*********************************************************************/ 23 | 24 | ListType LinkedBlockList::next() 25 | { 26 | ListType toReturn = m_cursor -> m_item[m_cursor_ind]; 27 | 28 | m_cursor_ind++; 29 | 30 | if ( m_cursor == m_head && m_cursor_ind >= m_head_block_size ) 31 | { 32 | m_cursor = m_cursor ->m_next; 33 | m_cursor_ind = 0; 34 | } 35 | else if ( m_cursor_ind == GCLL_BLOCK_SIZE ) 36 | { 37 | m_cursor = m_cursor ->m_next; 38 | m_cursor_ind = 0; 39 | } 40 | return(toReturn); 41 | } 42 | 43 | /*********************************************************************/ 44 | 45 | bool LinkedBlockList::hasNext() 46 | { 47 | if ( m_cursor != 0 ) return (true); 48 | else return(false); 49 | } 50 | 51 | 52 | /*********************************************************************/ 53 | 54 | LinkedBlockList::~LinkedBlockList() 55 | { 56 | LLBlock *tmp; 57 | 58 | while ( m_head != 0 ) 59 | { 60 | tmp = m_head; 61 | m_head = m_head->m_next; 62 | delete tmp; 63 | } 64 | }; 65 | 66 | /*********************************************************************/ 67 | 68 | -------------------------------------------------------------------------------- /gco/LinkedBlockList.h: -------------------------------------------------------------------------------- 1 | /* Singly Linked List of Blocks */ 2 | // This data structure should be used only for the GCoptimization class implementation 3 | // because it lucks some important general functions for general list, like remove_item() 4 | // The head block may be not full 5 | // For regular 2D grids, it's better to set GCLL_BLOCK_SIZE to 2 6 | // For other graphs, it should be set to the average expected number of neighbors 7 | // Data in linked list for the neighborhood system is allocated in blocks of size GCLL_BLOCK_SIZE 8 | 9 | #ifndef __LINKEDBLOCKLIST_H__ 10 | #define __LINKEDBLOCKLIST_H__ 11 | 12 | #define GCLL_BLOCK_SIZE 4 13 | // GCLL_BLOCKSIZE should "fit" into the type BlockType. That is 14 | // if GCLL_BLOCKSIZE is larger than 255 but smaller than largest short integer 15 | // then BlockType should be set to short 16 | typedef char BlockType; 17 | 18 | //The type of data stored in the linked list 19 | typedef void * ListType; 20 | 21 | class LinkedBlockList{ 22 | 23 | public: 24 | void addFront(ListType item); 25 | inline bool isEmpty(){if (m_head == 0) return(true); else return(false);}; 26 | inline LinkedBlockList(){m_head = 0; m_head_block_size = GCLL_BLOCK_SIZE;}; 27 | ~LinkedBlockList(); 28 | 29 | // Next three functins are for the linked list traversal 30 | inline void setCursorFront(){m_cursor = m_head; m_cursor_ind = 0;}; 31 | ListType next(); 32 | bool hasNext(); 33 | 34 | private: 35 | typedef struct LLBlockStruct{ 36 | ListType m_item[GCLL_BLOCK_SIZE]; 37 | struct LLBlockStruct *m_next; 38 | } LLBlock; 39 | 40 | LLBlock *m_head; 41 | // Remembers the number of elements in the head block, since it may not be full 42 | BlockType m_head_block_size; 43 | // For block traversal, points to current element in the current block 44 | BlockType m_cursor_ind; 45 | // For block traversal, points to current block in the linked list 46 | LLBlock *m_cursor; 47 | }; 48 | 49 | #endif 50 | 51 | -------------------------------------------------------------------------------- /gco/graph.cpp: -------------------------------------------------------------------------------- 1 | /* graph.cpp */ 2 | 3 | 4 | #include 5 | #include 6 | #include 7 | #include "graph.h" 8 | 9 | 10 | template 11 | Graph::Graph(int node_num_max, int edge_num_max, void (*err_function)(const char *)) 12 | : node_num(0), 13 | nodeptr_block(NULL), 14 | error_function(err_function) 15 | { 16 | if (node_num_max < 16) node_num_max = 16; 17 | if (edge_num_max < 16) edge_num_max = 16; 18 | 19 | nodes = (node*) malloc(node_num_max*sizeof(node)); 20 | arcs = (arc*) malloc(2*edge_num_max*sizeof(arc)); 21 | if (!nodes || !arcs) { if (error_function) (*error_function)("Not enough memory!"); exit(1); } 22 | 23 | node_last = nodes; 24 | node_max = nodes + node_num_max; 25 | arc_last = arcs; 26 | arc_max = arcs + 2*edge_num_max; 27 | 28 | maxflow_iteration = 0; 29 | flow = 0; 30 | } 31 | 32 | template 33 | Graph::~Graph() 34 | { 35 | if (nodeptr_block) 36 | { 37 | delete nodeptr_block; 38 | nodeptr_block = NULL; 39 | } 40 | free(nodes); 41 | free(arcs); 42 | } 43 | 44 | template 45 | void Graph::reset() 46 | { 47 | node_last = nodes; 48 | arc_last = arcs; 49 | node_num = 0; 50 | 51 | if (nodeptr_block) 52 | { 53 | delete nodeptr_block; 54 | nodeptr_block = NULL; 55 | } 56 | 57 | maxflow_iteration = 0; 58 | flow = 0; 59 | } 60 | 61 | template 62 | void Graph::reallocate_nodes(int num) 63 | { 64 | int node_num_max = (int)(node_max - nodes); 65 | node* nodes_old = nodes; 66 | 67 | node_num_max += node_num_max / 2; 68 | if (node_num_max < node_num + num) node_num_max = node_num + num; 69 | nodes = (node*) realloc(nodes_old, node_num_max*sizeof(node)); 70 | if (!nodes) { if (error_function) (*error_function)("Not enough memory!"); exit(1); } 71 | 72 | node_last = nodes + node_num; 73 | node_max = nodes + node_num_max; 74 | 75 | if (nodes != nodes_old) 76 | { 77 | arc* a; 78 | for (a=arcs; ahead = (node*) ((char*)a->head + (((char*) nodes) - ((char*) nodes_old))); 81 | } 82 | } 83 | } 84 | 85 | template 86 | void Graph::reallocate_arcs() 87 | { 88 | int arc_num_max = (int)(arc_max - arcs); 89 | int arc_num = (int)(arc_last - arcs); 90 | arc* arcs_old = arcs; 91 | 92 | arc_num_max += arc_num_max / 2; if (arc_num_max & 1) arc_num_max ++; 93 | arcs = (arc*) realloc(arcs_old, arc_num_max*sizeof(arc)); 94 | if (!arcs) { if (error_function) (*error_function)("Not enough memory!"); exit(1); } 95 | 96 | arc_last = arcs + arc_num; 97 | arc_max = arcs + arc_num_max; 98 | 99 | if (arcs != arcs_old) 100 | { 101 | node* i; 102 | arc* a; 103 | for (i=nodes; ifirst) i->first = (arc*) ((char*)i->first + (((char*) arcs) - ((char*) arcs_old))); 106 | } 107 | for (a=arcs; anext) a->next = (arc*) ((char*)a->next + (((char*) arcs) - ((char*) arcs_old))); 110 | a->sister = (arc*) ((char*)a->sister + (((char*) arcs) - ((char*) arcs_old))); 111 | } 112 | } 113 | } 114 | 115 | -------------------------------------------------------------------------------- /gco/main.cpp: -------------------------------------------------------------------------------- 1 | ////////////////////////////////////////////////////////////////////////////// 2 | // Example illustrating the use of GCoptimization.cpp 3 | // 4 | ///////////////////////////////////////////////////////////////////////////// 5 | 6 | #include 7 | #include 8 | #include 9 | #include "GCoptimization.h" 10 | 11 | float smoothFn(int p1, int p2, int l1, int l2, void *data) 12 | { 13 | float **pred = reinterpret_cast(data); 14 | //float avg_pred = 0.5 * (pred[p1][p2] + pred[p2][p1]); 15 | if (p1 > p2) { 16 | int tmp = p2; 17 | p2 = p1; 18 | p1 = tmp; 19 | } 20 | float avg_pred = pred[p1][p2]; 21 | float pred_distance = (l1 == l2) ? (1 - avg_pred) : avg_pred; 22 | //return int(pred_distance * 1000); 23 | return pred_distance; 24 | } 25 | 26 | int main(int argc, char **argv) 27 | { 28 | //std::cout << argv[1] << std::endl; 29 | std::ifstream is(argv[1]); 30 | 31 | if (!is.is_open()) { 32 | std::cout << "Unable to open pred file" << std::endl; 33 | return -1; 34 | } 35 | 36 | std::string pred_file_path, data_dir; 37 | int n_labels, n_sites, label_cost; 38 | float neighbor_sigma, prediction_sigma; 39 | is >> pred_file_path; 40 | is >> data_dir; 41 | is >> n_labels; 42 | is >> label_cost; 43 | is >> neighbor_sigma; 44 | is >> prediction_sigma; 45 | is >> n_sites; 46 | 47 | //std::cout << "pred_file_path:" << pred_file_path << std::endl; 48 | //std::cout << "data_dir:" << data_dir << std::endl; 49 | //std::cout << "n_labels:" << n_labels << std::endl; 50 | //std::cout << "label_cost:" << label_cost << std::endl; 51 | //std::cout << "neighbor_sigma:" << neighbor_sigma << std::endl; 52 | //std::cout << "pred_sigma:" << pred_sigma << std::endl; 53 | //std::cout << "n_sites:" << n_sites << std::endl; 54 | 55 | float **pred = new float*[n_sites]; 56 | for (int i = 0; i < n_sites; ++i) { 57 | pred[i] = new float[n_sites](); 58 | } 59 | float **w = new float*[n_sites]; 60 | for (int i = 0; i < n_sites; ++i) { 61 | w[i] = new float[n_sites](); 62 | } 63 | 64 | while (is.good()) { 65 | int i, j; 66 | float p, spatial; 67 | is >> i >> j >> p >> spatial; 68 | //std::cout << i << " " << j << " " << p << " " << spatial << std::endl; 69 | pred[i][j] = p; 70 | w[i][j] = spatial; 71 | } 72 | 73 | // std::cout << "0 1 " << pred[0][1] << " " << w[0][1] << std::endl; 74 | // std::cout << "0 2 " << pred[0][2] << " " << w[0][2] << std::endl; 75 | 76 | // is.close(); 77 | // return 0; 78 | 79 | 80 | int n_iters = 3; 81 | float *data = new float[n_sites*n_labels](); 82 | int *labels = new int[n_sites](); 83 | 84 | try { 85 | GCoptimizationGeneralGraph *gc = new GCoptimizationGeneralGraph(n_sites, n_labels); 86 | gc->setDataCost(data); 87 | gc->setSmoothCost(smoothFn, (void*)pred); 88 | for (int i = 0; i < n_sites - 1; ++i) { 89 | for (int j = i + 1; j < n_sites; ++j) { 90 | gc->setNeighbors(i, j, w[i][j]); 91 | } 92 | } 93 | gc->setLabelCost(label_cost); 94 | gc->setLabelOrder(true); 95 | 96 | std::string label_file_path = argv[1]; 97 | label_file_path.replace(label_file_path.end() - 5, label_file_path.end(), ".label"); 98 | std::ofstream os(label_file_path.c_str()); 99 | if (!os.is_open()) { 100 | std::cout << "Unable to open label file" << std::endl; 101 | return -1; 102 | } 103 | 104 | // printf("\nBefore optimization energy is %f", gc->compute_energy()); 105 | os << gc->compute_energy() << std::endl; 106 | // gc->expansion(n_iters); 107 | gc->swap(n_iters); 108 | // gc->fusion(n_iters); 109 | os << gc->compute_energy() << std::endl; 110 | // printf("\nAfter optimization energy is %f", gc->compute_energy()); 111 | 112 | for (int i = 0; i < n_sites; i++) { 113 | labels[i] = gc->whatLabel(i); 114 | //printf("\nLabel %d: %d", i, labels[i]); 115 | os << labels[i] << " "; 116 | } 117 | //printf("\n"); 118 | 119 | os.close(); 120 | delete gc; 121 | } 122 | catch (GCException e) { 123 | e.Report(); 124 | } 125 | 126 | delete[] labels; 127 | for (int i = 0; i < n_sites; ++i) { 128 | delete[] pred[i]; 129 | delete[] w[i]; 130 | } 131 | delete[] pred; 132 | delete[] w; 133 | delete[] data; 134 | 135 | return 0; 136 | } -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | import argparse 3 | 4 | def str2bool(v): 5 | return v.lower() in ('true', '1') 6 | 7 | arg_lists = [] 8 | parser = argparse.ArgumentParser() 9 | 10 | def add_argument_group(name): 11 | arg = parser.add_argument_group(name) 12 | arg_lists.append(arg) 13 | return arg 14 | 15 | # Network 16 | net_arg = add_argument_group('Network') 17 | net_arg.add_argument('--width', type=int, default=64) 18 | net_arg.add_argument('--height', type=int, default=64) 19 | net_arg.add_argument('--conv_hidden_num', type=int, default=64, 20 | choices=[64, 128, 256]) 21 | net_arg.add_argument('--repeat_num', type=int, default=20, 22 | choices=[16, 20, 32]) 23 | net_arg.add_argument('--use_l2', type=str2bool, default=True) 24 | net_arg.add_argument('--use_norm', type=str2bool, default=True) 25 | net_arg.add_argument('--archi', type=str, default='path', 26 | choices=['path','overlap']) 27 | 28 | # Data 29 | data_arg = add_argument_group('Data') 30 | data_arg.add_argument('--data_dir', type=str, default='data') 31 | data_arg.add_argument('--dataset', type=str, default='line', 32 | choices=['line','ch','kanji','baseball','cat']) 33 | data_arg.add_argument('--batch_size', type=int, default=8) 34 | data_arg.add_argument('--num_worker', type=int, default=16) 35 | # line 36 | data_arg.add_argument('--num_strokes', type=int, default=4) 37 | data_arg.add_argument('--stroke_type', type=int, default=2) 38 | data_arg.add_argument('--min_length', type=int, default=10) 39 | data_arg.add_argument('--max_stroke_width', type=int, default=2) # 4 for varying w. 40 | 41 | # Training / test parameters 42 | train_arg = add_argument_group('Training') 43 | train_arg.add_argument('--is_train', type=str2bool, default=True) 44 | train_arg.add_argument('--use_gpu', type=str2bool, default=True) 45 | train_arg.add_argument('--gpu_id', type=str, default='0') 46 | train_arg.add_argument('--start_step', type=int, default=0) 47 | train_arg.add_argument('--max_step', type=int, default=50000) # 2000 48 | train_arg.add_argument('--lr_update_step', type=int, default=20000) 49 | train_arg.add_argument('--lr', type=float, default=0.005) 50 | train_arg.add_argument('--lr_lower_boundary', type=float, default=0.00001) 51 | train_arg.add_argument('--optimizer', type=str, default='adam') 52 | train_arg.add_argument('--beta1', type=float, default=0.5) 53 | train_arg.add_argument('--beta2', type=float, default=0.999) 54 | 55 | # vectorize 56 | vect_arg = add_argument_group('Vectorize') 57 | vect_arg.add_argument('--load_pathnet', type=str, default='') 58 | vect_arg.add_argument('--load_overlapnet', type=str, default='') 59 | vect_arg.add_argument('--num_test', type=int, default=100) 60 | vect_arg.add_argument('--max_label', type=int, default=128) 61 | vect_arg.add_argument('--label_cost', type=int, default=0) 62 | vect_arg.add_argument('--sigma_neighbor', type=float, default=8.0) 63 | vect_arg.add_argument('--sigma_predict', type=float, default=0.7) 64 | vect_arg.add_argument('--neighbor_sample', type=float, default=1) 65 | vect_arg.add_argument('--find_overlap', type=str2bool, default=True) 66 | vect_arg.add_argument('--overlap_threshold', type=float, default=0.5) 67 | vect_arg.add_argument('--test_batch_size', type=int, default=512) 68 | vect_arg.add_argument('--mp', type=str2bool, default=True) 69 | 70 | # Misc 71 | misc_arg = add_argument_group('Misc') 72 | misc_arg.add_argument('--log_step', type=int, default=100) 73 | misc_arg.add_argument('--test_step', type=int, default=10000) # 1000 74 | misc_arg.add_argument('--save_sec', type=int, default=900) 75 | misc_arg.add_argument('--log_dir', type=str, default='log') 76 | misc_arg.add_argument('--tag', type=str, default='test') 77 | misc_arg.add_argument('--random_seed', type=int, default=123) 78 | 79 | 80 | def get_config(): 81 | config, unparsed = parser.parse_known_args() 82 | 83 | # import os 84 | # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # so the IDs match nvidia-smi 85 | # os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_id # "0, 1" for multiple 86 | 87 | if config.use_gpu: 88 | data_format = 'NCHW' 89 | else: 90 | data_format = 'NHWC' 91 | # data_format = 'NHWC' # for debug 92 | setattr(config, 'data_format', data_format) 93 | return config, unparsed 94 | -------------------------------------------------------------------------------- /gco/QPBO_postprocessing.cpp: -------------------------------------------------------------------------------- 1 | /* QPBO_postprocessing.cpp */ 2 | /* 3 | Copyright 2006-2008 Vladimir Kolmogorov (vnk@ist.ac.at). 4 | 5 | This file is part of QPBO. 6 | 7 | QPBO is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | QPBO is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with QPBO. If not, see . 19 | */ 20 | 21 | 22 | #include 23 | #include 24 | #include 25 | #include "QPBO.h" 26 | 27 | 28 | template 29 | void QPBO::ComputeWeakPersistencies() 30 | { 31 | if (stage == 0) return; 32 | 33 | Node* i; 34 | Node* j; 35 | Node* stack = NULL; 36 | int component; 37 | 38 | for (i=nodes[0]; ilabel>=-1 && i->label<=1); 41 | 42 | Node* i1 = GetMate0(i); 43 | 44 | if (i->label >= 0) 45 | { 46 | i->dfs_parent = i; 47 | i1->dfs_parent = i1; 48 | i->region = i1->region = 0; 49 | } 50 | else 51 | { 52 | i->dfs_parent = i1->dfs_parent = NULL; 53 | i->region = i1->region = -1; 54 | } 55 | } 56 | 57 | // first DFS 58 | for (i=nodes[0]; idfs_parent) continue; 62 | 63 | // DFS starting from i 64 | i->dfs_parent = i; 65 | i->dfs_current = i->first; 66 | while ( 1 ) 67 | { 68 | if (!i->dfs_current) 69 | { 70 | i->next = stack; 71 | stack = i; 72 | 73 | if (i->dfs_parent == i) break; 74 | i = i->dfs_parent; 75 | i->dfs_current = i->dfs_current->next; 76 | continue; 77 | } 78 | 79 | j = i->dfs_current->head; 80 | if (!(i->dfs_current->r_cap>0) || j->dfs_parent) 81 | { 82 | i->dfs_current = i->dfs_current->next; 83 | continue; 84 | } 85 | 86 | j->dfs_parent = i; 87 | i = j; 88 | i->dfs_current = i->first; 89 | } 90 | } 91 | 92 | // second DFS 93 | component = 0; 94 | while ( stack ) 95 | { 96 | i = stack; 97 | stack = i->next; 98 | if (i->region > 0) continue; 99 | 100 | i->region = ++ component; 101 | i->dfs_parent = i; 102 | i->dfs_current = i->first; 103 | while ( 1 ) 104 | { 105 | if (!i->dfs_current) 106 | { 107 | if (i->dfs_parent == i) break; 108 | i = i->dfs_parent; 109 | i->dfs_current = i->dfs_current->next; 110 | continue; 111 | } 112 | 113 | j = i->dfs_current->head; 114 | if (!(i->dfs_current->sister->r_cap>0) || j->region>=0) 115 | { 116 | i->dfs_current = i->dfs_current->next; 117 | continue; 118 | } 119 | 120 | j->dfs_parent = i; 121 | i = j; 122 | i->dfs_current = i->first; 123 | i->region = component; 124 | } 125 | } 126 | 127 | // assigning labels 128 | for (i=nodes[0]; ilabel < 0) 131 | { 132 | code_assert(i->region > 0); 133 | if (i->region > GetMate0(i)->region) { i->label = 0; i->region = 0; } 134 | else if (i->region < GetMate0(i)->region) { i->label = 1; i->region = 0; } 135 | } 136 | else code_assert(i->region == 0); 137 | } 138 | } 139 | 140 | template 141 | void QPBO::Stitch() 142 | { 143 | if (stage == 0) return; 144 | 145 | Node* i; 146 | Node* i_mate; 147 | Node* j; 148 | Arc* a; 149 | Arc* a_mate; 150 | 151 | for (a=arcs[0], a_mate=arcs[1]; asister) 153 | { 154 | a->r_cap = a_mate->r_cap = a->r_cap + a_mate->r_cap; 155 | 156 | i = a->sister->head; 157 | j = a->head; 158 | 159 | if (i->region==0 || i->region != j->region) continue; 160 | if (IsNode0(i)) 161 | { 162 | if (i->user_label != 0) continue; 163 | } 164 | else 165 | { 166 | if (GetMate1(i)->user_label != 1) continue; 167 | } 168 | if (IsNode0(j)) 169 | { 170 | if (j->user_label != 1) continue; 171 | } 172 | else 173 | { 174 | if (GetMate1(j)->user_label != 0) continue; 175 | } 176 | 177 | a->r_cap = a_mate->r_cap = 0; 178 | } 179 | 180 | for (i=nodes[0], i_mate=nodes[1]; itr_cap = i->tr_cap - i_mate->tr_cap; 183 | i_mate->tr_cap = -i->tr_cap; 184 | } 185 | 186 | ComputeWeakPersistencies(); 187 | } 188 | 189 | #include "instances.inc" 190 | -------------------------------------------------------------------------------- /asset/vec_39693.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 7 | 8 | Created by potrace 1.15, written by Peter Selinger 2001-2017 9 | 10 | 12 | 14 | 15 | 17 | 19 | 20 | 22 | 24 | 25 | 27 | 30 | 31 | 33 | 36 | 37 | 39 | 40 | 41 | 43 | 45 | 46 | 48 | 50 | 51 | 53 | 55 | 56 | 58 | 59 | 60 | 62 | 64 | 65 | 67 | 69 | 70 | 72 | 74 | 75 | 77 | 79 | 80 | 82 | 84 | 85 | 87 | 89 | 90 | 92 | 94 | 95 | 97 | 99 | 100 | 102 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import math 5 | import json 6 | import logging 7 | import numpy as np 8 | from PIL import Image 9 | from datetime import datetime 10 | import imageio 11 | from glob import glob 12 | import shutil 13 | 14 | def prepare_dirs_and_logger(config): 15 | # print(__file__) 16 | os.chdir(os.path.dirname(__file__)) 17 | 18 | formatter = logging.Formatter("%(asctime)s:%(levelname)s::%(message)s") 19 | logger = logging.getLogger() 20 | 21 | for hdlr in logger.handlers: 22 | logger.removeHandler(hdlr) 23 | 24 | handler = logging.StreamHandler() 25 | handler.setFormatter(formatter) 26 | 27 | logger.addHandler(handler) 28 | 29 | # data path 30 | config.data_path = os.path.join(config.data_dir, config.dataset) 31 | 32 | # model path 33 | if config.is_train: 34 | model_name = os.path.join(config.archi, '{}_{}_{}'.format( 35 | config.dataset, get_time(), config.tag)) 36 | config.model_dir = os.path.join(config.log_dir, model_name) 37 | else: 38 | model_name = os.path.join('vec', '{}_{}_{}'.format( 39 | config.dataset, get_time(), config.tag)) 40 | config.model_dir = os.path.join(config.log_dir, model_name) 41 | 42 | if not os.path.exists(config.model_dir): 43 | os.makedirs(config.model_dir) 44 | 45 | def get_time(): 46 | return datetime.now().strftime("%m%d_%H%M%S") 47 | 48 | def save_config(config): 49 | param_path = os.path.join(config.model_dir, "params.json") 50 | 51 | print("[*] MODEL dir: %s" % config.model_dir) 52 | print("[*] PARAM path: %s" % param_path) 53 | 54 | with open(param_path, 'w') as fp: 55 | json.dump(config.__dict__, fp, indent=4, sort_keys=True) 56 | 57 | def rank(array): 58 | return len(array.shape) 59 | 60 | def make_grid(tensor, nrow=8, padding=2, 61 | normalize=False, scale_each=False): 62 | """Code based on https://github.com/pytorch/vision/blob/master/torchvision/utils.py""" 63 | nmaps = tensor.shape[0] 64 | xmaps = min(nrow, nmaps) 65 | ymaps = int(math.ceil(float(nmaps) / xmaps)) 66 | height, width = int(tensor.shape[1] + padding), int(tensor.shape[2] + padding) 67 | grid = np.ones([height * ymaps + 1 + padding // 2, width * xmaps + 1 + padding // 2, 3], dtype=np.uint8)*255 68 | k = 0 69 | for y in range(ymaps): 70 | for x in range(xmaps): 71 | if k >= nmaps: 72 | break 73 | h, h_width = y * height + 1 + padding // 2, height - padding 74 | w, w_width = x * width + 1 + padding // 2, width - padding 75 | 76 | grid[h:h+h_width, w:w+w_width] = tensor[k] 77 | k = k + 1 78 | return grid 79 | 80 | def save_image(tensor, filename, nrow=8, padding=2, 81 | normalize=False, scale_each=False, single=False): 82 | if not single: 83 | ndarr = make_grid(tensor, nrow=nrow, padding=padding, 84 | normalize=normalize, scale_each=scale_each) 85 | else: 86 | h, w = tensor.shape[0], tensor.shape[1] 87 | ndarr = np.zeros([h,w,3], dtype=np.uint8) 88 | ndarr[:,:] = tensor[:,:] 89 | 90 | im = Image.fromarray(ndarr) 91 | im.save(filename) 92 | 93 | def convert_png2mp4(imgdir, filename, fps, delete_imgdir=False): 94 | dirname = os.path.dirname(filename) 95 | if not os.path.exists(dirname): 96 | os.makedirs(dirname) 97 | 98 | try: 99 | writer = imageio.get_writer(filename, fps=fps) 100 | except Exception: 101 | imageio.plugins.ffmpeg.download() 102 | writer = imageio.get_writer(filename, fps=fps) 103 | 104 | imgs = sorted(glob("{}/*.png".format(imgdir))) 105 | # print(imgs) 106 | for img in imgs: 107 | im = imageio.imread(img) 108 | writer.append_data(im) 109 | 110 | writer.close() 111 | 112 | if delete_imgdir: shutil.rmtree(imgdir) 113 | 114 | def rf(o, k, stride): # input size from output size 115 | return (o-1)*stride + k 116 | 117 | def receptive_field_size(c, k, s): 118 | if c == 0: 119 | return rf(rf(1, k, 1), k, 1) 120 | else: 121 | rfs = receptive_field_size(c-1, k, s) 122 | print('%d: %d' % (c-1, rfs)) 123 | return rf(rfs, k, s) 124 | 125 | if __name__ == '__main__': 126 | c, k, s = 4, 3, 2 127 | rfs = receptive_field_size(c, k, s) 128 | print('c{}k{}s{} receptive field size'.format(c, k, s), rfs) 129 | 130 | c, k = 3, 3 131 | rfs = receptive_field_size(c, k, s) 132 | print('c{}k{}s{} receptive field size'.format(c, k, s), rfs) 133 | 134 | c, k = 5, 3 135 | rfs = receptive_field_size(c, k, s) 136 | print('c{}k{}s{} receptive field size'.format(c, k, s), rfs) 137 | 138 | c, k = 4, 4 139 | rfs = receptive_field_size(c, k, s) 140 | print('c{}k{}s{} receptive field size'.format(c, k, s), rfs) 141 | 142 | c, k = 3, 4 143 | rfs = receptive_field_size(c, k, s) 144 | print('c{}k{}s{} receptive field size'.format(c, k, s), rfs) 145 | 146 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic Segmentation for Line Drawing Vectorization Using Neural Networks 2 | 3 | Tensorflow implementation of [Semantic Segmentation for Line Drawing Vectorization Using Neural Networks](http://www.byungsoo.me/project/vectornet). 4 | 5 | [Byungsoo Kim¹](http://www.byungsoo.me), [Oliver Wang²](http://www.oliverwang.info), [Cengiz Öztireli¹](https://graphics.ethz.ch/~cengizo), [Markus Gross¹](https://graphics.ethz.ch/people/grossm) 6 | 7 | ¹ETH Zurich, ²Adobe Research 8 | 9 | Computer Graphics Forum (Proceedings of Eurographics 2018) 10 | 11 | ![vectornet](./asset/teaser.png) 12 | 13 | 14 | ## Requirements 15 | 16 | This code is tested on Windows 10 and Ubuntu 16.04 with the following requirements: 17 | 18 | - [anaconda3 / python3.6](https://www.anaconda.com/download/#linux) 19 | - [TensorFlow 1.4](https://github.com/tensorflow/tensorflow) 20 | - [CairoSVG 2.1.2](http://cairosvg.org/) 21 | - [Matplotlib 2.1.0](https://matplotlib.org/) 22 | - [imageio 2.2.0](https://pypi.python.org/pypi/imageio) 23 | - [tqdm](https://github.com/tqdm/tqdm) 24 | - [Potrace](http://potrace.sourceforge.net/) 25 | 26 | After installing anaconda, run `pip install tensorflow-gpu cairosvg matplotlib imageio tqdm`. In case of Potrace, unzip it `(i.e. potrace/potrace.exe)` on Windows or run `sudo apt-get install potrace` on Ubuntu. 27 | 28 | ## Usage 29 | 30 | Download a preprocessed dataset first and unzip it `(i.e. data/ch/train)`. 31 | 32 | - [Chinese](http://gofile.me/6tGZC/zbws8gqEK) [(source)](https://github.com/skishore/makemeahanzi) 33 | - [Kanji](http://gofile.me/6tGZC/R7FWjODa2) [(source)](https://github.com/KanjiVG/kanjivg/releases) 34 | - [Quick Draw](http://gofile.me/6tGZC/VIH81NZJH) [(source)](https://github.com/googlecreativelab/quickdraw-dataset) 35 | - [Random Lines](http://gofile.me/6tGZC/GEoKSdiDc) 36 | 37 | To train PathNet on Chinese characters: 38 | 39 | $ python main.py --is_train=True --archi=path --dataset=ch 40 | 41 | To train OverlapNet on Chinese characters: 42 | 43 | $ python main.py --is_train=True --archi=overlap --dataset=ch 44 | 45 | To vectorize Chinese characters: 46 | 47 | $ .\build_win.bat or ./build_linux.sh 48 | $ python main.py --is_train=False --dataset=ch --load_pathnet=log/path/MODEL_DIR--load_overlapnet=log/overlap/MODEL_DIR 49 | 50 | ## Results 51 | 52 | ### PathNet output (64x64) after 50k steps (From top to bottom: input / output / ground truth) 53 | 54 | ![path_ch_in](./asset/path_ch_in.png) 55 | 56 | ![path_ch_50k](./asset/path_ch_50k.png) 57 | 58 | ![path_ch_gt](./asset/path_ch_gt.png) 59 | 60 | 61 | ### OverlapNet output (64x64) after 50k steps (From top to bottom: input / output / ground truth) 62 | 63 | ![ov_ch_in](./asset/ov_ch_in.png) 64 | 65 | ![ov_ch_50k](./asset/ov_ch_50k.png) 66 | 67 | ![ov_ch_gt](./asset/ov_ch_gt.png) 68 | 69 | 70 | ### Vectorization output (64x64) 71 | 72 | From left to right: input / raster / transparent / overlap / vector 73 | 74 | ![vec_39693_in](./asset/vec_39693_in.png) 75 | ![vec_39693_out](./asset/vec_39693_out.png) 76 | ![vec_39693_t](./asset/vec_39693_t.png) 77 | ![vec_39693_overlap](./asset/vec_39693_overlap.png) 78 | ![vec_39693](./asset/vec_39693.svg) 79 | 80 | 81 | ## Reference 82 | 83 | - Multi-label Optimization: [gco](http://vision.csd.uwo.ca/code/gco-v3.0.zip), [qpbo](http://pub.ist.ac.at/~vnk/software/QPBO-v1.3.src.tar.gz) 84 | - Tensorflow Framework: [carpedm20](https://github.com/carpedm20/BEGAN-tensorflow) 85 | 86 | 134 | 135 | 136 | 137 | 211 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import numpy as np 5 | from tqdm import trange 6 | 7 | from models import * 8 | from utils import save_image 9 | 10 | class Trainer(object): 11 | def __init__(self, config, batch_manager): 12 | tf.set_random_seed(config.random_seed) 13 | self.config = config 14 | self.batch_manager = batch_manager 15 | self.x, self.y = batch_manager.batch() 16 | self.xt = tf.placeholder(tf.float32, shape=int_shape(self.x)) 17 | self.yt = tf.placeholder(tf.float32, shape=int_shape(self.y)) 18 | self.dataset = config.dataset 19 | 20 | self.beta1 = config.beta1 21 | self.beta2 = config.beta2 22 | self.optimizer = config.optimizer 23 | self.batch_size = config.batch_size 24 | 25 | self.lr = tf.Variable(config.lr, name='lr') 26 | self.lr_update = tf.assign(self.lr, tf.maximum(self.lr*0.1, config.lr_lower_boundary), name='lr_update') 27 | 28 | self.height = config.height 29 | self.width = config.width 30 | self.b_num = config.batch_size 31 | self.conv_hidden_num = config.conv_hidden_num 32 | self.repeat_num = config.repeat_num 33 | self.use_l2 = config.use_l2 34 | self.use_norm = config.use_norm 35 | 36 | self.model_dir = config.model_dir 37 | 38 | self.use_gpu = config.use_gpu 39 | self.data_format = config.data_format 40 | if self.data_format == 'NCHW': 41 | self.x = nhwc_to_nchw(self.x) 42 | self.y = nhwc_to_nchw(self.y) 43 | self.xt = nhwc_to_nchw(self.xt) 44 | self.yt = nhwc_to_nchw(self.yt) 45 | 46 | self.start_step = config.start_step 47 | self.log_step = config.log_step 48 | self.test_step = config.test_step 49 | self.max_step = config.max_step 50 | self.save_sec = config.save_sec 51 | self.lr_update_step = config.lr_update_step 52 | 53 | self.step = tf.Variable(self.start_step, name='step', trainable=False) 54 | 55 | self.is_train = config.is_train 56 | self.build_model() 57 | 58 | self.saver = tf.train.Saver() 59 | self.summary_writer = tf.summary.FileWriter(self.model_dir) 60 | 61 | sv = tf.train.Supervisor(logdir=self.model_dir, 62 | is_chief=True, 63 | saver=self.saver, 64 | summary_op=None, 65 | summary_writer=self.summary_writer, 66 | save_model_secs=self.save_sec, 67 | global_step=self.step, 68 | ready_for_local_init_op=None) 69 | 70 | gpu_options = tf.GPUOptions(allow_growth=True) 71 | sess_config = tf.ConfigProto(allow_soft_placement=True, 72 | gpu_options=gpu_options) 73 | 74 | self.sess = sv.prepare_or_wait_for_session(config=sess_config) 75 | if self.is_train: 76 | self.batch_manager.start_thread(self.sess) 77 | 78 | def build_model(self): 79 | self.y_, self.var = VDSR( 80 | self.x, self.conv_hidden_num, self.repeat_num, self.data_format, self.use_norm) 81 | self.y_img = denorm_img(self.y_, self.data_format) # for debug 82 | 83 | self.yt_, _ = VDSR( 84 | self.xt, self.conv_hidden_num, self.repeat_num, self.data_format, self.use_norm, 85 | train=False, reuse=True) 86 | self.yt_ = tf.clip_by_value(self.yt_, 0, 1) 87 | self.yt_img = denorm_img(self.yt_, self.data_format) 88 | 89 | show_all_variables() 90 | 91 | if self.optimizer == 'adam': 92 | optimizer = tf.train.AdamOptimizer 93 | else: 94 | raise Exception("[!] Caution! Paper didn't use {} opimizer other than Adam".format(self.config.optimizer)) 95 | 96 | optimizer = optimizer(self.lr, beta1=self.beta1, beta2=self.beta2) 97 | 98 | # losses 99 | # l1 and l2 100 | self.loss_l1 = tf.reduce_mean(tf.abs(self.y_ - self.y)) 101 | self.loss_l2 = tf.reduce_mean(tf.squared_difference(self.y_, self.y)) 102 | 103 | # total 104 | if self.use_l2: 105 | self.loss = self.loss_l2 106 | else: 107 | self.loss = self.loss_l1 108 | 109 | # test loss 110 | self.tl1 = 1 - tf.reduce_mean(tf.abs(self.yt_ - self.yt)) 111 | self.tl2 = 1 - tf.reduce_mean(tf.squared_difference(self.yt_, self.yt)) 112 | self.test_acc_l1 = tf.placeholder(tf.float32) 113 | self.test_acc_l2 = tf.placeholder(tf.float32) 114 | self.test_acc_iou = tf.placeholder(tf.float32) 115 | 116 | self.optim = optimizer.minimize(self.loss, global_step=self.step, var_list=self.var) 117 | 118 | summary = [ 119 | tf.summary.image("y", self.y_img), 120 | 121 | tf.summary.scalar("loss/loss", self.loss), 122 | tf.summary.scalar("loss/loss_l1", self.loss_l1), 123 | tf.summary.scalar("loss/loss_l2", self.loss_l2), 124 | 125 | tf.summary.scalar("misc/lr", self.lr), 126 | tf.summary.scalar('misc/q', self.batch_manager.q.size()) 127 | ] 128 | 129 | self.summary_op = tf.summary.merge(summary) 130 | 131 | summary = [ 132 | tf.summary.image("x_sample", denorm_img(self.x, self.data_format)), 133 | tf.summary.image("y_sample", denorm_img(self.y, self.data_format)), 134 | ] 135 | 136 | self.summary_once = tf.summary.merge(summary) # call just once 137 | 138 | summary = [ 139 | tf.summary.scalar("loss/test_acc_l1", self.test_acc_l1), 140 | tf.summary.scalar("loss/test_acc_l2", self.test_acc_l2), 141 | tf.summary.scalar("loss/test_acc_iou", self.test_acc_iou), 142 | ] 143 | 144 | self.summary_test = tf.summary.merge(summary) 145 | 146 | def train(self): 147 | x_list, xs, ys, sample_list = self.batch_manager.random_list(self.b_num) 148 | save_image(xs, '{}/x_gt.png'.format(self.model_dir)) 149 | save_image(ys, '{}/y_gt.png'.format(self.model_dir)) 150 | 151 | with open('{}/gt.txt'.format(self.model_dir), 'w') as f: 152 | for sample in sample_list: 153 | f.write(sample + '\n') 154 | 155 | # call once 156 | summary_once = self.sess.run(self.summary_once) 157 | self.summary_writer.add_summary(summary_once, 0) 158 | self.summary_writer.flush() 159 | 160 | for step in trange(self.start_step, self.max_step): 161 | fetch_dict = { 162 | "optim": self.optim, 163 | "loss": self.loss, 164 | } 165 | 166 | if step % self.log_step == 0 or step == self.max_step-1: 167 | fetch_dict.update({ 168 | "summary": self.summary_op, 169 | }) 170 | 171 | if step % self.test_step == self.test_step-1 or step == self.max_step-1: 172 | l1, l2, iou, nb = 0, 0, 0, 0 173 | for x, y in self.batch_manager.test_batch(): 174 | if self.data_format == 'NCHW': 175 | x = to_nchw_numpy(x) 176 | y = to_nchw_numpy(y) 177 | tl1, tl2, y_ = self.sess.run([self.tl1, self.tl2, self.yt_], {self.xt: x, self.yt: y}) 178 | l1 += tl1 179 | l2 += tl2 180 | nb += 1 181 | 182 | # iou 183 | y_I = np.logical_and(y>0, y_>0) 184 | y_I_sum = np.sum(y_I, axis=(1, 2, 3)) 185 | y_U = np.logical_or(y>0, y_>0) 186 | y_U_sum = np.sum(y_U, axis=(1, 2, 3)) 187 | # print(y_I_sum, y_U_sum) 188 | nonzero_id = np.where(y_U_sum != 0)[0] 189 | if nonzero_id.shape[0] == 0: 190 | acc = 1.0 191 | else: 192 | acc = np.average(y_I_sum[nonzero_id] / y_U_sum[nonzero_id]) 193 | iou += acc 194 | 195 | if nb > 500: 196 | break 197 | 198 | l1 /= float(nb) 199 | l2 /= float(nb) 200 | iou /= float(nb) 201 | 202 | summary_test = self.sess.run(self.summary_test, 203 | {self.test_acc_l1: l1, self.test_acc_l2: l2, self.test_acc_iou: iou}) 204 | self.summary_writer.add_summary(summary_test, step) 205 | self.summary_writer.flush() 206 | 207 | result = self.sess.run(fetch_dict) 208 | 209 | if step % self.log_step == 0 or step == self.max_step-1: 210 | self.summary_writer.add_summary(result['summary'], step) 211 | self.summary_writer.flush() 212 | 213 | loss = result['loss'] 214 | assert not np.isnan(loss), 'Model diverged with loss = NaN' 215 | 216 | print("\n[{}/{}] Loss: {:.6f}".format(step, self.max_step, loss)) 217 | 218 | if step % (self.log_step * 10) == 0 or step == self.max_step-1: 219 | self.generate(x_list, self.model_dir, idx=step) 220 | 221 | if step % self.lr_update_step == self.lr_update_step - 1: 222 | self.sess.run(self.lr_update) 223 | 224 | # save last checkpoint.. 225 | save_path = os.path.join(self.model_dir, 'model.ckpt') 226 | self.saver.save(self.sess, save_path, global_step=self.step) 227 | self.batch_manager.stop_thread() 228 | 229 | def generate(self, x_samples, root_path=None, idx=None): 230 | if self.data_format == 'NCHW': 231 | x_samples = to_nchw_numpy(x_samples) 232 | generated = self.sess.run(self.yt_img, {self.xt: x_samples}) 233 | y_path = os.path.join(root_path, 'y_{}.png'.format(idx)) 234 | save_image(generated, y_path, nrow=self.b_num) 235 | print("[*] Samples saved: {}".format(y_path)) -------------------------------------------------------------------------------- /gco/energy.h: -------------------------------------------------------------------------------- 1 | /* energy.h */ 2 | /* Vladimir Kolmogorov (vnk@cs.cornell.edu), 2003. */ 3 | 4 | /* 5 | This software implements an energy minimization technique described in 6 | 7 | What Energy Functions can be Minimized via Graph Cuts? 8 | Vladimir Kolmogorov and Ramin Zabih. 9 | To appear in IEEE Transactions on Pattern Analysis and Machine Intelligence (PAMI). 10 | Earlier version appeared in European Conference on Computer Vision (ECCV), May 2002. 11 | 12 | More specifically, it computes the global minimum of a function E of binary 13 | variables x_1, ..., x_n which can be written as a sum of terms involving 14 | at most three variables at a time: 15 | 16 | E(x_1, ..., x_n) = \sum_{i} E^{i} (x_i) 17 | + \sum_{i,j} E^{i,j} (x_i, x_j) 18 | + \sum_{i,j,k} E^{i,j,k}(x_i, x_j, x_k) 19 | 20 | The method works only if each term is "regular". Definitions of regularity 21 | for terms E^{i}, E^{i,j}, E^{i,j,k} are given below as comments to functions 22 | add_term1(), add_term2(), add_term3(). 23 | 24 | This software can be used only for research purposes. IF YOU USE THIS SOFTWARE, 25 | YOU SHOULD CITE THE AFOREMENTIONED PAPER IN ANY RESULTING PUBLICATION. 26 | 27 | In order to use it, you will also need a MAXFLOW software which can be 28 | obtained from http://www.cs.cornell.edu/People/vnk/software.html 29 | 30 | 31 | Example usage 32 | (Minimizes the following function of 3 binary variables: 33 | E(x, y, z) = x - 2*y + 3*(1-z) - 4*x*y + 5*|y-z|): 34 | 35 | /////////////////////////////////////////////////// 36 | 37 | #include 38 | #include "energy.h" 39 | 40 | void main() 41 | { 42 | // Minimize the following function of 3 binary variables: 43 | // E(x, y, z) = x - 2*y + 3*(1-z) - 4*x*y + 5*|y-z| 44 | 45 | Energy::Var varx, vary, varz; 46 | Energy *e = new Energy(); 47 | 48 | varx = e -> add_variable(); 49 | vary = e -> add_variable(); 50 | varz = e -> add_variable(); 51 | 52 | e -> add_term1(varx, 0, 1); // add term x 53 | e -> add_term1(vary, 0, -2); // add term -2*y 54 | e -> add_term1(varz, 3, 0); // add term 3*(1-z) 55 | 56 | e -> add_term2(x, y, 0, 0, 0, -4); // add term -4*x*y 57 | e -> add_term2(y, z, 0, 5, 5, 0); // add term 5*|y-z| 58 | 59 | Energy::TotalValue Emin = e -> minimize(); 60 | 61 | printf("Minimum = %d\n", Emin); 62 | printf("Optimal solution:\n"); 63 | printf("x = %d\n", e->get_var(varx)); 64 | printf("y = %d\n", e->get_var(vary)); 65 | printf("z = %d\n", e->get_var(varz)); 66 | 67 | delete e; 68 | } 69 | 70 | /////////////////////////////////////////////////// 71 | */ 72 | 73 | #ifndef __ENERGY_H__ 74 | #define __ENERGY_H__ 75 | 76 | #include 77 | #include "graph.h" 78 | 79 | template class Energy: public Graph 80 | { 81 | typedef Graph GraphT; 82 | public: 83 | typedef typename GraphT::node_id Var; 84 | 85 | /* Types of energy values. 86 | Value is a type of a value in a single term 87 | TotalValue is a type of a value of the total energy. 88 | By default Value = short, TotalValue = int. 89 | To change it, change the corresponding types in graph.h */ 90 | typedef captype Value; 91 | typedef flowtype TotalValue; 92 | 93 | /* interface functions */ 94 | 95 | /* Constructor. Optional argument is the pointer to the 96 | function which will be called if an error occurs; 97 | an error message is passed to this function. If this 98 | argument is omitted, exit(1) will be called. */ 99 | Energy(int var_num_max, int edge_num_max, void (*err_function)(const char *) = NULL); 100 | 101 | /* Destructor */ 102 | ~Energy(); 103 | 104 | /* Adds a new binary variable */ 105 | Var add_variable(int num=1); 106 | 107 | /* Adds a constant E to the energy function */ 108 | void add_constant(Value E); 109 | 110 | /* Adds a new term E(x) of one binary variable 111 | to the energy function, where 112 | E(0) = E0, E(1) = E1 113 | E0 and E1 can be arbitrary */ 114 | void add_term1(Var x, 115 | Value E0, Value E1); 116 | 117 | /* Adds a new term E(x,y) of two binary variables 118 | to the energy function, where 119 | E(0,0) = E00, E(0,1) = E01 120 | E(1,0) = E10, E(1,1) = E11 121 | The term must be regular, i.e. E00 + E11 <= E01 + E10 */ 122 | void add_term2(Var x, Var y, 123 | Value E00, Value E01, 124 | Value E10, Value E11); 125 | 126 | /* Adds a new term E(x,y,z) of three binary variables 127 | to the energy function, where 128 | E(0,0,0) = E000, E(0,0,1) = E001 129 | E(0,1,0) = E010, E(0,1,1) = E011 130 | E(1,0,0) = E100, E(1,0,1) = E101 131 | E(1,1,0) = E110, E(1,1,1) = E111 132 | The term must be regular. It means that if one 133 | of the variables is fixed (for example, y=1), then 134 | the resulting function of two variables must be regular. 135 | Since there are 6 ways to fix one variable 136 | (3 variables times 2 binary values - 0 and 1), 137 | this is equivalent to 6 inequalities */ 138 | void add_term3(Var x, Var y, Var z, 139 | Value E000, Value E001, 140 | Value E010, Value E011, 141 | Value E100, Value E101, 142 | Value E110, Value E111); 143 | 144 | /* After the energy function has been constructed, 145 | call this function to minimize it. 146 | Returns the minimum of the function */ 147 | TotalValue minimize(); 148 | 149 | /* After 'minimize' has been called, this function 150 | can be used to determine the value of variable 'x' 151 | in the optimal solution. 152 | Returns either 0 or 1 */ 153 | int get_var(Var x); 154 | 155 | /***********************************************************************/ 156 | /***********************************************************************/ 157 | /***********************************************************************/ 158 | 159 | private: 160 | /* internal variables and functions */ 161 | 162 | TotalValue Econst; 163 | void (*error_function)(const char *); /* this function is called if a error occurs, 164 | with a corresponding error message 165 | (or exit(1) is called if it's NULL) */ 166 | }; 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | /***********************************************************************/ 183 | /************************ Implementation ******************************/ 184 | /***********************************************************************/ 185 | 186 | template 187 | inline Energy::Energy(int var_num_max, int edge_num_max, void (*err_function)(const char *)) : Graph(var_num_max, edge_num_max, err_function) 188 | { 189 | Econst = 0; 190 | error_function = err_function; 191 | } 192 | 193 | template 194 | inline Energy::~Energy() {} 195 | 196 | template 197 | inline typename Energy::Var Energy::add_variable(int num) 198 | { return GraphT::add_node(num); } 199 | 200 | template 201 | inline void Energy::add_constant(Value A) { Econst += A; } 202 | 203 | template 204 | inline void Energy::add_term1(Var x, 205 | Value A, Value B) 206 | { 207 | this->add_tweights(x, B, A); 208 | } 209 | 210 | template 211 | inline void Energy::add_term2(Var x, Var y, 212 | Value A, Value B, 213 | Value C, Value D) 214 | { 215 | /* 216 | E = A A + 0 B-A 217 | D D C-D 0 218 | Add edges for the first term 219 | */ 220 | this->add_tweights(x, D, A); 221 | B -= A; C -= D; 222 | 223 | /* now need to represent 224 | 0 B 225 | C 0 226 | */ 227 | 228 | //assert(B + C >= 0); /* check regularity */ 229 | if (B < 0) 230 | { 231 | /* Write it as 232 | B B + -B 0 + 0 0 233 | 0 0 -B 0 B+C 0 234 | */ 235 | this->add_tweights(x, 0, B); /* first term */ 236 | this->add_tweights(y, 0, -B); /* second term */ 237 | this->add_edge(x, y, 0, B+C); /* third term */ 238 | } 239 | else if (C < 0) 240 | { 241 | /* Write it as 242 | -C -C + C 0 + 0 B+C 243 | 0 0 C 0 0 0 244 | */ 245 | this->add_tweights(x, 0, -C); /* first term */ 246 | this->add_tweights(y, 0, C); /* second term */ 247 | this->add_edge(x, y, B+C, 0); /* third term */ 248 | } 249 | else /* B >= 0, C >= 0 */ 250 | { 251 | this->add_edge(x, y, B, C); 252 | } 253 | } 254 | 255 | template 256 | inline void Energy::add_term3(Var x, Var y, Var z, 257 | Value E000, Value E001, 258 | Value E010, Value E011, 259 | Value E100, Value E101, 260 | Value E110, Value E111) 261 | { 262 | register Value pi = (E000 + E011 + E101 + E110) - (E100 + E010 + E001 + E111); 263 | register Value delta; 264 | register Var u; 265 | 266 | if (pi >= 0) 267 | { 268 | Econst += E111 - (E011 + E101 + E110); 269 | 270 | add_tweights(x, E101, E001); 271 | add_tweights(y, E110, E100); 272 | add_tweights(z, E011, E010); 273 | 274 | delta = (E010 + E001) - (E000 + E011); /* -pi(E[x=0]) */ 275 | assert(delta >= 0); /* check regularity */ 276 | add_edge(y, z, delta, 0); 277 | 278 | delta = (E100 + E001) - (E000 + E101); /* -pi(E[y=0]) */ 279 | assert(delta >= 0); /* check regularity */ 280 | add_edge(z, x, delta, 0); 281 | 282 | delta = (E100 + E010) - (E000 + E110); /* -pi(E[z=0]) */ 283 | assert(delta >= 0); /* check regularity */ 284 | add_edge(x, y, delta, 0); 285 | 286 | if (pi > 0) 287 | { 288 | u = add_variable(); 289 | add_edge(x, u, pi, 0); 290 | add_edge(y, u, pi, 0); 291 | add_edge(z, u, pi, 0); 292 | add_tweights(u, 0, pi); 293 | } 294 | } 295 | else 296 | { 297 | Econst += E000 - (E100 + E010 + E001); 298 | 299 | add_tweights(x, E110, E010); 300 | add_tweights(y, E011, E001); 301 | add_tweights(z, E101, E100); 302 | 303 | delta = (E110 + E101) - (E100 + E111); /* -pi(E[x=1]) */ 304 | assert(delta >= 0); /* check regularity */ 305 | add_edge(z, y, delta, 0); 306 | 307 | delta = (E110 + E011) - (E010 + E111); /* -pi(E[y=1]) */ 308 | assert(delta >= 0); /* check regularity */ 309 | add_edge(x, z, delta, 0); 310 | 311 | delta = (E101 + E011) - (E001 + E111); /* -pi(E[z=1]) */ 312 | assert(delta >= 0); /* check regularity */ 313 | add_edge(y, x, delta, 0); 314 | 315 | u = add_variable(); 316 | add_edge(u, x, -pi, 0); 317 | add_edge(u, y, -pi, 0); 318 | add_edge(u, z, -pi, 0); 319 | this->add_tweights(u, -pi, 0); 320 | } 321 | } 322 | 323 | template 324 | inline typename Energy::TotalValue Energy::minimize() { 325 | return Econst + GraphT::maxflow(); } 326 | 327 | template 328 | inline int Energy::get_var(Var x) { return (int) this->what_segment(x); } 329 | 330 | #endif 331 | -------------------------------------------------------------------------------- /gco/block.h: -------------------------------------------------------------------------------- 1 | /* block.h */ 2 | /* Copyright Vladimir Kolmogorov vnk@ist.ac.at */ 3 | /* Last modified May 2013 */ 4 | /* 5 | Template classes Block and DBlock 6 | Implement adding and deleting items of the same type in blocks. 7 | 8 | If there there are many items then using Block or DBlock 9 | is more efficient than using 'new' and 'delete' both in terms 10 | of memory and time since 11 | (1) On some systems there is some minimum amount of memory 12 | that 'new' can allocate (e.g., 64), so if items are 13 | small that a lot of memory is wasted. 14 | (2) 'new' and 'delete' are designed for items of varying size. 15 | If all items has the same size, then an algorithm for 16 | adding and deleting can be made more efficient. 17 | (3) All Block and DBlock functions are inline, so there are 18 | no extra function calls. 19 | 20 | Differences between Block and DBlock: 21 | (1) DBlock allows both adding and deleting items, 22 | whereas Block allows only adding items. 23 | (2) Block has an additional operation of scanning 24 | items added so far (in the order in which they were added). 25 | (3) Block allows to allocate several consecutive 26 | items at a time, whereas DBlock can add only a single item. 27 | 28 | Note that no constructors or destructors are called for items. 29 | 30 | Example usage for items of type 'MyType': 31 | 32 | /////////////////////////////////////////////////// 33 | #include "block.h" 34 | #define BLOCK_SIZE 1024 35 | typedef struct { int a, b; } MyType; 36 | MyType *ptr, *array[10000]; 37 | 38 | ... 39 | 40 | Block *block = new Block(BLOCK_SIZE); 41 | 42 | // adding items 43 | for (int i=0; i New(); 46 | ptr -> a = ptr -> b = rand(); 47 | } 48 | 49 | // reading items 50 | for (ptr=block->ScanFirst(); ptr; ptr=block->ScanNext()) 51 | { 52 | printf("%d %d\n", ptr->a, ptr->b); 53 | } 54 | 55 | delete block; 56 | 57 | ... 58 | 59 | DBlock *dblock = new DBlock(BLOCK_SIZE); 60 | 61 | // adding items 62 | for (int i=0; i New(); 65 | } 66 | 67 | // deleting items 68 | for (int i=0; i Delete(array[i]); 71 | } 72 | 73 | // adding items 74 | for (int i=0; i New(); 77 | } 78 | 79 | delete dblock; 80 | 81 | /////////////////////////////////////////////////// 82 | 83 | Note that DBlock deletes items by marking them as 84 | empty (i.e., by adding them to the list of free items), 85 | so that this memory could be used for subsequently 86 | added items. Thus, at each moment the memory allocated 87 | is determined by the maximum number of items allocated 88 | simultaneously at earlier moments. All memory is 89 | deallocated only when the destructor is called. 90 | */ 91 | 92 | #ifndef __BLOCK_H__ 93 | #define __BLOCK_H__ 94 | 95 | #include 96 | 97 | /***********************************************************************/ 98 | /***********************************************************************/ 99 | /***********************************************************************/ 100 | 101 | template class Block 102 | { 103 | public: 104 | /* Constructor. Arguments are the block size and 105 | (optionally) the pointer to the function which 106 | will be called if allocation failed; the message 107 | passed to this function is "Not enough memory!" */ 108 | Block(int size, void (*err_function)(const char *) = NULL) { first = last = NULL; block_size = size; error_function = err_function; } 109 | 110 | /* Destructor. Deallocates all items added so far */ 111 | ~Block() { while (first) { block *next = first -> next; delete[] ((char*)first); first = next; } } 112 | 113 | /* Allocates 'num' consecutive items; returns pointer 114 | to the first item. 'num' cannot be greater than the 115 | block size since items must fit in one block */ 116 | Type *New(int num = 1) 117 | { 118 | Type *t; 119 | 120 | if (!last || last->current + num > last->last) 121 | { 122 | if (last && last->next) last = last -> next; 123 | else 124 | { 125 | block *next = (block *) new char [sizeof(block) + (block_size-1)*sizeof(Type)]; 126 | if (!next) { if (error_function) (*error_function)("Not enough memory!"); exit(1); } 127 | if (last) last -> next = next; 128 | else first = next; 129 | last = next; 130 | last -> current = & ( last -> data[0] ); 131 | last -> last = last -> current + block_size; 132 | last -> next = NULL; 133 | } 134 | } 135 | 136 | t = last -> current; 137 | last -> current += num; 138 | return t; 139 | } 140 | 141 | /* Returns the first item (or NULL, if no items were added) */ 142 | Type *ScanFirst() 143 | { 144 | for (scan_current_block=first; scan_current_block; scan_current_block = scan_current_block->next) 145 | { 146 | scan_current_data = & ( scan_current_block -> data[0] ); 147 | if (scan_current_data < scan_current_block -> current) return scan_current_data ++; 148 | } 149 | return NULL; 150 | } 151 | 152 | /* Returns the next item (or NULL, if all items have been read) 153 | Can be called only if previous ScanFirst() or ScanNext() 154 | call returned not NULL. */ 155 | Type *ScanNext() 156 | { 157 | while (scan_current_data >= scan_current_block -> current) 158 | { 159 | scan_current_block = scan_current_block -> next; 160 | if (!scan_current_block) return NULL; 161 | scan_current_data = & ( scan_current_block -> data[0] ); 162 | } 163 | return scan_current_data ++; 164 | } 165 | 166 | struct iterator; // for overlapping scans 167 | Type *ScanFirst(iterator& i) 168 | { 169 | for (i.scan_current_block=first; i.scan_current_block; i.scan_current_block = i.scan_current_block->next) 170 | { 171 | i.scan_current_data = & ( i.scan_current_block -> data[0] ); 172 | if (i.scan_current_data < i.scan_current_block -> current) return i.scan_current_data ++; 173 | } 174 | return NULL; 175 | } 176 | Type *ScanNext(iterator& i) 177 | { 178 | while (i.scan_current_data >= i.scan_current_block -> current) 179 | { 180 | i.scan_current_block = i.scan_current_block -> next; 181 | if (!i.scan_current_block) return NULL; 182 | i.scan_current_data = & ( i.scan_current_block -> data[0] ); 183 | } 184 | return i.scan_current_data ++; 185 | } 186 | 187 | /* Marks all elements as empty */ 188 | void Reset() 189 | { 190 | block *b; 191 | if (!first) return; 192 | for (b=first; ; b=b->next) 193 | { 194 | b -> current = & ( b -> data[0] ); 195 | if (b == last) break; 196 | } 197 | last = first; 198 | } 199 | 200 | /***********************************************************************/ 201 | 202 | private: 203 | 204 | typedef struct block_st 205 | { 206 | Type *current, *last; 207 | struct block_st *next; 208 | Type data[1]; 209 | } block; 210 | 211 | int block_size; 212 | block *first; 213 | block *last; 214 | public: 215 | struct iterator 216 | { 217 | block *scan_current_block; 218 | Type *scan_current_data; 219 | }; 220 | private: 221 | block *scan_current_block; 222 | Type *scan_current_data; 223 | 224 | void (*error_function)(const char *); 225 | }; 226 | 227 | /***********************************************************************/ 228 | /***********************************************************************/ 229 | /***********************************************************************/ 230 | 231 | template class DBlock 232 | { 233 | public: 234 | /* Constructor. Arguments are the block size and 235 | (optionally) the pointer to the function which 236 | will be called if allocation failed; the message 237 | passed to this function is "Not enough memory!" */ 238 | DBlock(int size, void (*err_function)(const char *) = NULL) { first = NULL; first_free = NULL; block_size = size; error_function = err_function; } 239 | 240 | /* Destructor. Deallocates all items added so far */ 241 | ~DBlock() { while (first) { block *next = first -> next; delete[] ((char*)first); first = next; } } 242 | 243 | /* Allocates one item */ 244 | Type *New() 245 | { 246 | block_item *item; 247 | 248 | if (!first_free) 249 | { 250 | block *next = first; 251 | first = (block *) new char [sizeof(block) + (block_size-1)*sizeof(block_item)]; 252 | if (!first) { if (error_function) (*error_function)("Not enough memory!"); exit(1); } 253 | first_free = & (first -> data[0] ); 254 | for (item=first_free; item next_free = item + 1; 256 | item -> next_free = NULL; 257 | first -> next = next; 258 | } 259 | 260 | item = first_free; 261 | first_free = item -> next_free; 262 | return (Type *) item; 263 | } 264 | 265 | /* Deletes an item allocated previously */ 266 | void Delete(Type *t) 267 | { 268 | ((block_item *) t) -> next_free = first_free; 269 | first_free = (block_item *) t; 270 | } 271 | 272 | /***********************************************************************/ 273 | 274 | private: 275 | 276 | typedef union block_item_st 277 | { 278 | Type t; 279 | block_item_st *next_free; 280 | } block_item; 281 | 282 | typedef struct block_st 283 | { 284 | struct block_st *next; 285 | block_item data[1]; 286 | } block; 287 | 288 | int block_size; 289 | block *first; 290 | block_item *first_free; 291 | 292 | void (*error_function)(const char *); 293 | }; 294 | 295 | 296 | /***********************************************************************/ 297 | /***********************************************************************/ 298 | /***********************************************************************/ 299 | 300 | // there is no Free() function, just Alloc() that could return the same pointer. 301 | // The allocated space grows as needed. 302 | class ReusableBuffer 303 | { 304 | public: 305 | /* Constructor. */ 306 | ReusableBuffer(void (*err_function)(const char *) = NULL) : size_max(0), buf(NULL), error_function(err_function) {} 307 | ~ReusableBuffer() { if (buf) free(buf); } 308 | 309 | void* Alloc(int size) 310 | { 311 | if (size <= size_max) return buf; 312 | size_max = (int)(1.2*size_max) + size; 313 | if (buf) free(buf); 314 | buf = (char*)malloc(size_max); 315 | if (!buf) { if (error_function) (*error_function)("Not enough memory!"); exit(1); } 316 | return buf; 317 | } 318 | void* Realloc(int size) 319 | { 320 | if (size <= size_max) return buf; 321 | size_max = (int)(1.2*size_max) + size; 322 | if (buf) buf = (char*)realloc(buf, size_max); 323 | else buf = (char*)malloc(size_max); 324 | if (!buf) { if (error_function) (*error_function)("Not enough memory!"); exit(1); } 325 | return buf; 326 | } 327 | 328 | private: 329 | char* buf; 330 | int size_max; 331 | 332 | void (*error_function)(const char *); 333 | }; 334 | 335 | 336 | /***********************************************************************/ 337 | /***********************************************************************/ 338 | /***********************************************************************/ 339 | 340 | // Again, no Free() function but different calls to Alloc() return pointers 341 | // to disjoint chunks of memory (unlike ReusableBuffer). Convenient to avoid 342 | // explicit garbage collection. 343 | class Buffer 344 | { 345 | public: 346 | Buffer(int _default_size, void (*err_function)(const char *) = NULL) 347 | : default_size(_default_size), buf_first(NULL), error_function(err_function) {} 348 | ~Buffer() 349 | { 350 | while (buf_first) 351 | { 352 | char* b = (char*) buf_first; 353 | buf_first = buf_first->next; 354 | delete [] b; 355 | } 356 | } 357 | 358 | void* Alloc(int size) 359 | { 360 | if (!buf_first || buf_first->size+size>buf_first->size_max) 361 | { 362 | int size_max = 2*size + default_size; 363 | Buf* b = (Buf*)(new char[sizeof(Buf)+size_max]); 364 | if (!b) { if (error_function) (*error_function)("Not enough memory!"); exit(1); } 365 | b->next = buf_first; 366 | buf_first = b; 367 | b->size = 0; 368 | b->size_max = size_max; 369 | b->arr = (char*)(b+1); 370 | } 371 | char* ptr = &buf_first->arr[buf_first->size]; 372 | buf_first->size += size; 373 | return ptr; 374 | } 375 | private: 376 | struct Buf 377 | { 378 | int size, size_max; 379 | char* arr; 380 | Buf* next; 381 | }; 382 | int default_size; 383 | Buf* buf_first; 384 | 385 | void (*error_function)(const char *); 386 | }; 387 | 388 | #endif 389 | 390 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import tensorflow.contrib.slim as slim 7 | import numpy as np 8 | 9 | def lrelu(x, leak=0.2): 10 | return tf.maximum(x, leak*x) 11 | 12 | def conv2d(x, o_dim, data_format='NHWC', name=None, k=4, s=2, act=None): 13 | return slim.conv2d(x, o_dim, k, stride=s, activation_fn=act, scope=name, data_format=data_format) 14 | 15 | def deconv2d(x, o_dim, data_format='NHWC', name=None, k=4, s=2, act=None): 16 | return slim.conv2d_transpose(x, o_dim, k, stride=s, activation_fn=act, scope=name, data_format=data_format) 17 | 18 | def linear(x, o_dim, name=None, act=None): 19 | return slim.fully_connected(x, o_dim, activation_fn=act, scope=name) 20 | 21 | def batch_norm(x, train, data_format='NHWC', name=None, act=lrelu, epsilon=1e-5, momentum=0.9): 22 | return slim.batch_norm(x, 23 | decay=momentum, 24 | updates_collections=None, 25 | epsilon=epsilon, 26 | scale=True, 27 | fused=True, 28 | is_training=train, 29 | activation_fn=act, 30 | data_format=data_format, 31 | scope=name) 32 | 33 | def inst_norm(x, train, data_format='NHWC', name=None, affine=False, act=lrelu, epsilon=1e-5): 34 | with tf.variable_scope(name, default_name='Inst', reuse=None) as vs: 35 | if x.get_shape().ndims == 4 and data_format == 'NCHW': 36 | x = nchw_to_nhwc(x) 37 | 38 | if x.get_shape().ndims == 4: 39 | mean_dim = [1,2] 40 | else: # 2 41 | mean_dim = [1] 42 | 43 | mu, sigma_sq = tf.nn.moments(x, mean_dim, keep_dims=True) 44 | inv = tf.rsqrt(sigma_sq+epsilon) 45 | normalized = (x-mu)*inv 46 | 47 | if affine: 48 | var_shape = [x.get_shape()[-1]] 49 | shift = slim.model_variable('shift', shape=var_shape, initializer=tf.zeros_initializer) 50 | scale = slim.model_variable('scale', shape=var_shape, initializer=tf.ones_initializer) 51 | out = scale*normalized + shift 52 | else: 53 | out = normalized 54 | 55 | if x.get_shape().ndims == 4 and data_format == 'NCHW': 56 | out = nhwc_to_nchw(out) 57 | 58 | if act is None: return out 59 | else: return act(out) 60 | 61 | def resize_nearest_neighbor(x, new_size, data_format): 62 | if data_format == 'NCHW': 63 | x = nchw_to_nhwc(x) 64 | x = tf.image.resize_nearest_neighbor(x, new_size) 65 | x = nhwc_to_nchw(x) 66 | else: 67 | x = tf.image.resize_nearest_neighbor(x, new_size) 68 | return x 69 | 70 | def upscale(x, scale, data_format): 71 | _, h, w, _ = get_conv_shape(x, data_format) 72 | return resize_nearest_neighbor(x, (h*scale, w*scale), data_format) 73 | 74 | def var_on_cpu(name, shape, initializer, dtype=tf.float32): 75 | return slim.model_variable(name, shape, dtype=dtype, initializer=initializer, device='/CPU:0') 76 | 77 | def int_shape(tensor): 78 | shape = tensor.get_shape().as_list() 79 | return [num if num is not None else -1 for num in shape] 80 | 81 | def get_conv_shape(tensor, data_format): 82 | shape = int_shape(tensor) 83 | # always return [N, H, W, C] 84 | if data_format == 'NCHW': 85 | return [shape[0], shape[2], shape[3], shape[1]] 86 | elif data_format == 'NHWC': 87 | return shape 88 | 89 | def nchw_to_nhwc(x): 90 | return tf.transpose(x, [0, 2, 3, 1]) 91 | 92 | def nhwc_to_nchw(x): 93 | return tf.transpose(x, [0, 3, 1, 2]) 94 | 95 | def next(loader): 96 | return loader.next()[0].data.numpy() 97 | 98 | def to_nhwc(image, data_format): 99 | if data_format == 'NCHW': 100 | new_image = nchw_to_nhwc(image) 101 | else: 102 | new_image = image 103 | return new_image 104 | 105 | def to_nchw_numpy(image): 106 | if image.shape[3] in [1,2,3]: 107 | new_image = image.transpose([0, 3, 1, 2]) 108 | else: 109 | new_image = image 110 | return new_image 111 | 112 | def to_nhwc_numpy(image): 113 | if image.shape[1] in [1,2,3]: 114 | new_image = image.transpose([0, 2, 3, 1]) 115 | else: 116 | new_image = image 117 | return new_image 118 | 119 | def add_channels(x, num_ch=1, data_format='NHWC'): 120 | b, h, w, c = get_conv_shape(x, data_format) 121 | if data_format == 'NCHW': 122 | x = tf.concat([x, tf.zeros([b, num_ch, h, w])], axis=1) 123 | else: 124 | x = tf.concat([x, tf.zeros([b, h, w, num_ch])], axis=-1) 125 | return x 126 | 127 | def remove_channels(x, data_format='NHWC'): 128 | b, h, w, c = get_conv_shape(x, data_format) 129 | if data_format == 'NCHW': 130 | x, _ = tf.split(x, [3, -1], axis=1) 131 | else: 132 | x, _ = tf.split(x, [3, -1], axis=3) 133 | return x 134 | 135 | def denorm_img(norm, data_format): 136 | _, _, _, c = get_conv_shape(norm, data_format) 137 | if c == 2: 138 | norm = add_channels(norm, num_ch=1, data_format=data_format) 139 | elif c > 3: 140 | norm = remove_channels(norm, data_format=data_format) 141 | return tf.clip_by_value(to_nhwc(norm*255, data_format), 0, 255) 142 | 143 | def reshape(x, h, w, c, data_format): 144 | if data_format == 'NCHW': 145 | x = tf.reshape(x, [-1, c, h, w]) 146 | else: 147 | x = tf.reshape(x, [-1, h, w, c]) 148 | return x 149 | 150 | def show_all_variables(): 151 | model_vars = tf.trainable_variables() 152 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 153 | 154 | 155 | # https://stackoverflow.com/questions/39051451/ssim-ms-ssim-for-tensorflow 156 | # https://github.com/tensorflow/models/blob/master/compression/image_encoder/msssim.py 157 | def fspecial_gauss(size, sigma, channels): 158 | """ 159 | Function to mimic the 'fspecial' gaussian MATLAB function 160 | """ 161 | radius = size // 2 162 | offset = 0.0 163 | start, stop = -radius, radius + 1 164 | if size % 2 == 0: 165 | offset = 0.5 166 | stop -= 1 167 | x, y = np.mgrid[offset + start:stop, offset + start:stop] 168 | assert len(x) == size 169 | 170 | x = x.reshape(x.shape+(1,1)) 171 | x = np.repeat(x, channels, axis=2) 172 | x = np.repeat(x, channels, axis=3) 173 | 174 | y = y.reshape(y.shape+(1,1)) 175 | y = np.repeat(y, channels, axis=2) 176 | y = np.repeat(y, channels, axis=3) 177 | 178 | x = tf.constant(x, dtype=tf.float32) 179 | y = tf.constant(y, dtype=tf.float32) 180 | 181 | g = tf.exp(-((x**2 + y**2)/(2.0*sigma**2))) 182 | return g / tf.reduce_sum(g) 183 | 184 | 185 | def ssim(img1, img2, mean_metric=True, 186 | filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03, 187 | min_val=-1.0, max_val=1.0): 188 | 189 | # input should be rescaled to [-1,1] 190 | img_shape = img1.get_shape() 191 | height = img_shape[1].value 192 | width = img_shape[2].value 193 | channels = img_shape[3].value 194 | # print(img_shape) 195 | 196 | # Filter size can't be larger than height or width of images. 197 | size = min(filter_size, height, width) 198 | # print(size) 199 | 200 | # Scale down sigma if a smaller filter size is used. 201 | sigma = filter_sigma * size / filter_size if filter_size else 0 202 | # print(sigma) 203 | 204 | # ! normalize image to [0,1] 205 | img1 = (img1 - min_val) / (max_val - min_val) 206 | img2 = (img2 - min_val) / (max_val - min_val) 207 | 208 | if filter_size: 209 | window = fspecial_gauss(size, sigma, channels) # window shape [size, size] 210 | mu1 = tf.nn.conv2d(img1, window, strides=[1,1,1,1], padding='VALID') 211 | mu2 = tf.nn.conv2d(img2, window, strides=[1,1,1,1], padding='VALID') 212 | sigma11 = tf.nn.conv2d(img1*img1, window, strides=[1,1,1,1], padding='VALID') 213 | sigma22 = tf.nn.conv2d(img2*img2, window, strides=[1,1,1,1], padding='VALID') 214 | sigma12 = tf.nn.conv2d(img1*img2, window, strides=[1,1,1,1], padding='VALID') 215 | else: 216 | mu1 = img1, mu2 = img2 217 | sigma11 = img1*img1 218 | sigma22 = img2*img2 219 | sigma12 = img1*img2 220 | 221 | mu11 = mu1*mu1 222 | mu22 = mu2*mu2 223 | mu12 = mu1*mu2 224 | sigma11 -= mu11 225 | sigma22 -= mu22 226 | sigma12 -= mu12 227 | 228 | L = 1.0 # max scale, already normalized to 1 229 | c1 = (k1*L)**2 230 | c2 = (k2*L)**2 231 | v1 = 2.0 * sigma12 + c2 232 | v2 = sigma11 + sigma22 + c2 233 | value = ((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2) 234 | if mean_metric: return tf.reduce_mean(value) 235 | 236 | result = {'ssim_map': value, 'cs_map': v1/v2, 'g': window} 237 | return result 238 | 239 | 240 | def ms_ssim(img1, img2, mean_metric=True, min_val=-1.0, max_val=1.0): 241 | # input should be rescaled to [-1,1] 242 | weight = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] 243 | mssim = [] 244 | mcs = [] 245 | 246 | for w in weight: 247 | result = ssim(img1, img2, mean_metric=False, min_val=min_val, max_val=max_val) 248 | mssim.append(tf.reduce_mean(result['ssim_map'])) 249 | mcs.append(tf.reduce_mean(result['cs_map'])) 250 | filtered_im1 = tf.nn.avg_pool(img1, [1,2,2,1], [1,2,2,1], padding='SAME') 251 | filtered_im2 = tf.nn.avg_pool(img2, [1,2,2,1], [1,2,2,1], padding='SAME') 252 | img1 = filtered_im1 253 | img2 = filtered_im2 254 | 255 | # ! doesn't work 256 | # filter_sigmas = [0.5, 1, 2, 4, 8] 257 | # cs_map0 = None 258 | # for i, filter_sigma in enumerate(filter_sigmas): 259 | # result = ssim(img1, img2, filter_sigma=filter_sigma, 260 | # min_val=min_val, mean_metric=False) 261 | 262 | # if i == 0: cs_map0 = result['cs_map'] 263 | 264 | # mssim.append(tf.reduce_mean(result['ssim_map'])) 265 | # mcs.append(tf.reduce_mean(tf.nn.conv2d(cs_map0, result['g'], strides=[1,1,1,1], padding='VALID'))) 266 | 267 | # list to tensor of dim D+1 268 | mssim = tf.stack(mssim, axis=0) 269 | mcs = tf.stack(mcs, axis=0) 270 | level = len(weight) 271 | 272 | value = (tf.reduce_prod(mcs[0:level-1]**weight[0:level-1])* 273 | (mssim[level-1]**weight[level-1])) 274 | 275 | if mean_metric: value = tf.reduce_mean(value) 276 | return value 277 | 278 | 279 | def main(_): 280 | from skimage import data, transform, img_as_float 281 | import matplotlib.pyplot as plt 282 | 283 | color = False 284 | if color: 285 | image = data.astronaut() 286 | else: # [h,w] -> [h,w,1] 287 | image = data.camera() 288 | image = np.expand_dims(image, axis=-1) 289 | 290 | # image = transform.resize(image, output_shape=[128, 128]) 291 | 292 | img = img_as_float(image) 293 | print(img.shape) 294 | rows, cols, channels = img.shape 295 | 296 | noise = np.ones_like(img) * 0.2 * (img.max() - img.min()) 297 | noise[np.random.random(size=noise.shape) > 0.5] *= -1 298 | 299 | img_noise = img + noise 300 | img_noise = np.clip(img_noise, a_min=0, a_max=1) 301 | 302 | plt.figure() 303 | plt.subplot(121) 304 | if color: 305 | plt.imshow(img) 306 | plt.subplot(122) 307 | plt.imshow(img_noise) 308 | else: 309 | plt.imshow(img[:,:,0], cmap='gray') 310 | plt.subplot(122) 311 | plt.imshow(img_noise[:,:,0], cmap='gray') 312 | plt.show() 313 | 314 | 315 | ## TF CALC START 316 | image1 = tf.placeholder(tf.float32, shape=[rows, cols, channels]) 317 | image2 = tf.placeholder(tf.float32, shape=[rows, cols, channels]) 318 | 319 | def image_to_4d(image): 320 | image = tf.expand_dims(image, 0) 321 | return image 322 | 323 | image4d_1 = image_to_4d(image1) 324 | image4d_2 = image_to_4d(image2) 325 | 326 | print(img.min(), img.max(), img_noise.min(), img_noise.max()) 327 | ssim_index = ssim(image4d_1, image4d_2) #, min_val=0.0, max_val=1.0) 328 | msssim_index = ms_ssim(image4d_1, image4d_2) #, min_val=0.0, max_val=1.0) 329 | 330 | # img *= 255 331 | # img_noise *= 255 332 | # ssim_index = ssim(image4d_1, image4d_2, min_val=0.0, max_val=255.0) 333 | # msssim_index = ms_ssim(image4d_1, image4d_2, min_val=0.0, max_val=255.0) 334 | 335 | 336 | with tf.Session() as sess: 337 | sess.run(tf.global_variables_initializer()) 338 | 339 | tf_ssim_none = sess.run(ssim_index, 340 | feed_dict={image1: img, image2: img}) 341 | tf_ssim_noise = sess.run(ssim_index, 342 | feed_dict={image1: img, image2: img_noise}) 343 | 344 | tf_msssim_none = sess.run(msssim_index, 345 | feed_dict={image1: img, image2: img}) 346 | tf_msssim_noise = sess.run(msssim_index, 347 | feed_dict={image1: img, image2: img_noise}) 348 | ###TF CALC END 349 | 350 | print('tf_ssim_none', tf_ssim_none) 351 | print('tf_ssim_noise', tf_ssim_noise) 352 | print('tf_msssim_none', tf_msssim_none) 353 | print('tf_msssim_noise', tf_msssim_noise) 354 | 355 | 356 | if __name__ == '__main__': 357 | tf.app.run() -------------------------------------------------------------------------------- /data_qdraw.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import threading 4 | import multiprocessing 5 | import signal 6 | import sys 7 | from datetime import datetime 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | import cairosvg 12 | from PIL import Image 13 | import io 14 | import xml.etree.ElementTree as et 15 | import matplotlib.pyplot as plt 16 | 17 | from ops import * 18 | 19 | class BatchManager(object): 20 | def __init__(self, config): 21 | self.root = config.data_path 22 | self.rng = np.random.RandomState(config.random_seed) 23 | 24 | self.paths = sorted(glob("{}/train/*.{}".format(self.root, 'svg'))) 25 | self.test_paths = sorted(glob("{}/test/*.{}".format(self.root, 'svg'))) 26 | self.vec_paths = sorted(glob("{}/vec/*.{}".format(self.root, 'svg'))) 27 | assert(len(self.paths) > 0 and len(self.test_paths) > 0 and len(self.vec_paths) > 0) 28 | 29 | self.batch_size = config.batch_size 30 | self.height = config.height 31 | self.width = config.width 32 | 33 | self.is_pathnet = (config.archi == 'path') 34 | if self.is_pathnet: 35 | feature_dim = [self.height, self.width, 2] 36 | label_dim = [self.height, self.width, 1] 37 | else: 38 | feature_dim = [self.height, self.width, 1] 39 | label_dim = [self.height, self.width, 1] 40 | 41 | self.capacity = 10000 42 | self.q = tf.FIFOQueue(self.capacity, [tf.float32, tf.float32], [feature_dim, label_dim]) 43 | self.x = tf.placeholder(dtype=tf.float32, shape=feature_dim) 44 | self.y = tf.placeholder(dtype=tf.float32, shape=label_dim) 45 | self.enqueue = self.q.enqueue([self.x, self.y]) 46 | self.num_threads = config.num_worker 47 | # np.amin([config.num_worker, multiprocessing.cpu_count(), self.batch_size]) 48 | 49 | def __del__(self): 50 | try: 51 | self.stop_thread() 52 | except AttributeError: 53 | pass 54 | 55 | def start_thread(self, sess): 56 | print('%s: start to enque with %d threads' % (datetime.now(), self.num_threads)) 57 | 58 | # Main thread: create a coordinator. 59 | self.sess = sess 60 | self.coord = tf.train.Coordinator() 61 | 62 | # Create a method for loading and enqueuing 63 | def load_n_enqueue(sess, enqueue, coord, paths, rng, 64 | x, y, w, h, is_pathnet): 65 | with coord.stop_on_exception(): 66 | while not coord.should_stop(): 67 | id = rng.randint(len(paths)) 68 | if is_pathnet: 69 | x_, y_ = preprocess_path(paths[id], w, h, rng) 70 | else: 71 | x_, y_ = preprocess_overlap(paths[id], w, h, rng) 72 | sess.run(enqueue, feed_dict={x: x_, y: y_}) 73 | 74 | # Create threads that enqueue 75 | self.threads = [threading.Thread(target=load_n_enqueue, 76 | args=(self.sess, 77 | self.enqueue, 78 | self.coord, 79 | self.paths, 80 | self.rng, 81 | self.x, 82 | self.y, 83 | self.width, 84 | self.height, 85 | self.is_pathnet) 86 | ) for i in range(self.num_threads)] 87 | 88 | # define signal handler 89 | def signal_handler(signum, frame): 90 | #print "stop training, save checkpoint..." 91 | #saver.save(sess, "./checkpoints/VDSR_norm_clip_epoch_%03d.ckpt" % epoch ,global_step=global_step) 92 | print('%s: canceled by SIGINT' % datetime.now()) 93 | self.coord.request_stop() 94 | self.sess.run(self.q.close(cancel_pending_enqueues=True)) 95 | self.coord.join(self.threads) 96 | sys.exit(1) 97 | signal.signal(signal.SIGINT, signal_handler) 98 | 99 | # Start the threads and wait for all of them to stop. 100 | for t in self.threads: 101 | t.start() 102 | 103 | # dirty way to bypass graph finilization error 104 | g = tf.get_default_graph() 105 | g._finalized = False 106 | qs = 0 107 | while qs < (self.capacity*0.8): 108 | qs = self.sess.run(self.q.size()) 109 | print('%s: q size %d' % (datetime.now(), qs)) 110 | 111 | def stop_thread(self): 112 | # dirty way to bypass graph finilization error 113 | g = tf.get_default_graph() 114 | g._finalized = False 115 | 116 | self.coord.request_stop() 117 | self.sess.run(self.q.close(cancel_pending_enqueues=True)) 118 | self.coord.join(self.threads) 119 | 120 | def test_batch(self): 121 | x_list, y_list = [], [] 122 | for i, file_path in enumerate(self.test_paths): 123 | if self.is_pathnet: 124 | x_, y_ = preprocess_path(file_path, self.width, self.height, self.rng) 125 | else: 126 | x_, y_ = preprocess_overlap(file_path, self.width, self.height, self.rng) 127 | x_list.append(x_) 128 | y_list.append(y_) 129 | if i % self.batch_size == self.batch_size-1: 130 | yield np.array(x_list), np.array(y_list) 131 | x_list, y_list = [], [] 132 | 133 | def batch(self): 134 | return self.q.dequeue_many(self.batch_size) 135 | 136 | def sample(self, num): 137 | idx = self.rng.choice(len(self.paths), num).tolist() 138 | return [self.paths[i] for i in idx] 139 | 140 | def random_list(self, num): 141 | x_list = [] 142 | xs, ys = [], [] 143 | file_list = self.sample(num) 144 | for file_path in file_list: 145 | if self.is_pathnet: 146 | x, y = preprocess_path(file_path, self.width, self.height, self.rng) 147 | else: 148 | x, y = preprocess_overlap(file_path, self.width, self.height, self.rng) 149 | x_list.append(x) 150 | 151 | if self.is_pathnet: 152 | b_ch = np.zeros([self.height,self.width,1]) 153 | xs.append(np.concatenate((x*255, b_ch), axis=-1)) 154 | else: 155 | xs.append(x*255) 156 | ys.append(y*255) 157 | 158 | return np.array(x_list), np.array(xs), np.array(ys), file_list 159 | 160 | def read_svg(self, file_path): 161 | with open(file_path, 'r') as f: 162 | svg = f.read() 163 | 164 | img = cairosvg.svg2png(bytestring=svg.encode('utf-8')) 165 | img = Image.open(io.BytesIO(img)) 166 | s = np.array(img)[:,:,3].astype(np.float) # / 255.0 167 | max_intensity = np.amax(s) 168 | if max_intensity == 0: 169 | return s, 0, [] 170 | s = s / max_intensity 171 | 172 | path_list = [] 173 | num_paths = svg.count('polyline') 174 | 175 | for i in range(1,num_paths+1): 176 | svg_xml = et.fromstring(svg) 177 | svg_xml[1] = svg_xml[i] 178 | del svg_xml[2:] 179 | svg_one = et.tostring(svg_xml, method='xml') 180 | 181 | # leave only one path 182 | y_png = cairosvg.svg2png(bytestring=svg_one) 183 | y_img = Image.open(io.BytesIO(y_png)) 184 | path = (np.array(y_img)[:,:,3] > 0) 185 | path_list.append(path) 186 | 187 | return s, num_paths, path_list 188 | 189 | def preprocess_path(file_path, w, h, rng): 190 | with open(file_path, 'r') as f: 191 | svg = f.read() 192 | 193 | img = cairosvg.svg2png(bytestring=svg.encode('utf-8')) 194 | img = Image.open(io.BytesIO(img)) 195 | s = np.array(img)[:,:,3].astype(np.float) # / 255.0 196 | max_intensity = np.amax(s) 197 | if max_intensity == 0: 198 | x = np.zeros([h, w, 2]) 199 | y = np.zeros([h, w, 1]) 200 | return x, y 201 | s = s / max_intensity 202 | 203 | while True: 204 | svg_xml = et.fromstring(svg) 205 | num_paths = svg.count('polyline') 206 | path_id = rng.randint(1,num_paths+1) 207 | svg_xml[1] = svg_xml[path_id] 208 | del svg_xml[2:] 209 | svg_one = et.tostring(svg_xml, method='xml') 210 | 211 | # leave only one path 212 | y_png = cairosvg.svg2png(bytestring=svg_one) 213 | y_img = Image.open(io.BytesIO(y_png)) 214 | y = np.array(y_img)[:,:,3].astype(np.float) / max_intensity # [0,1] 215 | 216 | pixel_ids = np.nonzero(y) 217 | # assert len(pixel_ids[0]) > 0, '%s: no stroke px' % file_path 218 | if len(pixel_ids[0]) > 0: 219 | break 220 | 221 | # select arbitrary marking pixel 222 | point_id = rng.randint(len(pixel_ids[0])) 223 | px, py = pixel_ids[0][point_id], pixel_ids[1][point_id] 224 | 225 | y = np.reshape(y, [h, w, 1]) 226 | x = np.zeros([h, w, 2]) 227 | x[:,:,0] = s 228 | x[px,py,1] = 1.0 229 | 230 | # # debug 231 | # plt.figure() 232 | # plt.subplot(221) 233 | # plt.imshow(img) 234 | # plt.subplot(222) 235 | # plt.imshow(s, cmap=plt.cm.gray) 236 | # plt.subplot(223) 237 | # plt.imshow(np.concatenate((x, np.zeros([h, w, 1])), axis=-1)) 238 | # plt.subplot(224) 239 | # plt.imshow(y[:,:,0], cmap=plt.cm.gray) 240 | # plt.show() 241 | 242 | return x, y 243 | 244 | def preprocess_overlap(file_path, w, h, rng): 245 | with open(file_path, 'r') as f: 246 | svg = f.read() 247 | 248 | img = cairosvg.svg2png(bytestring=svg.encode('utf-8')) 249 | img = Image.open(io.BytesIO(img)) 250 | s = np.array(img)[:,:,3].astype(np.float) # / 255.0 251 | max_intensity = np.amax(s) 252 | if max_intensity == 0: 253 | x = np.zeros([h, w, 1]) 254 | y = np.zeros([h, w, 1]) 255 | return x, y 256 | s = s / max_intensity 257 | 258 | path_list = [] 259 | num_paths = svg.count('polyline') 260 | 261 | for i in range(1,num_paths+1): 262 | svg_xml = et.fromstring(svg) 263 | svg_xml[1] = svg_xml[i] 264 | del svg_xml[2:] 265 | svg_one = et.tostring(svg_xml, method='xml') 266 | 267 | # leave only one path 268 | y_png = cairosvg.svg2png(bytestring=svg_one) 269 | y_img = Image.open(io.BytesIO(y_png)) 270 | path = (np.array(y_img)[:,:,3] > 0) 271 | path_list.append(path) 272 | 273 | y = np.zeros([h, w], dtype=np.int) 274 | for i in range(num_paths-1): 275 | for j in range(i+1, num_paths): 276 | intersect = np.logical_and(path_list[i], path_list[j]) 277 | y = np.logical_or(intersect, y) 278 | 279 | x = np.expand_dims(s, axis=-1) 280 | y = np.expand_dims(y, axis=-1) 281 | 282 | # # debug 283 | # plt.figure() 284 | # plt.subplot(131) 285 | # plt.imshow(img) 286 | # plt.subplot(132) 287 | # plt.imshow(s, cmap=plt.cm.gray) 288 | # plt.subplot(133) 289 | # plt.imshow(y[:,:,0], cmap=plt.cm.gray) 290 | # plt.show() 291 | 292 | return x, y 293 | 294 | def main(config): 295 | prepare_dirs_and_logger(config) 296 | batch_manager = BatchManager(config) 297 | preprocess_path('data/qdraw/baseball/train/4503641325043712.svg', 128, 128, batch_manager.rng) 298 | preprocess_overlap('data/qdraw/baseball/train/4503641325043712.svg', 128, 128, batch_manager.rng) 299 | 300 | # thread test 301 | sess_config = tf.ConfigProto() 302 | sess_config.gpu_options.allow_growth = True 303 | sess_config.allow_soft_placement = True 304 | sess_config.log_device_placement = False 305 | sess = tf.Session(config=sess_config) 306 | batch_manager.start_thread(sess) 307 | 308 | x, y = batch_manager.batch() 309 | if config.data_format == 'NCHW': 310 | x = nhwc_to_nchw(x) 311 | x_, y_ = sess.run([x, y]) 312 | batch_manager.stop_thread() 313 | 314 | if config.data_format == 'NCHW': 315 | x_ = x_.transpose([0, 2, 3, 1]) 316 | 317 | if config.archi == 'path': 318 | b_ch = np.zeros([config.batch_size,config.height,config.width,1]) 319 | x_ = np.concatenate((x_*255, b_ch), axis=-1) 320 | else: 321 | x_ = x_*255 322 | y_ = y_*255 323 | 324 | save_image(x_, '{}/x_fixed.png'.format(config.model_dir)) 325 | save_image(y_, '{}/y_fixed.png'.format(config.model_dir)) 326 | 327 | 328 | # random pick from parameter space 329 | x_samples, x_gt, y_gt, sample_list = batch_manager.random_list(8) 330 | save_image(x_gt, '{}/x_gt.png'.format(config.model_dir)) 331 | save_image(y_gt, '{}/y_gt.png'.format(config.model_dir)) 332 | 333 | with open('{}/sample_list.txt'.format(config.model_dir), 'w') as f: 334 | for sample in sample_list: 335 | f.write(sample+'\n') 336 | 337 | print('batch manager test done') 338 | 339 | if __name__ == "__main__": 340 | from config import get_config 341 | from utils import prepare_dirs_and_logger, save_config, save_image 342 | 343 | config, unparsed = get_config() 344 | setattr(config, 'archi', 'path') # overlap 345 | setattr(config, 'dataset', 'baseball') # cat multi 346 | setattr(config, 'width', 128) 347 | setattr(config, 'height', 128) 348 | 349 | main(config) -------------------------------------------------------------------------------- /data_ch.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import threading 4 | import multiprocessing 5 | import signal 6 | import sys 7 | from datetime import datetime 8 | import platform 9 | 10 | import tensorflow as tf 11 | import numpy as np 12 | import cairosvg 13 | from PIL import Image 14 | import io 15 | import xml.etree.ElementTree as et 16 | import matplotlib.pyplot as plt 17 | 18 | from ops import * 19 | 20 | 21 | class BatchManager(object): 22 | def __init__(self, config): 23 | self.root = config.data_path 24 | self.rng = np.random.RandomState(config.random_seed) 25 | 26 | self.paths = sorted(glob("{}/train/*.{}".format(self.root, 'svg_pre'))) 27 | self.test_paths = sorted(glob("{}/test/*.{}".format(self.root, 'svg_pre'))) 28 | assert(len(self.paths) > 0 and len(self.test_paths) > 0) 29 | 30 | self.batch_size = config.batch_size 31 | self.height = config.height 32 | self.width = config.width 33 | 34 | self.is_pathnet = (config.archi == 'path') 35 | if self.is_pathnet: 36 | feature_dim = [self.height, self.width, 2] 37 | label_dim = [self.height, self.width, 1] 38 | else: 39 | feature_dim = [self.height, self.width, 1] 40 | label_dim = [self.height, self.width, 1] 41 | 42 | self.capacity = 10000 43 | self.q = tf.FIFOQueue(self.capacity, [tf.float32, tf.float32], [feature_dim, label_dim]) 44 | self.x = tf.placeholder(dtype=tf.float32, shape=feature_dim) 45 | self.y = tf.placeholder(dtype=tf.float32, shape=label_dim) 46 | self.enqueue = self.q.enqueue([self.x, self.y]) 47 | self.num_threads = config.num_worker 48 | # np.amin([config.num_worker, multiprocessing.cpu_count(), self.batch_size]) 49 | 50 | def __del__(self): 51 | try: 52 | self.stop_thread() 53 | except AttributeError: 54 | pass 55 | 56 | def start_thread(self, sess): 57 | print('%s: start to enque with %d threads' % (datetime.now(), self.num_threads)) 58 | 59 | # Main thread: create a coordinator. 60 | self.sess = sess 61 | self.coord = tf.train.Coordinator() 62 | 63 | # Create a method for loading and enqueuing 64 | def load_n_enqueue(sess, enqueue, coord, paths, rng, 65 | x, y, w, h, is_pathnet): 66 | with coord.stop_on_exception(): 67 | while not coord.should_stop(): 68 | id = rng.randint(len(paths)) 69 | if is_pathnet: 70 | x_, y_ = preprocess_path(paths[id], w, h, rng) 71 | else: 72 | x_, y_ = preprocess_overlap(paths[id], w, h, rng) 73 | sess.run(enqueue, feed_dict={x: x_, y: y_}) 74 | 75 | # Create threads that enqueue 76 | self.threads = [threading.Thread(target=load_n_enqueue, 77 | args=(self.sess, 78 | self.enqueue, 79 | self.coord, 80 | self.paths, 81 | self.rng, 82 | self.x, 83 | self.y, 84 | self.width, 85 | self.height, 86 | self.is_pathnet) 87 | ) for i in range(self.num_threads)] 88 | 89 | # define signal handler 90 | def signal_handler(signum, frame): 91 | #print "stop training, save checkpoint..." 92 | #saver.save(sess, "./checkpoints/VDSR_norm_clip_epoch_%03d.ckpt" % epoch ,global_step=global_step) 93 | print('%s: canceled by SIGINT' % datetime.now()) 94 | self.coord.request_stop() 95 | self.sess.run(self.q.close(cancel_pending_enqueues=True)) 96 | self.coord.join(self.threads) 97 | sys.exit(1) 98 | signal.signal(signal.SIGINT, signal_handler) 99 | 100 | # Start the threads and wait for all of them to stop. 101 | for t in self.threads: 102 | t.start() 103 | 104 | # dirty way to bypass graph finilization error 105 | g = tf.get_default_graph() 106 | g._finalized = False 107 | qs = 0 108 | while qs < (self.capacity*0.8): 109 | qs = self.sess.run(self.q.size()) 110 | print('%s: q size %d' % (datetime.now(), qs)) 111 | 112 | def stop_thread(self): 113 | # dirty way to bypass graph finilization error 114 | g = tf.get_default_graph() 115 | g._finalized = False 116 | 117 | self.coord.request_stop() 118 | self.sess.run(self.q.close(cancel_pending_enqueues=True)) 119 | self.coord.join(self.threads) 120 | 121 | def test_batch(self): 122 | x_list, y_list = [], [] 123 | for i, file_path in enumerate(self.test_paths): 124 | if self.is_pathnet: 125 | x_, y_ = preprocess_path(file_path, self.width, self.height, self.rng) 126 | else: 127 | x_, y_ = preprocess_overlap(file_path, self.width, self.height, self.rng) 128 | x_list.append(x_) 129 | y_list.append(y_) 130 | if i % self.batch_size == self.batch_size-1: 131 | yield np.array(x_list), np.array(y_list) 132 | x_list, y_list = [], [] 133 | 134 | def batch(self): 135 | return self.q.dequeue_many(self.batch_size) 136 | 137 | def sample(self, num): 138 | idx = self.rng.choice(len(self.paths), num).tolist() 139 | return [self.paths[i] for i in idx] 140 | 141 | def random_list(self, num): 142 | x_list = [] 143 | xs, ys = [], [] 144 | file_list = self.sample(num) 145 | for file_path in file_list: 146 | if self.is_pathnet: 147 | x, y = preprocess_path(file_path, self.width, self.height, self.rng) 148 | else: 149 | x, y = preprocess_overlap(file_path, self.width, self.height, self.rng) 150 | x_list.append(x) 151 | 152 | if self.is_pathnet: 153 | b_ch = np.zeros([self.height,self.width,1]) 154 | xs.append(np.concatenate((x*255, b_ch), axis=-1)) 155 | else: 156 | xs.append(x*255) 157 | ys.append(y*255) 158 | 159 | return np.array(x_list), np.array(xs), np.array(ys), file_list 160 | 161 | def read_svg(self, file_path): 162 | with open(file_path, 'r') as f: 163 | svg = f.read() 164 | 165 | r = 0 166 | s = [1, -1] 167 | t = [0, -900] 168 | # if transform: 169 | # r = rng.randint(-45, 45) 170 | # # s_sign = rng.choice([1, -1], 1)[0] 171 | # s_sign = -1 172 | # s = 1.75 * rng.random_sample(2) + 0.25 # [0.25, 2) 173 | # s[1] = s[1] * s_sign 174 | # t = rng.randint(-100, 100, 2) 175 | # if s_sign == 1: 176 | # t[1] = t[1] + 124 177 | # else: 178 | # t[1] = t[1] - 900 179 | 180 | svg = svg.format(w=self.width, h=self.height, r=r, sx=s[0], sy=s[1], tx=t[0], ty=t[1]) 181 | img = cairosvg.svg2png(bytestring=svg.encode('utf-8')) 182 | img = Image.open(io.BytesIO(img)) 183 | s = np.array(img)[:,:,3].astype(np.float) # / 255.0 184 | max_intensity = np.amax(s) 185 | s = s / max_intensity 186 | 187 | path_list = [] 188 | svg_xml = et.fromstring(svg) 189 | 190 | sys_name = platform.system() 191 | if sys_name == 'Windows': 192 | num_paths = len(svg_xml[0]._children) 193 | else: 194 | num_paths = len(svg_xml[0]) 195 | 196 | for i in range(num_paths): 197 | svg_xml = et.fromstring(svg) 198 | 199 | if sys_name == 'Windows': 200 | svg_xml[0]._children = [svg_xml[0]._children[i]] 201 | else: 202 | svg_xml[0][0] = svg_xml[0][i] 203 | del svg_xml[0][1:] 204 | svg_one = et.tostring(svg_xml, method='xml') 205 | 206 | # leave only one path 207 | y_png = cairosvg.svg2png(bytestring=svg_one) 208 | y_img = Image.open(io.BytesIO(y_png)) 209 | path = (np.array(y_img)[:,:,3] > 0) 210 | path_list.append(path) 211 | 212 | return s, num_paths, path_list 213 | 214 | def preprocess_path(file_path, w, h, rng): 215 | with open(file_path, 'r') as f: 216 | svg = f.read() 217 | 218 | r = 0 219 | s = [1, -1] 220 | t = [0, -900] 221 | # if transform: 222 | # r = rng.randint(-45, 45) 223 | # # s_sign = rng.choice([1, -1], 1)[0] 224 | # s_sign = -1 225 | # s = 1.75 * rng.random_sample(2) + 0.25 # [0.25, 2) 226 | # s[1] = s[1] * s_sign 227 | # t = rng.randint(-100, 100, 2) 228 | # if s_sign == 1: 229 | # t[1] = t[1] + 124 230 | # else: 231 | # t[1] = t[1] - 900 232 | 233 | svg = svg.format(w=w, h=h, r=r, sx=s[0], sy=s[1], tx=t[0], ty=t[1]) 234 | img = cairosvg.svg2png(bytestring=svg.encode('utf-8')) 235 | img = Image.open(io.BytesIO(img)) 236 | s = np.array(img)[:,:,3].astype(np.float) # / 255.0 237 | max_intensity = np.amax(s) 238 | s = s / max_intensity 239 | 240 | # while True: 241 | svg_xml = et.fromstring(svg) 242 | 243 | sys_name = platform.system() 244 | if sys_name == 'Windows': 245 | path_id = rng.randint(len(svg_xml[0]._children)) 246 | svg_xml[0]._children = [svg_xml[0]._children[path_id]] 247 | else: 248 | path_id = rng.randint(len(svg_xml[0])) 249 | svg_xml[0][0] = svg_xml[0][path_id] 250 | del svg_xml[0][1:] 251 | svg_one = et.tostring(svg_xml, method='xml') 252 | 253 | # leave only one path 254 | y_png = cairosvg.svg2png(bytestring=svg_one) 255 | y_img = Image.open(io.BytesIO(y_png)) 256 | y = np.array(y_img)[:,:,3].astype(np.float) / max_intensity # [0,1] 257 | 258 | pixel_ids = np.nonzero(y) 259 | # if len(pixel_ids[0]) == 0: 260 | # continue 261 | # else: 262 | # break 263 | 264 | # select arbitrary marking pixel 265 | point_id = rng.randint(len(pixel_ids[0])) 266 | px, py = pixel_ids[0][point_id], pixel_ids[1][point_id] 267 | 268 | y = np.reshape(y, [h, w, 1]) 269 | x = np.zeros([h, w, 2]) 270 | x[:,:,0] = s 271 | x[px,py,1] = 1.0 272 | 273 | # # debug 274 | # plt.figure() 275 | # plt.subplot(221) 276 | # plt.imshow(img) 277 | # plt.subplot(222) 278 | # plt.imshow(s, cmap=plt.cm.gray) 279 | # plt.subplot(223) 280 | # plt.imshow(np.concatenate((x, np.zeros([h, w, 1])), axis=-1)) 281 | # plt.subplot(224) 282 | # plt.imshow(y[:,:,0], cmap=plt.cm.gray) 283 | # plt.show() 284 | 285 | return x, y 286 | 287 | def preprocess_overlap(file_path, w, h, rng): 288 | with open(file_path, 'r') as f: 289 | svg = f.read() 290 | 291 | r = 0 292 | s = [1, -1] 293 | t = [0, -900] 294 | # if transform: 295 | # r = rng.randint(-45, 45) 296 | # # s_sign = rng.choice([1, -1], 1)[0] 297 | # s_sign = -1 298 | # s = 1.75 * rng.random_sample(2) + 0.25 # [0.25, 2) 299 | # s[1] = s[1] * s_sign 300 | # t = rng.randint(-100, 100, 2) 301 | # if s_sign == 1: 302 | # t[1] = t[1] + 124 303 | # else: 304 | # t[1] = t[1] - 900 305 | 306 | svg = svg.format(w=w, h=h, r=r, sx=s[0], sy=s[1], tx=t[0], ty=t[1]) 307 | img = cairosvg.svg2png(bytestring=svg.encode('utf-8')) 308 | img = Image.open(io.BytesIO(img)) 309 | s = np.array(img)[:,:,3].astype(np.float) # / 255.0 310 | max_intensity = np.amax(s) 311 | s = s / max_intensity 312 | 313 | # while True: 314 | path_list = [] 315 | svg_xml = et.fromstring(svg) 316 | 317 | sys_name = platform.system() 318 | if sys_name == 'Windows': 319 | num_paths = len(svg_xml[0]._children) 320 | else: 321 | num_paths = len(svg_xml[0]) 322 | 323 | for i in range(num_paths): 324 | svg_xml = et.fromstring(svg) 325 | 326 | if sys_name == 'Windows': 327 | svg_xml[0]._children = [svg_xml[0]._children[i]] 328 | else: 329 | svg_xml[0][0] = svg_xml[0][i] 330 | del svg_xml[0][1:] 331 | svg_one = et.tostring(svg_xml, method='xml') 332 | 333 | # leave only one path 334 | y_png = cairosvg.svg2png(bytestring=svg_one) 335 | y_img = Image.open(io.BytesIO(y_png)) 336 | path = (np.array(y_img)[:,:,3] > 0) 337 | path_list.append(path) 338 | 339 | y = np.zeros([h, w], dtype=np.int) 340 | for i in range(num_paths-1): 341 | for j in range(i+1, num_paths): 342 | intersect = np.logical_and(path_list[i], path_list[j]) 343 | y = np.logical_or(intersect, y) 344 | 345 | x = np.expand_dims(s, axis=-1) 346 | y = np.expand_dims(y, axis=-1) 347 | 348 | # # debug 349 | # plt.figure() 350 | # plt.subplot(131) 351 | # plt.imshow(img) 352 | # plt.subplot(132) 353 | # plt.imshow(s, cmap=plt.cm.gray) 354 | # plt.subplot(133) 355 | # plt.imshow(y[:,:,0], cmap=plt.cm.gray) 356 | # plt.show() 357 | 358 | return x, y 359 | 360 | def main(config): 361 | prepare_dirs_and_logger(config) 362 | batch_manager = BatchManager(config) 363 | preprocess_path('data/ch/train/11904.svg_pre', 64, 64, batch_manager.rng) 364 | preprocess_overlap('data/ch/train/11904.svg_pre', 64, 64, batch_manager.rng) 365 | 366 | # thread test 367 | sess_config = tf.ConfigProto() 368 | sess_config.gpu_options.allow_growth = True 369 | sess_config.allow_soft_placement = True 370 | sess_config.log_device_placement = False 371 | sess = tf.Session(config=sess_config) 372 | batch_manager.start_thread(sess) 373 | 374 | x, y = batch_manager.batch() 375 | if config.data_format == 'NCHW': 376 | x = nhwc_to_nchw(x) 377 | x_, y_ = sess.run([x, y]) 378 | batch_manager.stop_thread() 379 | 380 | if config.data_format == 'NCHW': 381 | x_ = x_.transpose([0, 2, 3, 1]) 382 | 383 | if config.archi == 'path': 384 | b_ch = np.zeros([config.batch_size,config.height,config.width,1]) 385 | x_ = np.concatenate((x_*255, b_ch), axis=-1) 386 | else: 387 | x_ = x_*255 388 | y_ = y_*255 389 | 390 | save_image(x_, '{}/x_fixed.png'.format(config.model_dir)) 391 | save_image(y_, '{}/y_fixed.png'.format(config.model_dir)) 392 | 393 | 394 | # random pick from parameter space 395 | x_samples, x_gt, y_gt, sample_list = batch_manager.random_list(8) 396 | save_image(x_gt, '{}/x_gt.png'.format(config.model_dir)) 397 | save_image(y_gt, '{}/y_gt.png'.format(config.model_dir)) 398 | 399 | with open('{}/sample_list.txt'.format(config.model_dir), 'w') as f: 400 | for sample in sample_list: 401 | f.write(sample+'\n') 402 | 403 | print('batch manager test done') 404 | 405 | if __name__ == "__main__": 406 | from config import get_config 407 | from utils import prepare_dirs_and_logger, save_config, save_image 408 | 409 | config, unparsed = get_config() 410 | setattr(config, 'archi', 'path') # overlap 411 | setattr(config, 'dataset', 'ch') 412 | setattr(config, 'width', 64) 413 | setattr(config, 'height', 64) 414 | 415 | main(config) -------------------------------------------------------------------------------- /data_kanji.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import threading 4 | import multiprocessing 5 | import signal 6 | import sys 7 | from datetime import datetime 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | import cairosvg 12 | from PIL import Image 13 | import io 14 | import xml.etree.ElementTree as et 15 | import matplotlib.pyplot as plt 16 | 17 | from ops import * 18 | 19 | 20 | class BatchManager(object): 21 | def __init__(self, config): 22 | self.root = config.data_path 23 | self.rng = np.random.RandomState(config.random_seed) 24 | 25 | self.paths = sorted(glob("{}/train/*.{}".format(self.root, 'svg_pre'))) 26 | self.test_paths = sorted(glob("{}/test/*.{}".format(self.root, 'svg_pre'))) 27 | assert(len(self.paths) > 0 and len(self.test_paths) > 0) 28 | 29 | self.batch_size = config.batch_size 30 | self.height = config.height 31 | self.width = config.width 32 | 33 | self.is_pathnet = (config.archi == 'path') 34 | if self.is_pathnet: 35 | feature_dim = [self.height, self.width, 2] 36 | label_dim = [self.height, self.width, 1] 37 | else: 38 | feature_dim = [self.height, self.width, 1] 39 | label_dim = [self.height, self.width, 1] 40 | 41 | self.capacity = 10000 42 | self.q = tf.FIFOQueue(self.capacity, [tf.float32, tf.float32], [feature_dim, label_dim]) 43 | self.x = tf.placeholder(dtype=tf.float32, shape=feature_dim) 44 | self.y = tf.placeholder(dtype=tf.float32, shape=label_dim) 45 | self.enqueue = self.q.enqueue([self.x, self.y]) 46 | self.num_threads = config.num_worker 47 | # np.amin([config.num_worker, multiprocessing.cpu_count(), self.batch_size]) 48 | 49 | def __del__(self): 50 | try: 51 | self.stop_thread() 52 | except AttributeError: 53 | pass 54 | 55 | def start_thread(self, sess): 56 | print('%s: start to enque with %d threads' % (datetime.now(), self.num_threads)) 57 | 58 | # Main thread: create a coordinator. 59 | self.sess = sess 60 | self.coord = tf.train.Coordinator() 61 | 62 | # Create a method for loading and enqueuing 63 | def load_n_enqueue(sess, enqueue, coord, paths, rng, 64 | x, y, w, h, is_pathnet): 65 | with coord.stop_on_exception(): 66 | while not coord.should_stop(): 67 | id = rng.randint(len(paths)) 68 | if is_pathnet: 69 | x_, y_ = preprocess_path(paths[id], w, h, rng) 70 | else: 71 | x_, y_ = preprocess_overlap(paths[id], w, h, rng) 72 | sess.run(enqueue, feed_dict={x: x_, y: y_}) 73 | 74 | # Create threads that enqueue 75 | self.threads = [threading.Thread(target=load_n_enqueue, 76 | args=(self.sess, 77 | self.enqueue, 78 | self.coord, 79 | self.paths, 80 | self.rng, 81 | self.x, 82 | self.y, 83 | self.width, 84 | self.height, 85 | self.is_pathnet) 86 | ) for i in range(self.num_threads)] 87 | 88 | # define signal handler 89 | def signal_handler(signum, frame): 90 | #print "stop training, save checkpoint..." 91 | #saver.save(sess, "./checkpoints/VDSR_norm_clip_epoch_%03d.ckpt" % epoch ,global_step=global_step) 92 | print('%s: canceled by SIGINT' % datetime.now()) 93 | self.coord.request_stop() 94 | self.sess.run(self.q.close(cancel_pending_enqueues=True)) 95 | self.coord.join(self.threads) 96 | sys.exit(1) 97 | signal.signal(signal.SIGINT, signal_handler) 98 | 99 | # Start the threads and wait for all of them to stop. 100 | for t in self.threads: 101 | t.start() 102 | 103 | # dirty way to bypass graph finilization error 104 | g = tf.get_default_graph() 105 | g._finalized = False 106 | qs = 0 107 | while qs < (self.capacity*0.8): 108 | qs = self.sess.run(self.q.size()) 109 | print('%s: q size %d' % (datetime.now(), qs)) 110 | 111 | def stop_thread(self): 112 | # dirty way to bypass graph finilization error 113 | g = tf.get_default_graph() 114 | g._finalized = False 115 | 116 | self.coord.request_stop() 117 | self.sess.run(self.q.close(cancel_pending_enqueues=True)) 118 | self.coord.join(self.threads) 119 | 120 | def test_batch(self): 121 | x_list, y_list = [], [] 122 | for i, file_path in enumerate(self.test_paths): 123 | if self.is_pathnet: 124 | x_, y_ = preprocess_path(file_path, self.width, self.height, self.rng) 125 | else: 126 | x_, y_ = preprocess_overlap(file_path, self.width, self.height, self.rng) 127 | x_list.append(x_) 128 | y_list.append(y_) 129 | if i % self.batch_size == self.batch_size-1: 130 | yield np.array(x_list), np.array(y_list) 131 | x_list, y_list = [], [] 132 | 133 | def batch(self): 134 | return self.q.dequeue_many(self.batch_size) 135 | 136 | def sample(self, num): 137 | idx = self.rng.choice(len(self.paths), num).tolist() 138 | return [self.paths[i] for i in idx] 139 | 140 | def random_list(self, num): 141 | x_list = [] 142 | xs, ys = [], [] 143 | file_list = self.sample(num) 144 | for file_path in file_list: 145 | if self.is_pathnet: 146 | x, y = preprocess_path(file_path, self.width, self.height, self.rng) 147 | else: 148 | x, y = preprocess_overlap(file_path, self.width, self.height, self.rng) 149 | x_list.append(x) 150 | 151 | if self.is_pathnet: 152 | b_ch = np.zeros([self.height,self.width,1]) 153 | xs.append(np.concatenate((x*255, b_ch), axis=-1)) 154 | else: 155 | xs.append(x*255) 156 | ys.append(y*255) 157 | 158 | return np.array(x_list), np.array(xs), np.array(ys), file_list 159 | 160 | def read_svg(self, file_path): 161 | with open(file_path, 'r', encoding='utf-8') as f: 162 | svg = f.read() 163 | 164 | r = 0 165 | s = [1, 1] 166 | t = [0, 0] 167 | # if transform: 168 | # r = rng.randint(-45, 45) 169 | # # s_sign = rng.choice([1, -1], 1)[0] 170 | # s_sign = 1 171 | # s = 1.75 * rng.random_sample(2) + 0.25 # [0.25, 2) 172 | # s[1] = s[1] * s_sign 173 | # t = rng.randint(-10, 10, 2) 174 | # if s_sign == -1: 175 | # t[1] = t[1] - 109 176 | 177 | svg = svg.format(w=self.width, h=self.height, r=r, sx=s[0], sy=s[1], tx=t[0], ty=t[1]) 178 | img = cairosvg.svg2png(bytestring=svg.encode('utf-8')) 179 | img = Image.open(io.BytesIO(img)) 180 | s = np.array(img)[:,:,3].astype(np.float) # / 255.0 181 | max_intensity = np.amax(s) 182 | s = s / max_intensity 183 | 184 | path_list = [] 185 | pid = 0 186 | num_paths = 0 187 | while pid != -1: 188 | pid = svg.find('path id', pid + 1) 189 | num_paths = num_paths + 1 190 | num_paths = num_paths - 1 # uncount last one 191 | 192 | for i in range(num_paths): 193 | svg_one = svg 194 | pid = len(svg_one) 195 | for j in range(num_paths): 196 | pid = svg_one.rfind('path id', 0, pid) 197 | if j != i: 198 | id_start = svg_one.rfind('>', 0, pid) + 1 199 | id_end = svg_one.find('/>', id_start) + 2 200 | svg_one = svg_one[:id_start] + svg_one[id_end:] 201 | 202 | # leave only one path 203 | y_png = cairosvg.svg2png(bytestring=svg_one.encode('utf-8')) 204 | y_img = Image.open(io.BytesIO(y_png)) 205 | path = (np.array(y_img)[:,:,3] > 0) 206 | path_list.append(path) 207 | 208 | return s, num_paths, path_list 209 | 210 | def preprocess_path(file_path, w, h, rng): 211 | with open(file_path, 'r', encoding='utf-8') as f: 212 | svg = f.read() 213 | 214 | r = 0 215 | s = [1, 1] 216 | t = [0, 0] 217 | # if transform: 218 | # r = rng.randint(-45, 45) 219 | # # s_sign = rng.choice([1, -1], 1)[0] 220 | # s_sign = 1 221 | # s = 1.75 * rng.random_sample(2) + 0.25 # [0.25, 2) 222 | # s[1] = s[1] * s_sign 223 | # t = rng.randint(-10, 10, 2) 224 | # if s_sign == -1: 225 | # t[1] = t[1] - 109 226 | 227 | svg = svg.format(w=w, h=h, r=r, sx=s[0], sy=s[1], tx=t[0], ty=t[1]) 228 | img = cairosvg.svg2png(bytestring=svg.encode('utf-8')) 229 | img = Image.open(io.BytesIO(img)) 230 | s = np.array(img)[:,:,3].astype(np.float) # / 255.0 231 | max_intensity = np.amax(s) 232 | s = s / max_intensity 233 | 234 | # while True: 235 | pid = 0 236 | num_paths = 0 237 | while pid != -1: 238 | pid = svg.find('path id', pid + 1) 239 | num_paths = num_paths + 1 240 | num_paths = num_paths - 1 # uncount last one 241 | 242 | path_id = rng.randint(num_paths) 243 | svg_one = svg 244 | pid = len(svg_one) 245 | for c in range(num_paths): 246 | pid = svg_one.rfind('path id', 0, pid) 247 | if c != path_id: 248 | id_start = svg_one.rfind('>', 0, pid) + 1 249 | id_end = svg_one.find('/>', id_start) + 2 250 | svg_one = svg_one[:id_start] + svg_one[id_end:] 251 | 252 | # leave only one path 253 | y_png = cairosvg.svg2png(bytestring=svg_one.encode('utf-8')) 254 | y_img = Image.open(io.BytesIO(y_png)) 255 | y = np.array(y_img)[:,:,3].astype(np.float) / max_intensity # [0,1] 256 | 257 | pixel_ids = np.nonzero(y) 258 | # if len(pixel_ids[0]) == 0: 259 | # continue 260 | # else: 261 | # break 262 | 263 | # select arbitrary marking pixel 264 | point_id = rng.randint(len(pixel_ids[0])) 265 | px, py = pixel_ids[0][point_id], pixel_ids[1][point_id] 266 | 267 | y = np.reshape(y, [h, w, 1]) 268 | x = np.zeros([h, w, 2]) 269 | x[:,:,0] = s 270 | x[px,py,1] = 1.0 271 | 272 | # # debug 273 | # plt.figure() 274 | # plt.subplot(221) 275 | # plt.imshow(img) 276 | # plt.subplot(222) 277 | # plt.imshow(s, cmap=plt.cm.gray) 278 | # plt.subplot(223) 279 | # plt.imshow(np.concatenate((x, np.zeros([h, w, 1])), axis=-1)) 280 | # plt.subplot(224) 281 | # plt.imshow(y[:,:,0], cmap=plt.cm.gray) 282 | # plt.show() 283 | 284 | return x, y 285 | 286 | def preprocess_overlap(file_path, w, h, rng): 287 | with open(file_path, 'r', encoding='utf-8') as f: 288 | svg = f.read() 289 | 290 | r = 0 291 | s = [1, 1] 292 | t = [0, 0] 293 | # if transform: 294 | # r = rng.randint(-45, 45) 295 | # # s_sign = rng.choice([1, -1], 1)[0] 296 | # s_sign = 1 297 | # s = 1.75 * rng.random_sample(2) + 0.25 # [0.25, 2) 298 | # s[1] = s[1] * s_sign 299 | # t = rng.randint(-10, 10, 2) 300 | # if s_sign == -1: 301 | # t[1] = t[1] - 109 302 | 303 | svg = svg.format(w=w, h=h, r=r, sx=s[0], sy=s[1], tx=t[0], ty=t[1]) 304 | img = cairosvg.svg2png(bytestring=svg.encode('utf-8')) 305 | img = Image.open(io.BytesIO(img)) 306 | s = np.array(img)[:,:,3].astype(np.float) # / 255.0 307 | max_intensity = np.amax(s) 308 | s = s / max_intensity 309 | 310 | # while True: 311 | path_list = [] 312 | pid = 0 313 | num_paths = 0 314 | while pid != -1: 315 | pid = svg.find('path id', pid + 1) 316 | num_paths = num_paths + 1 317 | num_paths = num_paths - 1 # uncount last one 318 | 319 | for i in range(num_paths): 320 | svg_one = svg 321 | pid = len(svg_one) 322 | for j in range(num_paths): 323 | pid = svg_one.rfind('path id', 0, pid) 324 | if j != i: 325 | id_start = svg_one.rfind('>', 0, pid) + 1 326 | id_end = svg_one.find('/>', id_start) + 2 327 | svg_one = svg_one[:id_start] + svg_one[id_end:] 328 | 329 | # leave only one path 330 | y_png = cairosvg.svg2png(bytestring=svg_one.encode('utf-8')) 331 | y_img = Image.open(io.BytesIO(y_png)) 332 | path = (np.array(y_img)[:,:,3] > 0) 333 | path_list.append(path) 334 | 335 | y = np.zeros([h, w], dtype=np.int) 336 | for i in range(num_paths-1): 337 | for j in range(i+1, num_paths): 338 | intersect = np.logical_and(path_list[i], path_list[j]) 339 | y = np.logical_or(intersect, y) 340 | 341 | x = np.expand_dims(s, axis=-1) 342 | y = np.expand_dims(y, axis=-1) 343 | 344 | # # debug 345 | # plt.figure() 346 | # plt.subplot(131) 347 | # plt.imshow(img) 348 | # plt.subplot(132) 349 | # plt.imshow(s, cmap=plt.cm.gray) 350 | # plt.subplot(133) 351 | # plt.imshow(y[:,:,0], cmap=plt.cm.gray) 352 | # plt.show() 353 | 354 | return x, y 355 | 356 | def main(config): 357 | prepare_dirs_and_logger(config) 358 | batch_manager = BatchManager(config) 359 | preprocess_path('data/kanji/train/0f9a8.svg_pre', 64, 64, batch_manager.rng) 360 | preprocess_overlap('data/kanji/train/0f9a8.svg_pre', 64, 64, batch_manager.rng) 361 | 362 | # thread test 363 | sess_config = tf.ConfigProto() 364 | sess_config.gpu_options.allow_growth = True 365 | sess_config.allow_soft_placement = True 366 | sess_config.log_device_placement = False 367 | sess = tf.Session(config=sess_config) 368 | batch_manager.start_thread(sess) 369 | 370 | x, y = batch_manager.batch() 371 | if config.data_format == 'NCHW': 372 | x = nhwc_to_nchw(x) 373 | x_, y_ = sess.run([x, y]) 374 | batch_manager.stop_thread() 375 | 376 | if config.data_format == 'NCHW': 377 | x_ = x_.transpose([0, 2, 3, 1]) 378 | 379 | if config.archi == 'path': 380 | b_ch = np.zeros([config.batch_size,config.height,config.width,1]) 381 | x_ = np.concatenate((x_*255, b_ch), axis=-1) 382 | else: 383 | x_ = x_*255 384 | y_ = y_*255 385 | 386 | save_image(x_, '{}/x_fixed.png'.format(config.model_dir)) 387 | save_image(y_, '{}/y_fixed.png'.format(config.model_dir)) 388 | 389 | 390 | # random pick from parameter space 391 | x_samples, x_gt, y_gt, sample_list = batch_manager.random_list(8) 392 | save_image(x_gt, '{}/x_gt.png'.format(config.model_dir)) 393 | save_image(y_gt, '{}/y_gt.png'.format(config.model_dir)) 394 | 395 | with open('{}/sample_list.txt'.format(config.model_dir), 'w') as f: 396 | for sample in sample_list: 397 | f.write(sample+'\n') 398 | 399 | print('batch manager test done') 400 | 401 | if __name__ == "__main__": 402 | from config import get_config 403 | from utils import prepare_dirs_and_logger, save_config, save_image 404 | 405 | config, unparsed = get_config() 406 | setattr(config, 'archi', 'path') # overlap 407 | setattr(config, 'dataset', 'kanji') 408 | setattr(config, 'width', 64) 409 | setattr(config, 'height', 64) 410 | 411 | main(config) -------------------------------------------------------------------------------- /data_line.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import threading 4 | import multiprocessing 5 | import signal 6 | import sys 7 | from datetime import datetime 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | import cairosvg 12 | from PIL import Image 13 | import io 14 | import xml.etree.ElementTree as et 15 | import matplotlib.pyplot as plt 16 | 17 | from ops import * 18 | 19 | 20 | SVG_START_TEMPLATE = """ 21 | 22 | \n""" 23 | SVG_LINE_TEMPLATE = """""" 24 | SVG_CUBIC_BEZIER_TEMPLATE = """""" 25 | SVG_END_TEMPLATE = """\n""" 26 | 27 | 28 | class BatchManager(object): 29 | def __init__(self, config): 30 | self.root = config.data_path 31 | self.rng = np.random.RandomState(config.random_seed) 32 | 33 | self.paths = sorted(glob("{}/train/*.{}".format(self.root, 'svg_pre'))) 34 | if len(self.paths) == 0: 35 | # create line dataset 36 | data_dir = os.path.join(config.data_dir, config.dataset) 37 | train_dir = os.path.join(data_dir, 'train') 38 | if not os.path.exists(train_dir): 39 | os.makedirs(train_dir) 40 | test_dir = os.path.join(data_dir, 'test') 41 | if not os.path.exists(test_dir): 42 | os.makedirs(test_dir) 43 | 44 | self.paths = gen_data(data_dir, config, self.rng, 45 | num_train=45000, num_test=5000) 46 | 47 | self.test_paths = sorted(glob("{}/test/*.{}".format(self.root, 'svg_pre'))) 48 | assert(len(self.paths) > 0 and len(self.test_paths) > 0) 49 | 50 | self.batch_size = config.batch_size 51 | self.height = config.height 52 | self.width = config.width 53 | 54 | self.is_pathnet = (config.archi == 'path') 55 | if self.is_pathnet: 56 | feature_dim = [self.height, self.width, 2] 57 | label_dim = [self.height, self.width, 1] 58 | else: 59 | feature_dim = [self.height, self.width, 1] 60 | label_dim = [self.height, self.width, 1] 61 | 62 | self.capacity = 10000 63 | self.q = tf.FIFOQueue(self.capacity, [tf.float32, tf.float32], [feature_dim, label_dim]) 64 | self.x = tf.placeholder(dtype=tf.float32, shape=feature_dim) 65 | self.y = tf.placeholder(dtype=tf.float32, shape=label_dim) 66 | self.enqueue = self.q.enqueue([self.x, self.y]) 67 | self.num_threads = config.num_worker 68 | # np.amin([config.num_worker, multiprocessing.cpu_count(), self.batch_size]) 69 | 70 | def __del__(self): 71 | try: 72 | self.stop_thread() 73 | except AttributeError: 74 | pass 75 | 76 | def start_thread(self, sess): 77 | print('%s: start to enque with %d threads' % (datetime.now(), self.num_threads)) 78 | 79 | # Main thread: create a coordinator. 80 | self.sess = sess 81 | self.coord = tf.train.Coordinator() 82 | 83 | # Create a method for loading and enqueuing 84 | def load_n_enqueue(sess, enqueue, coord, paths, rng, 85 | x, y, w, h, is_pathnet): 86 | with coord.stop_on_exception(): 87 | while not coord.should_stop(): 88 | id = rng.randint(len(paths)) 89 | if is_pathnet: 90 | x_, y_ = preprocess_path(paths[id], w, h, rng) 91 | else: 92 | x_, y_ = preprocess_overlap(paths[id], w, h, rng) 93 | sess.run(enqueue, feed_dict={x: x_, y: y_}) 94 | 95 | # Create threads that enqueue 96 | self.threads = [threading.Thread(target=load_n_enqueue, 97 | args=(self.sess, 98 | self.enqueue, 99 | self.coord, 100 | self.paths, 101 | self.rng, 102 | self.x, 103 | self.y, 104 | self.width, 105 | self.height, 106 | self.is_pathnet) 107 | ) for i in range(self.num_threads)] 108 | 109 | # define signal handler 110 | def signal_handler(signum, frame): 111 | #print "stop training, save checkpoint..." 112 | #saver.save(sess, "./checkpoints/VDSR_norm_clip_epoch_%03d.ckpt" % epoch ,global_step=global_step) 113 | print('%s: canceled by SIGINT' % datetime.now()) 114 | self.coord.request_stop() 115 | self.sess.run(self.q.close(cancel_pending_enqueues=True)) 116 | self.coord.join(self.threads) 117 | sys.exit(1) 118 | signal.signal(signal.SIGINT, signal_handler) 119 | 120 | # Start the threads and wait for all of them to stop. 121 | for t in self.threads: 122 | t.start() 123 | 124 | # dirty way to bypass graph finilization error 125 | g = tf.get_default_graph() 126 | g._finalized = False 127 | qs = 0 128 | while qs < (self.capacity*0.8): 129 | qs = self.sess.run(self.q.size()) 130 | print('%s: q size %d' % (datetime.now(), qs)) 131 | 132 | def stop_thread(self): 133 | # dirty way to bypass graph finilization error 134 | g = tf.get_default_graph() 135 | g._finalized = False 136 | 137 | self.coord.request_stop() 138 | self.sess.run(self.q.close(cancel_pending_enqueues=True)) 139 | self.coord.join(self.threads) 140 | 141 | def test_batch(self): 142 | x_list, y_list = [], [] 143 | for i, file_path in enumerate(self.test_paths): 144 | if self.is_pathnet: 145 | x_, y_ = preprocess_path(file_path, self.width, self.height, self.rng) 146 | else: 147 | x_, y_ = preprocess_overlap(file_path, self.width, self.height, self.rng) 148 | x_list.append(x_) 149 | y_list.append(y_) 150 | if i % self.batch_size == self.batch_size-1: 151 | yield np.array(x_list), np.array(y_list) 152 | x_list, y_list = [], [] 153 | 154 | def batch(self): 155 | return self.q.dequeue_many(self.batch_size) 156 | 157 | def sample(self, num): 158 | idx = self.rng.choice(len(self.paths), num).tolist() 159 | return [self.paths[i] for i in idx] 160 | 161 | def random_list(self, num): 162 | x_list = [] 163 | xs, ys = [], [] 164 | file_list = self.sample(num) 165 | for file_path in file_list: 166 | if self.is_pathnet: 167 | x, y = preprocess_path(file_path, self.width, self.height, self.rng) 168 | else: 169 | x, y = preprocess_overlap(file_path, self.width, self.height, self.rng) 170 | x_list.append(x) 171 | 172 | if self.is_pathnet: 173 | b_ch = np.zeros([self.height,self.width,1]) 174 | xs.append(np.concatenate((x*255, b_ch), axis=-1)) 175 | else: 176 | xs.append(x*255) 177 | ys.append(y*255) 178 | 179 | return np.array(x_list), np.array(xs), np.array(ys), file_list 180 | 181 | def read_svg(self, file_path): 182 | with open(file_path, 'r') as f: 183 | svg = f.read() 184 | 185 | svg = svg.format(w=self.width, h=self.height) 186 | img = cairosvg.svg2png(bytestring=svg.encode('utf-8')) 187 | img = Image.open(io.BytesIO(img)) 188 | s = np.array(img)[:,:,3].astype(np.float) # / 255.0 189 | max_intensity = np.amax(s) 190 | s = s / max_intensity 191 | 192 | path_list = [] 193 | svg_xml = et.fromstring(svg) 194 | num_paths = len(svg_xml[0]) 195 | 196 | for i in range(num_paths): 197 | svg_xml = et.fromstring(svg) 198 | svg_xml[0][0] = svg_xml[0][i] 199 | del svg_xml[0][1:] 200 | svg_one = et.tostring(svg_xml, method='xml') 201 | 202 | # leave only one path 203 | y_png = cairosvg.svg2png(bytestring=svg_one) 204 | y_img = Image.open(io.BytesIO(y_png)) 205 | path = (np.array(y_img)[:,:,3] > 0) 206 | path_list.append(path) 207 | 208 | return s, num_paths, path_list 209 | 210 | 211 | def draw_line(id, w, h, min_length, max_stroke_width, rng): 212 | stroke_color = rng.randint(240, size=3) 213 | # stroke_width = rng.randint(low=1, high=max_stroke_width+1) 214 | stroke_width = max_stroke_width 215 | while True: 216 | x = rng.randint(w, size=2) 217 | y = rng.randint(h, size=2) 218 | if x[0] - x[1] + y[0] - y[1] < min_length: 219 | continue 220 | break 221 | 222 | return SVG_LINE_TEMPLATE.format( 223 | id=id, 224 | x1=x[0], y1=y[0], 225 | x2=x[1], y2=y[1], 226 | r=stroke_color[0], g=stroke_color[1], b=stroke_color[2], 227 | sw=stroke_width 228 | ) 229 | 230 | def draw_cubic_bezier_curve(id, w, h, min_length, max_stroke_width, rng): 231 | stroke_color = rng.randint(240, size=3) 232 | # stroke_width = rng.randint(low=1, high=max_stroke_width+1) 233 | stroke_width = max_stroke_width 234 | x = rng.randint(w, size=4) 235 | y = rng.randint(h, size=4) 236 | 237 | return SVG_CUBIC_BEZIER_TEMPLATE.format( 238 | id=id, 239 | sx=x[0], sy=y[0], 240 | cx1=x[1], cy1=y[1], 241 | cx2=x[2], cy2=y[2], 242 | tx=x[3], ty=y[3], 243 | r=stroke_color[0], g=stroke_color[1], b=stroke_color[2], 244 | sw=stroke_width 245 | ) 246 | 247 | def draw_path(stroke_type, id, w, h, min_length, max_stroke_width, rng): 248 | if stroke_type == 2: 249 | stroke_type = rng.randint(2) 250 | 251 | path_selector = { 252 | 0: draw_line, 253 | 1: draw_cubic_bezier_curve 254 | } 255 | 256 | return path_selector[stroke_type](id, w, h,min_length, max_stroke_width, rng) 257 | 258 | def gen_data(data_dir, config, rng, num_train, num_test): 259 | file_list = [] 260 | num = num_train + num_test 261 | for file_id in range(num): 262 | while True: 263 | svg = SVG_START_TEMPLATE.format( 264 | w=config.width, 265 | h=config.height, 266 | ) 267 | svgpre = SVG_START_TEMPLATE 268 | 269 | for i in range(config.num_strokes): 270 | path = draw_path( 271 | stroke_type=config.stroke_type, 272 | id=i, 273 | w=config.width, 274 | h=config.height, 275 | min_length=config.min_length, 276 | max_stroke_width=config.max_stroke_width, 277 | rng=rng, 278 | ) 279 | svg += path + '\n' 280 | svgpre += path + '\n' 281 | 282 | svg += SVG_END_TEMPLATE 283 | svgpre += SVG_END_TEMPLATE 284 | s_png = cairosvg.svg2png(bytestring=svg.encode('utf-8')) 285 | s_img = Image.open(io.BytesIO(s_png)) 286 | s = np.array(s_img)[:,:,3].astype(np.float) # / 255.0 287 | max_intensity = np.amax(s) 288 | 289 | if max_intensity == 0: 290 | continue 291 | else: 292 | s = s / max_intensity # [0,1] 293 | break 294 | 295 | if file_id < num_train: 296 | cat = 'train' 297 | else: 298 | cat = 'test' 299 | 300 | # svgpre 301 | svgpre_file_path = os.path.join(data_dir, cat, '%d.svg_pre' % file_id) 302 | print(svgpre_file_path) 303 | with open(svgpre_file_path, 'w') as f: 304 | f.write(svgpre) 305 | 306 | # svg and jpg for reference 307 | svg_dir = os.path.join(data_dir, 'svg') 308 | if not os.path.exists(svg_dir): 309 | os.makedirs(svg_dir) 310 | jpg_dir = os.path.join(data_dir, 'jpg') 311 | if not os.path.exists(jpg_dir): 312 | os.makedirs(jpg_dir) 313 | 314 | svg_file_path = os.path.join(data_dir, 'svg', '%d.svg' % file_id) 315 | jpg_file_path = os.path.join(data_dir, 'jpg', '%d.jpg' % file_id) 316 | 317 | with open(svg_file_path, 'w') as f: 318 | f.write(svg) 319 | s_img.convert('RGB').save(jpg_file_path) 320 | 321 | if file_id < num_train: 322 | file_list.append(svgpre_file_path) 323 | 324 | return file_list 325 | 326 | def preprocess_path(file_path, w, h, rng): 327 | with open(file_path, 'r') as f: 328 | svg = f.read() 329 | 330 | svg = svg.format(w=w, h=h) 331 | img = cairosvg.svg2png(bytestring=svg.encode('utf-8')) 332 | img = Image.open(io.BytesIO(img)) 333 | s = np.array(img)[:,:,3].astype(np.float) # / 255.0 334 | max_intensity = np.amax(s) 335 | s = s / max_intensity 336 | 337 | # while True: 338 | svg_xml = et.fromstring(svg) 339 | path_id = rng.randint(len(svg_xml[0])) 340 | svg_xml[0][0] = svg_xml[0][path_id] 341 | del svg_xml[0][1:] 342 | svg_one = et.tostring(svg_xml, method='xml') 343 | 344 | # leave only one path 345 | y_png = cairosvg.svg2png(bytestring=svg_one) 346 | y_img = Image.open(io.BytesIO(y_png)) 347 | y = np.array(y_img)[:,:,3].astype(np.float) / max_intensity # [0,1] 348 | 349 | pixel_ids = np.nonzero(y) 350 | # if len(pixel_ids[0]) == 0: 351 | # continue 352 | # else: 353 | # break 354 | 355 | # select arbitrary marking pixel 356 | point_id = rng.randint(len(pixel_ids[0])) 357 | px, py = pixel_ids[0][point_id], pixel_ids[1][point_id] 358 | 359 | y = np.reshape(y, [h, w, 1]) 360 | x = np.zeros([h, w, 2]) 361 | x[:,:,0] = s 362 | x[px,py,1] = 1.0 363 | 364 | # # debug 365 | # plt.figure() 366 | # plt.subplot(221) 367 | # plt.imshow(img) 368 | # plt.subplot(222) 369 | # plt.imshow(s, cmap=plt.cm.gray) 370 | # plt.subplot(223) 371 | # plt.imshow(np.concatenate((x, np.zeros([h, w, 1])), axis=-1)) 372 | # plt.subplot(224) 373 | # plt.imshow(y[:,:,0], cmap=plt.cm.gray) 374 | # plt.show() 375 | 376 | return x, y 377 | 378 | def preprocess_overlap(file_path, w, h, rng): 379 | with open(file_path, 'r') as f: 380 | svg = f.read() 381 | svg = svg.format(w=w, h=h) 382 | img = cairosvg.svg2png(bytestring=svg.encode('utf-8')) 383 | img = Image.open(io.BytesIO(img)) 384 | s = np.array(img)[:,:,3].astype(np.float) # / 255.0 385 | max_intensity = np.amax(s) 386 | s = s / max_intensity 387 | 388 | # while True: 389 | path_list = [] 390 | svg_xml = et.fromstring(svg) 391 | num_paths = len(svg_xml[0]) 392 | 393 | for i in range(num_paths): 394 | svg_xml = et.fromstring(svg) 395 | svg_xml[0][0] = svg_xml[0][i] 396 | del svg_xml[0][1:] 397 | svg_one = et.tostring(svg_xml, method='xml') 398 | 399 | # leave only one path 400 | y_png = cairosvg.svg2png(bytestring=svg_one) 401 | y_img = Image.open(io.BytesIO(y_png)) 402 | path = (np.array(y_img)[:,:,3] > 0) 403 | path_list.append(path) 404 | 405 | y = np.zeros([h, w], dtype=np.int) 406 | for i in range(num_paths-1): 407 | for j in range(i+1, num_paths): 408 | intersect = np.logical_and(path_list[i], path_list[j]) 409 | y = np.logical_or(intersect, y) 410 | 411 | x = np.expand_dims(s, axis=-1) 412 | y = np.expand_dims(y, axis=-1) 413 | 414 | # # debug 415 | # plt.figure() 416 | # plt.subplot(131) 417 | # plt.imshow(img) 418 | # plt.subplot(132) 419 | # plt.imshow(s, cmap=plt.cm.gray) 420 | # plt.subplot(133) 421 | # plt.imshow(y[:,:,0], cmap=plt.cm.gray) 422 | # plt.show() 423 | 424 | return x, y 425 | 426 | def main(config): 427 | prepare_dirs_and_logger(config) 428 | batch_manager = BatchManager(config) 429 | # preprocess_path('data/line/train/0.svg_pre', 64, 64, batch_manager.rng) 430 | # preprocess_overlap('data/line/train/0.svg_pre', 64, 64, batch_manager.rng) 431 | 432 | # thread test 433 | sess_config = tf.ConfigProto() 434 | sess_config.gpu_options.allow_growth = True 435 | sess_config.allow_soft_placement = True 436 | sess_config.log_device_placement = False 437 | sess = tf.Session(config=sess_config) 438 | batch_manager.start_thread(sess) 439 | 440 | x, y = batch_manager.batch() 441 | if config.data_format == 'NCHW': 442 | x = nhwc_to_nchw(x) 443 | x_, y_ = sess.run([x, y]) 444 | batch_manager.stop_thread() 445 | 446 | if config.data_format == 'NCHW': 447 | x_ = x_.transpose([0, 2, 3, 1]) 448 | 449 | if config.archi == 'path': 450 | b_ch = np.zeros([config.batch_size,config.height,config.width,1]) 451 | x_ = np.concatenate((x_*255, b_ch), axis=-1) 452 | else: 453 | x_ = x_*255 454 | y_ = y_*255 455 | 456 | save_image(x_, '{}/x_fixed.png'.format(config.model_dir)) 457 | save_image(y_, '{}/y_fixed.png'.format(config.model_dir)) 458 | 459 | 460 | # random pick from parameter space 461 | x_samples, x_gt, y_gt, sample_list = batch_manager.random_list(8) 462 | save_image(x_gt, '{}/x_gt.png'.format(config.model_dir)) 463 | save_image(y_gt, '{}/y_gt.png'.format(config.model_dir)) 464 | 465 | with open('{}/sample_list.txt'.format(config.model_dir), 'w') as f: 466 | for sample in sample_list: 467 | f.write(sample+'\n') 468 | 469 | print('batch manager test done') 470 | 471 | if __name__ == "__main__": 472 | from config import get_config 473 | from utils import prepare_dirs_and_logger, save_config, save_image 474 | 475 | config, unparsed = get_config() 476 | setattr(config, 'archi', 'path') # overlap 477 | setattr(config, 'dataset', 'line') 478 | setattr(config, 'width', 64) 479 | setattr(config, 'height', 64) 480 | 481 | main(config) -------------------------------------------------------------------------------- /gco/QPBO_maxflow.cpp: -------------------------------------------------------------------------------- 1 | /* QPBO_maxflow.cpp */ 2 | /* 3 | Copyright 2006-2008 Vladimir Kolmogorov (vnk@ist.ac.at). 4 | 5 | This file is part of QPBO. 6 | 7 | QPBO is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | QPBO is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with QPBO. If not, see . 19 | */ 20 | 21 | 22 | #include 23 | #include "QPBO.h" 24 | 25 | 26 | 27 | #define INFINITE_D ((int)(((unsigned)-1)/2)) /* infinite distance to the terminal */ 28 | 29 | /***********************************************************************/ 30 | 31 | /* 32 | Functions for processing active list. 33 | i->next points to the next node in the list 34 | (or to i, if i is the last node in the list). 35 | If i->next is NULL iff i is not in the list. 36 | 37 | There are two queues. Active nodes are added 38 | to the end of the second queue and read from 39 | the front of the first queue. If the first queue 40 | is empty, it is replaced by the second queue 41 | (and the second queue becomes empty). 42 | */ 43 | 44 | 45 | template 46 | inline void QPBO::set_active(Node *i) 47 | { 48 | if (!i->next) 49 | { 50 | /* it's not in the list yet */ 51 | if (queue_last[1]) queue_last[1] -> next = i; 52 | else queue_first[1] = i; 53 | queue_last[1] = i; 54 | i -> next = i; 55 | } 56 | } 57 | 58 | /* 59 | Returns the next active node. 60 | If it is connected to the sink, it stays in the list, 61 | otherwise it is removed from the list 62 | */ 63 | template 64 | inline typename QPBO::Node* QPBO::next_active() 65 | { 66 | Node *i; 67 | 68 | while ( 1 ) 69 | { 70 | if (!(i=queue_first[0])) 71 | { 72 | queue_first[0] = i = queue_first[1]; 73 | queue_last[0] = queue_last[1]; 74 | queue_first[1] = NULL; 75 | queue_last[1] = NULL; 76 | if (!i) return NULL; 77 | } 78 | 79 | /* remove it from the active list */ 80 | if (i->next == i) queue_first[0] = queue_last[0] = NULL; 81 | else queue_first[0] = i -> next; 82 | i -> next = NULL; 83 | 84 | /* a node in the list is active iff it has a parent */ 85 | if (i->parent) return i; 86 | } 87 | } 88 | 89 | /***********************************************************************/ 90 | 91 | template 92 | inline void QPBO::set_orphan_front(Node *i) 93 | { 94 | nodeptr *np; 95 | i -> parent = QPBO_MAXFLOW_ORPHAN; 96 | np = nodeptr_block -> New(); 97 | np -> ptr = i; 98 | np -> next = orphan_first; 99 | orphan_first = np; 100 | } 101 | 102 | template 103 | inline void QPBO::set_orphan_rear(Node *i) 104 | { 105 | nodeptr *np; 106 | i -> parent = QPBO_MAXFLOW_ORPHAN; 107 | np = nodeptr_block -> New(); 108 | np -> ptr = i; 109 | if (orphan_last) orphan_last -> next = np; 110 | else orphan_first = np; 111 | orphan_last = np; 112 | np -> next = NULL; 113 | } 114 | 115 | /***********************************************************************/ 116 | 117 | template 118 | inline void QPBO::add_to_changed_list(Node *i) 119 | { 120 | if (keep_changed_list) 121 | { 122 | if (!IsNode0(i)) i = GetMate1(i); 123 | if (!i->is_in_changed_list) 124 | { 125 | Node** ptr = changed_list->New(); 126 | *ptr = i;; 127 | i->is_in_changed_list = true; 128 | } 129 | } 130 | } 131 | 132 | /***********************************************************************/ 133 | 134 | template 135 | void QPBO::maxflow_init() 136 | { 137 | Node *i; 138 | 139 | queue_first[0] = queue_last[0] = NULL; 140 | queue_first[1] = queue_last[1] = NULL; 141 | orphan_first = NULL; 142 | 143 | TIME = 0; 144 | 145 | for (i=nodes[0]; i next = NULL; 150 | i -> is_marked = 0; 151 | i -> is_in_changed_list = 0; 152 | i -> TS = TIME; 153 | if (i->tr_cap > 0) 154 | { 155 | /* i is connected to the source */ 156 | i -> is_sink = 0; 157 | i -> parent = QPBO_MAXFLOW_TERMINAL; 158 | set_active(i); 159 | i -> DIST = 1; 160 | } 161 | else if (i->tr_cap < 0) 162 | { 163 | /* i is connected to the sink */ 164 | i -> is_sink = 1; 165 | i -> parent = QPBO_MAXFLOW_TERMINAL; 166 | set_active(i); 167 | i -> DIST = 1; 168 | } 169 | else 170 | { 171 | i -> parent = NULL; 172 | } 173 | } 174 | } 175 | 176 | template 177 | void QPBO::maxflow_reuse_trees_init() 178 | { 179 | Node* i; 180 | Node* j; 181 | Node* queue = queue_first[1]; 182 | Arc* a; 183 | nodeptr* np; 184 | 185 | queue_first[0] = queue_last[0] = NULL; 186 | queue_first[1] = queue_last[1] = NULL; 187 | orphan_first = orphan_last = NULL; 188 | 189 | TIME ++; 190 | 191 | while ((i=queue)) 192 | { 193 | queue = i->next; 194 | if (queue == i) queue = NULL; 195 | if (IsNode0(i)) 196 | { 197 | if (i->is_removed) continue; 198 | } 199 | else 200 | { 201 | if (GetMate1(i)->is_removed) continue; 202 | } 203 | i->next = NULL; 204 | i->is_marked = 0; 205 | set_active(i); 206 | 207 | if (i->tr_cap == 0) 208 | { 209 | if (i->parent) set_orphan_rear(i); 210 | continue; 211 | } 212 | 213 | if (i->tr_cap > 0) 214 | { 215 | if (!i->parent || i->is_sink) 216 | { 217 | i->is_sink = 0; 218 | for (a=i->first; a; a=a->next) 219 | { 220 | j = a->head; 221 | if (!j->is_marked) 222 | { 223 | if (j->parent == a->sister) set_orphan_rear(j); 224 | if (j->parent && j->is_sink && a->r_cap > 0) set_active(j); 225 | } 226 | } 227 | add_to_changed_list(i); 228 | } 229 | } 230 | else 231 | { 232 | if (!i->parent || !i->is_sink) 233 | { 234 | i->is_sink = 1; 235 | for (a=i->first; a; a=a->next) 236 | { 237 | j = a->head; 238 | if (!j->is_marked) 239 | { 240 | if (j->parent == a->sister) set_orphan_rear(j); 241 | if (j->parent && !j->is_sink && a->sister->r_cap > 0) set_active(j); 242 | } 243 | } 244 | add_to_changed_list(i); 245 | } 246 | } 247 | i->parent = QPBO_MAXFLOW_TERMINAL; 248 | i -> TS = TIME; 249 | i -> DIST = 1; 250 | } 251 | 252 | code_assert(stage == 1); 253 | //test_consistency(); 254 | 255 | /* adoption */ 256 | while ((np=orphan_first)) 257 | { 258 | orphan_first = np -> next; 259 | i = np -> ptr; 260 | nodeptr_block -> Delete(np); 261 | if (!orphan_first) orphan_last = NULL; 262 | if (i->is_sink) process_sink_orphan(i); 263 | else process_source_orphan(i); 264 | } 265 | /* adoption end */ 266 | 267 | //test_consistency(); 268 | } 269 | 270 | template 271 | void QPBO::augment(Arc *middle_arc) 272 | { 273 | Node *i; 274 | Arc *a; 275 | REAL bottleneck; 276 | 277 | 278 | /* 1. Finding bottleneck capacity */ 279 | /* 1a - the source tree */ 280 | bottleneck = middle_arc -> r_cap; 281 | for (i=middle_arc->sister->head; ; i=a->head) 282 | { 283 | a = i -> parent; 284 | if (a == QPBO_MAXFLOW_TERMINAL) break; 285 | if (bottleneck > a->sister->r_cap) bottleneck = a -> sister -> r_cap; 286 | } 287 | if (bottleneck > i->tr_cap) bottleneck = i -> tr_cap; 288 | /* 1b - the sink tree */ 289 | for (i=middle_arc->head; ; i=a->head) 290 | { 291 | a = i -> parent; 292 | if (a == QPBO_MAXFLOW_TERMINAL) break; 293 | if (bottleneck > a->r_cap) bottleneck = a -> r_cap; 294 | } 295 | if (bottleneck > - i->tr_cap) bottleneck = - i -> tr_cap; 296 | 297 | 298 | /* 2. Augmenting */ 299 | /* 2a - the source tree */ 300 | middle_arc -> sister -> r_cap += bottleneck; 301 | middle_arc -> r_cap -= bottleneck; 302 | for (i=middle_arc->sister->head; ; i=a->head) 303 | { 304 | a = i -> parent; 305 | if (a == QPBO_MAXFLOW_TERMINAL) break; 306 | a -> r_cap += bottleneck; 307 | a -> sister -> r_cap -= bottleneck; 308 | if (!a->sister->r_cap) 309 | { 310 | set_orphan_front(i); // add i to the beginning of the adoption list 311 | } 312 | } 313 | i -> tr_cap -= bottleneck; 314 | if (!i->tr_cap) 315 | { 316 | set_orphan_front(i); // add i to the beginning of the adoption list 317 | } 318 | /* 2b - the sink tree */ 319 | for (i=middle_arc->head; ; i=a->head) 320 | { 321 | a = i -> parent; 322 | if (a == QPBO_MAXFLOW_TERMINAL) break; 323 | a -> sister -> r_cap += bottleneck; 324 | a -> r_cap -= bottleneck; 325 | if (!a->r_cap) 326 | { 327 | set_orphan_front(i); // add i to the beginning of the adoption list 328 | } 329 | } 330 | i -> tr_cap += bottleneck; 331 | if (!i->tr_cap) 332 | { 333 | set_orphan_front(i); // add i to the beginning of the adoption list 334 | } 335 | } 336 | 337 | /***********************************************************************/ 338 | 339 | template 340 | void QPBO::process_source_orphan(Node *i) 341 | { 342 | Node *j; 343 | Arc *a0, *a0_min = NULL, *a; 344 | int d, d_min = INFINITE_D; 345 | 346 | /* trying to find a new parent */ 347 | for (a0=i->first; a0; a0=a0->next) 348 | if (a0->sister->r_cap) 349 | { 350 | j = a0 -> head; 351 | if (!j->is_sink && (a=j->parent)) 352 | { 353 | /* checking the origin of j */ 354 | d = 0; 355 | while ( 1 ) 356 | { 357 | if (j->TS == TIME) 358 | { 359 | d += j -> DIST; 360 | break; 361 | } 362 | a = j -> parent; 363 | d ++; 364 | if (a==QPBO_MAXFLOW_TERMINAL) 365 | { 366 | j -> TS = TIME; 367 | j -> DIST = 1; 368 | break; 369 | } 370 | if (a==QPBO_MAXFLOW_ORPHAN) { d = INFINITE_D; break; } 371 | j = a -> head; 372 | } 373 | if (dhead; j->TS!=TIME; j=j->parent->head) 382 | { 383 | j -> TS = TIME; 384 | j -> DIST = d --; 385 | } 386 | } 387 | } 388 | } 389 | 390 | if (i->parent = a0_min) 391 | { 392 | i -> TS = TIME; 393 | i -> DIST = d_min + 1; 394 | } 395 | else 396 | { 397 | /* no parent is found */ 398 | add_to_changed_list(i); 399 | 400 | /* process neighbors */ 401 | for (a0=i->first; a0; a0=a0->next) 402 | { 403 | j = a0 -> head; 404 | if (!j->is_sink && (a=j->parent)) 405 | { 406 | if (a0->sister->r_cap) set_active(j); 407 | if (a!=QPBO_MAXFLOW_TERMINAL && a!=QPBO_MAXFLOW_ORPHAN && a->head==i) 408 | { 409 | set_orphan_rear(j); // add j to the end of the adoption list 410 | } 411 | } 412 | } 413 | } 414 | } 415 | 416 | template 417 | void QPBO::process_sink_orphan(Node *i) 418 | { 419 | Node *j; 420 | Arc *a0, *a0_min = NULL, *a; 421 | int d, d_min = INFINITE_D; 422 | 423 | /* trying to find a new parent */ 424 | for (a0=i->first; a0; a0=a0->next) 425 | if (a0->r_cap) 426 | { 427 | j = a0 -> head; 428 | if ((a=j->parent) && j->is_sink) 429 | { 430 | /* checking the origin of j */ 431 | d = 0; 432 | while ( 1 ) 433 | { 434 | if (j->TS == TIME) 435 | { 436 | d += j -> DIST; 437 | break; 438 | } 439 | a = j -> parent; 440 | d ++; 441 | if (a==QPBO_MAXFLOW_TERMINAL) 442 | { 443 | j -> TS = TIME; 444 | j -> DIST = 1; 445 | break; 446 | } 447 | if (a==QPBO_MAXFLOW_ORPHAN) { d = INFINITE_D; break; } 448 | j = a -> head; 449 | } 450 | if (dhead; j->TS!=TIME; j=j->parent->head) 459 | { 460 | j -> TS = TIME; 461 | j -> DIST = d --; 462 | } 463 | } 464 | } 465 | } 466 | 467 | if (i->parent = a0_min) 468 | { 469 | i -> TS = TIME; 470 | i -> DIST = d_min + 1; 471 | } 472 | else 473 | { 474 | /* no parent is found */ 475 | add_to_changed_list(i); 476 | 477 | /* process neighbors */ 478 | for (a0=i->first; a0; a0=a0->next) 479 | { 480 | j = a0 -> head; 481 | if ((a=j->parent) && j->is_sink) 482 | { 483 | if (a0->r_cap) set_active(j); 484 | if (a!=QPBO_MAXFLOW_TERMINAL && a!=QPBO_MAXFLOW_ORPHAN && a->head==i) 485 | { 486 | set_orphan_rear(j); // add j to the end of the adoption list 487 | } 488 | } 489 | } 490 | } 491 | } 492 | 493 | /***********************************************************************/ 494 | 495 | template 496 | void QPBO::maxflow(bool reuse_trees, bool _keep_changed_list) 497 | { 498 | Node *i, *j, *current_node = NULL; 499 | Arc *a; 500 | nodeptr *np, *np_next; 501 | 502 | if (!nodeptr_block) 503 | { 504 | nodeptr_block = new DBlock(NODEPTR_BLOCK_SIZE, error_function); 505 | } 506 | 507 | if (maxflow_iteration == 0) 508 | { 509 | reuse_trees = false; 510 | _keep_changed_list = false; 511 | } 512 | 513 | keep_changed_list = _keep_changed_list; 514 | if (keep_changed_list) 515 | { 516 | if (!changed_list) changed_list = new Block(NODEPTR_BLOCK_SIZE, error_function); 517 | } 518 | 519 | if (reuse_trees) maxflow_reuse_trees_init(); 520 | else maxflow_init(); 521 | 522 | // main loop 523 | while ( 1 ) 524 | { 525 | // test_consistency(current_node); 526 | 527 | if ((i=current_node)) 528 | { 529 | i -> next = NULL; /* remove active flag */ 530 | if (!i->parent) i = NULL; 531 | } 532 | if (!i) 533 | { 534 | if (!(i = next_active())) break; 535 | } 536 | 537 | /* growth */ 538 | if (!i->is_sink) 539 | { 540 | /* grow source tree */ 541 | for (a=i->first; a; a=a->next) 542 | if (a->r_cap) 543 | { 544 | j = a -> head; 545 | if (!j->parent) 546 | { 547 | j -> is_sink = 0; 548 | j -> parent = a -> sister; 549 | j -> TS = i -> TS; 550 | j -> DIST = i -> DIST + 1; 551 | set_active(j); 552 | add_to_changed_list(j); 553 | } 554 | else if (j->is_sink) break; 555 | else if (j->TS <= i->TS && 556 | j->DIST > i->DIST) 557 | { 558 | /* heuristic - trying to make the distance from j to the source shorter */ 559 | j -> parent = a -> sister; 560 | j -> TS = i -> TS; 561 | j -> DIST = i -> DIST + 1; 562 | } 563 | } 564 | } 565 | else 566 | { 567 | /* grow sink tree */ 568 | for (a=i->first; a; a=a->next) 569 | if (a->sister->r_cap) 570 | { 571 | j = a -> head; 572 | if (!j->parent) 573 | { 574 | j -> is_sink = 1; 575 | j -> parent = a -> sister; 576 | j -> TS = i -> TS; 577 | j -> DIST = i -> DIST + 1; 578 | set_active(j); 579 | add_to_changed_list(j); 580 | } 581 | else if (!j->is_sink) { a = a -> sister; break; } 582 | else if (j->TS <= i->TS && 583 | j->DIST > i->DIST) 584 | { 585 | /* heuristic - trying to make the distance from j to the sink shorter */ 586 | j -> parent = a -> sister; 587 | j -> TS = i -> TS; 588 | j -> DIST = i -> DIST + 1; 589 | } 590 | } 591 | } 592 | 593 | TIME ++; 594 | 595 | if (a) 596 | { 597 | i -> next = i; /* set active flag */ 598 | current_node = i; 599 | 600 | /* augmentation */ 601 | augment(a); 602 | /* augmentation end */ 603 | 604 | /* adoption */ 605 | while ((np=orphan_first)) 606 | { 607 | np_next = np -> next; 608 | np -> next = NULL; 609 | 610 | while ((np=orphan_first)) 611 | { 612 | orphan_first = np -> next; 613 | i = np -> ptr; 614 | nodeptr_block -> Delete(np); 615 | if (!orphan_first) orphan_last = NULL; 616 | if (i->is_sink) process_sink_orphan(i); 617 | else process_source_orphan(i); 618 | } 619 | 620 | orphan_first = np_next; 621 | } 622 | /* adoption end */ 623 | } 624 | else current_node = NULL; 625 | } 626 | // test_consistency(); 627 | 628 | if (!reuse_trees || (maxflow_iteration % 64) == 0) 629 | { 630 | delete nodeptr_block; 631 | nodeptr_block = NULL; 632 | } 633 | 634 | maxflow_iteration ++; 635 | } 636 | 637 | /***********************************************************************/ 638 | 639 | 640 | template 641 | void QPBO::test_consistency(Node* current_node) 642 | { 643 | Node *i; 644 | Arc *a; 645 | int r; 646 | int num1 = 0, num2 = 0; 647 | 648 | // test whether all nodes i with i->next!=NULL are indeed in the queue 649 | for (i=nodes[0]; iis_removed) || (!IsNode0(i) && GetMate1(i)->is_removed)) 653 | { 654 | code_assert(i->first == NULL); 655 | continue; 656 | } 657 | 658 | if (i->next || i==current_node) num1 ++; 659 | } 660 | for (r=0; r<3; r++) 661 | { 662 | i = (r == 2) ? current_node : queue_first[r]; 663 | if (i) 664 | for ( ; ; i=i->next) 665 | { 666 | code_assert((IsNode0(i) && !i->is_removed) || (!IsNode0(i) && !GetMate1(i)->is_removed)); 667 | num2 ++; 668 | if (i->next == i) 669 | { 670 | if (r<2) code_assert(i == queue_last[r]); 671 | else code_assert(i == current_node); 672 | break; 673 | } 674 | } 675 | } 676 | code_assert(num1 == num2); 677 | 678 | for (i=nodes[0]; iis_removed) || (!IsNode0(i) && GetMate1(i)->is_removed)) continue; 682 | 683 | // test whether all edges in seach trees are non-saturated 684 | if (i->parent == NULL) {} 685 | else if (i->parent == QPBO_MAXFLOW_ORPHAN) {} 686 | else if (i->parent == QPBO_MAXFLOW_TERMINAL) 687 | { 688 | if (!i->is_sink) code_assert(i->tr_cap > 0); 689 | else code_assert(i->tr_cap < 0); 690 | } 691 | else 692 | { 693 | if (!i->is_sink) code_assert (i->parent->sister->r_cap > 0); 694 | else code_assert (i->parent->r_cap > 0); 695 | } 696 | // test whether passive nodes in search trees have neighbors in 697 | // a different tree through non-saturated edges 698 | if (i->parent && !i->next) 699 | { 700 | if (!i->is_sink) 701 | { 702 | code_assert(i->tr_cap >= 0); 703 | for (a=i->first; a; a=a->next) 704 | { 705 | if (a->r_cap > 0) code_assert(a->head->parent && !a->head->is_sink); 706 | } 707 | } 708 | else 709 | { 710 | code_assert(i->tr_cap <= 0); 711 | for (a=i->first; a; a=a->next) 712 | { 713 | if (a->sister->r_cap > 0) code_assert(a->head->parent && a->head->is_sink); 714 | } 715 | } 716 | } 717 | // test marking invariants 718 | if (i->parent && i->parent!=QPBO_MAXFLOW_ORPHAN && i->parent!=QPBO_MAXFLOW_TERMINAL) 719 | { 720 | code_assert(i->TS <= i->parent->head->TS); 721 | if (i->TS == i->parent->head->TS) code_assert(i->DIST > i->parent->head->DIST); 722 | } 723 | } 724 | } 725 | 726 | #include "instances.inc" 727 | -------------------------------------------------------------------------------- /gco/graph.h: -------------------------------------------------------------------------------- 1 | /* graph.h */ 2 | /* 3 | This software library implements the maxflow algorithm 4 | described in 5 | 6 | "An Experimental Comparison of Min-Cut/Max-Flow Algorithms for Energy Minimization in Vision." 7 | Yuri Boykov and Vladimir Kolmogorov. 8 | In IEEE Transactions on Pattern Analysis and Machine Intelligence (PAMI), 9 | September 2004 10 | 11 | This algorithm was developed by Yuri Boykov and Vladimir Kolmogorov 12 | at Siemens Corporate Research. To make it available for public use, 13 | it was later reimplemented by Vladimir Kolmogorov based on open publications. 14 | 15 | If you use this software for research purposes, you should cite 16 | the aforementioned paper in any resulting publication. 17 | 18 | ---------------------------------------------------------------------- 19 | 20 | REUSING TREES: 21 | 22 | Starting with version 3.0, there is a also an option of reusing search 23 | trees from one maxflow computation to the next, as described in 24 | 25 | "Efficiently Solving Dynamic Markov Random Fields Using Graph Cuts." 26 | Pushmeet Kohli and Philip H.S. Torr 27 | International Conference on Computer Vision (ICCV), 2005 28 | 29 | If you use this option, you should cite 30 | the aforementioned paper in any resulting publication. 31 | */ 32 | 33 | 34 | 35 | /* 36 | For description, license, example usage see README.TXT. 37 | */ 38 | 39 | #ifndef __GRAPH_H__ 40 | #define __GRAPH_H__ 41 | 42 | #include 43 | #include "block.h" 44 | 45 | #include 46 | // NOTE: in UNIX you need to use -DNDEBUG preprocessor option to supress assert's!!! 47 | 48 | 49 | 50 | // captype: type of edge capacities (excluding t-links) 51 | // tcaptype: type of t-links (edges between nodes and terminals) 52 | // flowtype: type of total flow 53 | // 54 | // Current instantiations are in instances.inc 55 | template class Graph 56 | { 57 | public: 58 | typedef enum 59 | { 60 | SOURCE = 0, 61 | SINK = 1 62 | } termtype; // terminals 63 | typedef int node_id; 64 | 65 | ///////////////////////////////////////////////////////////////////////// 66 | // BASIC INTERFACE FUNCTIONS // 67 | // (should be enough for most applications) // 68 | ///////////////////////////////////////////////////////////////////////// 69 | 70 | // Constructor. 71 | // The first argument gives an estimate of the maximum number of nodes that can be added 72 | // to the graph, and the second argument is an estimate of the maximum number of edges. 73 | // The last (optional) argument is the pointer to the function which will be called 74 | // if an error occurs; an error message is passed to this function. 75 | // If this argument is omitted, exit(1) will be called. 76 | // 77 | // IMPORTANT: It is possible to add more nodes to the graph than node_num_max 78 | // (and node_num_max can be zero). However, if the count is exceeded, then 79 | // the internal memory is reallocated (increased by 50%) which is expensive. 80 | // Also, temporarily the amount of allocated memory would be more than twice than needed. 81 | // Similarly for edges. 82 | // If you wish to avoid this overhead, you can download version 2.2, where nodes and edges are stored in blocks. 83 | Graph(int node_num_max, int edge_num_max, void (*err_function)(const char *) = NULL); 84 | 85 | // Destructor 86 | ~Graph(); 87 | 88 | // Adds node(s) to the graph. By default, one node is added (num=1); then first call returns 0, second call returns 1, and so on. 89 | // If num>1, then several nodes are added, and node_id of the first one is returned. 90 | // IMPORTANT: see note about the constructor 91 | node_id add_node(int num = 1); 92 | 93 | // Adds a bidirectional edge between 'i' and 'j' with the weights 'cap' and 'rev_cap'. 94 | // IMPORTANT: see note about the constructor 95 | void add_edge(node_id i, node_id j, captype cap, captype rev_cap); 96 | 97 | // Adds new edges 'SOURCE->i' and 'i->SINK' with corresponding weights. 98 | // Can be called multiple times for each node. 99 | // Weights can be negative. 100 | // NOTE: the number of such edges is not counted in edge_num_max. 101 | // No internal memory is allocated by this call. 102 | void add_tweights(node_id i, tcaptype cap_source, tcaptype cap_sink); 103 | 104 | 105 | // Computes the maxflow. Can be called several times. 106 | // FOR DESCRIPTION OF reuse_trees, SEE mark_node(). 107 | // FOR DESCRIPTION OF changed_list, SEE remove_from_changed_list(). 108 | flowtype maxflow(bool reuse_trees = false, Block* changed_list = NULL); 109 | 110 | // After the maxflow is computed, this function returns to which 111 | // segment the node 'i' belongs (Graph::SOURCE or Graph::SINK). 112 | // 113 | // Occasionally there may be several minimum cuts. If a node can be assigned 114 | // to both the source and the sink, then default_segm is returned. 115 | termtype what_segment(node_id i, termtype default_segm = SOURCE); 116 | 117 | 118 | 119 | ////////////////////////////////////////////// 120 | // ADVANCED INTERFACE FUNCTIONS // 121 | // (provide access to the graph) // 122 | ////////////////////////////////////////////// 123 | 124 | private: 125 | struct node; 126 | struct arc; 127 | 128 | public: 129 | 130 | //////////////////////////// 131 | // 1. Reallocating graph. // 132 | //////////////////////////// 133 | 134 | // Removes all nodes and edges. 135 | // After that functions add_node() and add_edge() must be called again. 136 | // 137 | // Advantage compared to deleting Graph and allocating it again: 138 | // no calls to delete/new (which could be quite slow). 139 | // 140 | // If the graph structure stays the same, then an alternative 141 | // is to go through all nodes/edges and set new residual capacities 142 | // (see functions below). 143 | void reset(); 144 | 145 | //////////////////////////////////////////////////////////////////////////////// 146 | // 2. Functions for getting pointers to arcs and for reading graph structure. // 147 | // NOTE: adding new arcs may invalidate these pointers (if reallocation // 148 | // happens). So it's best not to add arcs while reading graph structure. // 149 | //////////////////////////////////////////////////////////////////////////////// 150 | 151 | // The following two functions return arcs in the same order that they 152 | // were added to the graph. NOTE: for each call add_edge(i,j,cap,cap_rev) 153 | // the first arc returned will be i->j, and the second j->i. 154 | // If there are no more arcs, then the function can still be called, but 155 | // the returned arc_id is undetermined. 156 | typedef arc* arc_id; 157 | arc_id get_first_arc(); 158 | arc_id get_next_arc(arc_id a); 159 | 160 | // other functions for reading graph structure 161 | int get_node_num() { return node_num; } 162 | int get_arc_num() { return (int)(arc_last - arcs); } 163 | void get_arc_ends(arc_id a, node_id& i, node_id& j); // returns i,j to that a = i->j 164 | 165 | /////////////////////////////////////////////////// 166 | // 3. Functions for reading residual capacities. // 167 | /////////////////////////////////////////////////// 168 | 169 | // returns residual capacity of SOURCE->i minus residual capacity of i->SINK 170 | tcaptype get_trcap(node_id i); 171 | // returns residual capacity of arc a 172 | captype get_rcap(arc* a); 173 | 174 | ///////////////////////////////////////////////////////////////// 175 | // 4. Functions for setting residual capacities. // 176 | // NOTE: If these functions are used, the value of the flow // 177 | // returned by maxflow() will not be valid! // 178 | ///////////////////////////////////////////////////////////////// 179 | 180 | void set_trcap(node_id i, tcaptype trcap); 181 | void set_rcap(arc* a, captype rcap); 182 | 183 | //////////////////////////////////////////////////////////////////// 184 | // 5. Functions related to reusing trees & list of changed nodes. // 185 | //////////////////////////////////////////////////////////////////// 186 | 187 | // If flag reuse_trees is true while calling maxflow(), then search trees 188 | // are reused from previous maxflow computation. 189 | // In this case before calling maxflow() the user must 190 | // specify which parts of the graph have changed by calling mark_node(): 191 | // add_tweights(i),set_trcap(i) => call mark_node(i) 192 | // add_edge(i,j),set_rcap(a) => call mark_node(i); mark_node(j) 193 | // 194 | // This option makes sense only if a small part of the graph is changed. 195 | // The initialization procedure goes only through marked nodes then. 196 | // 197 | // mark_node(i) can either be called before or after graph modification. 198 | // Can be called more than once per node, but calls after the first one 199 | // do not have any effect. 200 | // 201 | // NOTE: 202 | // - This option cannot be used in the first call to maxflow(). 203 | // - It is not necessary to call mark_node() if the change is ``not essential'', 204 | // i.e. sign(trcap) is preserved for a node and zero/nonzero status is preserved for an arc. 205 | // - To check that you marked all necessary nodes, you can call maxflow(false) after calling maxflow(true). 206 | // If everything is correct, the two calls must return the same value of flow. (Useful for debugging). 207 | void mark_node(node_id i); 208 | 209 | // If changed_list is not NULL while calling maxflow(), then the algorithm 210 | // keeps a list of nodes which could potentially have changed their segmentation label. 211 | // Nodes which are not in the list are guaranteed to keep their old segmentation label (SOURCE or SINK). 212 | // Example usage: 213 | // 214 | // typedef Graph G; 215 | // G* g = new Graph(nodeNum, edgeNum); 216 | // Block* changed_list = new Block(128); 217 | // 218 | // ... // add nodes and edges 219 | // 220 | // g->maxflow(); // first call should be without arguments 221 | // for (int iter=0; iter<10; iter++) 222 | // { 223 | // ... // change graph, call mark_node() accordingly 224 | // 225 | // g->maxflow(true, changed_list); 226 | // G::node_id* ptr; 227 | // for (ptr=changed_list->ScanFirst(); ptr; ptr=changed_list->ScanNext()) 228 | // { 229 | // G::node_id i = *ptr; assert(i>=0 && iremove_from_changed_list(i); 231 | // // do something with node i... 232 | // if (g->what_segment(i) == G::SOURCE) { ... } 233 | // } 234 | // changed_list->Reset(); 235 | // } 236 | // delete changed_list; 237 | // 238 | // NOTE: 239 | // - If changed_list option is used, then reuse_trees must be used as well. 240 | // - In the example above, the user may omit calls g->remove_from_changed_list(i) and changed_list->Reset() in a given iteration. 241 | // Then during the next call to maxflow(true, &changed_list) new nodes will be added to changed_list. 242 | // - If the next call to maxflow() does not use option reuse_trees, then calling remove_from_changed_list() 243 | // is not necessary. ("changed_list->Reset()" or "delete changed_list" should still be called, though). 244 | void remove_from_changed_list(node_id i) 245 | { 246 | assert(i>=0 && i* g0); 251 | 252 | 253 | 254 | 255 | ///////////////////////////////////////////////////////////////////////// 256 | ///////////////////////////////////////////////////////////////////////// 257 | ///////////////////////////////////////////////////////////////////////// 258 | 259 | private: 260 | // internal variables and functions 261 | 262 | struct node 263 | { 264 | arc *first; // first outcoming arc 265 | 266 | arc *parent; // node's parent 267 | node *next; // pointer to the next active node 268 | // (or to itself if it is the last node in the list) 269 | int TS; // timestamp showing when DIST was computed 270 | int DIST; // distance to the terminal 271 | int is_sink : 1; // flag showing whether the node is in the source or in the sink tree (if parent!=NULL) 272 | int is_marked : 1; // set by mark_node() 273 | int is_in_changed_list : 1; // set by maxflow if 274 | 275 | tcaptype tr_cap; // if tr_cap > 0 then tr_cap is residual capacity of the arc SOURCE->node 276 | // otherwise -tr_cap is residual capacity of the arc node->SINK 277 | 278 | }; 279 | 280 | struct arc 281 | { 282 | node *head; // node the arc points to 283 | arc *next; // next arc with the same originating node 284 | arc *sister; // reverse arc 285 | 286 | captype r_cap; // residual capacity 287 | }; 288 | 289 | struct nodeptr 290 | { 291 | node *ptr; 292 | nodeptr *next; 293 | }; 294 | static const int NODEPTR_BLOCK_SIZE = 128; 295 | 296 | node *nodes, *node_last, *node_max; // node_last = nodes+node_num, node_max = nodes+node_num_max; 297 | arc *arcs, *arc_last, *arc_max; // arc_last = arcs+2*edge_num, arc_max = arcs+2*edge_num_max; 298 | 299 | int node_num; 300 | 301 | DBlock *nodeptr_block; 302 | 303 | void (*error_function)(const char *); // this function is called if a error occurs, 304 | // with a corresponding error message 305 | // (or exit(1) is called if it's NULL) 306 | 307 | flowtype flow; // total flow 308 | 309 | // reusing trees & list of changed pixels 310 | int maxflow_iteration; // counter 311 | Block *changed_list; 312 | 313 | ///////////////////////////////////////////////////////////////////////// 314 | 315 | node *queue_first[2], *queue_last[2]; // list of active nodes 316 | nodeptr *orphan_first, *orphan_last; // list of pointers to orphans 317 | int TIME; // monotonically increasing global counter 318 | 319 | ///////////////////////////////////////////////////////////////////////// 320 | 321 | void reallocate_nodes(int num); // num is the number of new nodes 322 | void reallocate_arcs(); 323 | 324 | // functions for processing active list 325 | void set_active(node *i); 326 | node *next_active(); 327 | 328 | // functions for processing orphans list 329 | void set_orphan_front(node* i); // add to the beginning of the list 330 | void set_orphan_rear(node* i); // add to the end of the list 331 | 332 | void add_to_changed_list(node* i); 333 | 334 | void maxflow_init(); // called if reuse_trees == false 335 | void maxflow_reuse_trees_init(); // called if reuse_trees == true 336 | void augment(arc *middle_arc); 337 | void process_source_orphan(node *i); 338 | void process_sink_orphan(node *i); 339 | 340 | void test_consistency(node* current_node=NULL); // debug function 341 | }; 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | /////////////////////////////////////// 354 | // Implementation - inline functions // 355 | /////////////////////////////////////// 356 | 357 | 358 | 359 | template 360 | inline typename Graph::node_id Graph::add_node(int num) 361 | { 362 | assert(num > 0); 363 | 364 | if (node_last + num > node_max) reallocate_nodes(num); 365 | 366 | if (num == 1) 367 | { 368 | node_last -> first = NULL; 369 | node_last -> tr_cap = 0; 370 | node_last -> is_marked = 0; 371 | node_last -> is_in_changed_list = 0; 372 | 373 | node_last ++; 374 | return node_num ++; 375 | } 376 | else 377 | { 378 | memset(node_last, 0, num*sizeof(node)); 379 | 380 | node_id i = node_num; 381 | node_num += num; 382 | node_last += num; 383 | return i; 384 | } 385 | } 386 | 387 | template 388 | inline void Graph::add_tweights(node_id i, tcaptype cap_source, tcaptype cap_sink) 389 | { 390 | assert(i >= 0 && i < node_num); 391 | 392 | tcaptype delta = nodes[i].tr_cap; 393 | if (delta > 0) cap_source += delta; 394 | else cap_sink -= delta; 395 | flow += (cap_source < cap_sink) ? cap_source : cap_sink; 396 | nodes[i].tr_cap = cap_source - cap_sink; 397 | } 398 | 399 | template 400 | inline void Graph::add_edge(node_id _i, node_id _j, captype cap, captype rev_cap) 401 | { 402 | assert(_i >= 0 && _i < node_num); 403 | assert(_j >= 0 && _j < node_num); 404 | assert(_i != _j); 405 | assert(cap >= 0); 406 | //assert(rev_cap >= 0); 407 | 408 | if (arc_last == arc_max) reallocate_arcs(); 409 | 410 | arc *a = arc_last ++; 411 | arc *a_rev = arc_last ++; 412 | 413 | node* i = nodes + _i; 414 | node* j = nodes + _j; 415 | 416 | a -> sister = a_rev; 417 | a_rev -> sister = a; 418 | a -> next = i -> first; 419 | i -> first = a; 420 | a_rev -> next = j -> first; 421 | j -> first = a_rev; 422 | a -> head = j; 423 | a_rev -> head = i; 424 | a -> r_cap = cap; 425 | a_rev -> r_cap = rev_cap; 426 | } 427 | 428 | template 429 | inline typename Graph::arc* Graph::get_first_arc() 430 | { 431 | return arcs; 432 | } 433 | 434 | template 435 | inline typename Graph::arc* Graph::get_next_arc(arc* a) 436 | { 437 | return a + 1; 438 | } 439 | 440 | template 441 | inline void Graph::get_arc_ends(arc* a, node_id& i, node_id& j) 442 | { 443 | assert(a >= arcs && a < arc_last); 444 | i = (node_id) (a->sister->head - nodes); 445 | j = (node_id) (a->head - nodes); 446 | } 447 | 448 | template 449 | inline tcaptype Graph::get_trcap(node_id i) 450 | { 451 | assert(i>=0 && i 456 | inline captype Graph::get_rcap(arc* a) 457 | { 458 | assert(a >= arcs && a < arc_last); 459 | return a->r_cap; 460 | } 461 | 462 | template 463 | inline void Graph::set_trcap(node_id i, tcaptype trcap) 464 | { 465 | assert(i>=0 && i 470 | inline void Graph::set_rcap(arc* a, captype rcap) 471 | { 472 | assert(a >= arcs && a < arc_last); 473 | a->r_cap = rcap; 474 | } 475 | 476 | 477 | template 478 | inline typename Graph::termtype Graph::what_segment(node_id i, termtype default_segm) 479 | { 480 | if (nodes[i].parent) 481 | { 482 | return (nodes[i].is_sink) ? SINK : SOURCE; 483 | } 484 | else 485 | { 486 | return default_segm; 487 | } 488 | } 489 | 490 | template 491 | inline void Graph::mark_node(node_id _i) 492 | { 493 | node* i = nodes + _i; 494 | if (!i->next) 495 | { 496 | /* it's not in the list yet */ 497 | if (queue_last[1]) queue_last[1] -> next = i; 498 | else queue_first[1] = i; 499 | queue_last[1] = i; 500 | i -> next = i; 501 | } 502 | i->is_marked = 1; 503 | } 504 | 505 | 506 | #endif 507 | -------------------------------------------------------------------------------- /gco/maxflow.cpp: -------------------------------------------------------------------------------- 1 | /* maxflow.cpp */ 2 | 3 | 4 | #include 5 | #include "graph.h" 6 | 7 | 8 | /* 9 | special constants for node->parent 10 | */ 11 | #define TERMINAL ( (arc *) 1 ) /* to terminal */ 12 | #define ORPHAN ( (arc *) 2 ) /* orphan */ 13 | 14 | 15 | #define INFINITE_D ((int)(((unsigned)-1)/2)) /* infinite distance to the terminal */ 16 | 17 | /***********************************************************************/ 18 | 19 | /* 20 | Functions for processing active list. 21 | i->next points to the next node in the list 22 | (or to i, if i is the last node in the list). 23 | If i->next is NULL iff i is not in the list. 24 | 25 | There are two queues. Active nodes are added 26 | to the end of the second queue and read from 27 | the front of the first queue. If the first queue 28 | is empty, it is replaced by the second queue 29 | (and the second queue becomes empty). 30 | */ 31 | 32 | 33 | template 34 | inline void Graph::set_active(node *i) 35 | { 36 | if (!i->next) 37 | { 38 | /* it's not in the list yet */ 39 | if (queue_last[1]) queue_last[1] -> next = i; 40 | else queue_first[1] = i; 41 | queue_last[1] = i; 42 | i -> next = i; 43 | } 44 | } 45 | 46 | /* 47 | Returns the next active node. 48 | If it is connected to the sink, it stays in the list, 49 | otherwise it is removed from the list 50 | */ 51 | template 52 | inline typename Graph::node* Graph::next_active() 53 | { 54 | node *i; 55 | 56 | while ( 1 ) 57 | { 58 | if (!(i=queue_first[0])) 59 | { 60 | queue_first[0] = i = queue_first[1]; 61 | queue_last[0] = queue_last[1]; 62 | queue_first[1] = NULL; 63 | queue_last[1] = NULL; 64 | if (!i) return NULL; 65 | } 66 | 67 | /* remove it from the active list */ 68 | if (i->next == i) queue_first[0] = queue_last[0] = NULL; 69 | else queue_first[0] = i -> next; 70 | i -> next = NULL; 71 | 72 | /* a node in the list is active iff it has a parent */ 73 | if (i->parent) return i; 74 | } 75 | } 76 | 77 | /***********************************************************************/ 78 | 79 | template 80 | inline void Graph::set_orphan_front(node *i) 81 | { 82 | nodeptr *np; 83 | i -> parent = ORPHAN; 84 | np = nodeptr_block -> New(); 85 | np -> ptr = i; 86 | np -> next = orphan_first; 87 | orphan_first = np; 88 | } 89 | 90 | template 91 | inline void Graph::set_orphan_rear(node *i) 92 | { 93 | nodeptr *np; 94 | i -> parent = ORPHAN; 95 | np = nodeptr_block -> New(); 96 | np -> ptr = i; 97 | if (orphan_last) orphan_last -> next = np; 98 | else orphan_first = np; 99 | orphan_last = np; 100 | np -> next = NULL; 101 | } 102 | 103 | /***********************************************************************/ 104 | 105 | template 106 | inline void Graph::add_to_changed_list(node *i) 107 | { 108 | if (changed_list && !i->is_in_changed_list) 109 | { 110 | node_id* ptr = changed_list->New(); 111 | *ptr = (node_id)(i - nodes); 112 | i->is_in_changed_list = true; 113 | } 114 | } 115 | 116 | /***********************************************************************/ 117 | 118 | template 119 | void Graph::maxflow_init() 120 | { 121 | node *i; 122 | 123 | queue_first[0] = queue_last[0] = NULL; 124 | queue_first[1] = queue_last[1] = NULL; 125 | orphan_first = NULL; 126 | 127 | TIME = 0; 128 | 129 | for (i=nodes; i next = NULL; 132 | i -> is_marked = 0; 133 | i -> is_in_changed_list = 0; 134 | i -> TS = TIME; 135 | if (i->tr_cap > 0) 136 | { 137 | /* i is connected to the source */ 138 | i -> is_sink = 0; 139 | i -> parent = TERMINAL; 140 | set_active(i); 141 | i -> DIST = 1; 142 | } 143 | else if (i->tr_cap < 0) 144 | { 145 | /* i is connected to the sink */ 146 | i -> is_sink = 1; 147 | i -> parent = TERMINAL; 148 | set_active(i); 149 | i -> DIST = 1; 150 | } 151 | else 152 | { 153 | i -> parent = NULL; 154 | } 155 | } 156 | } 157 | 158 | template 159 | void Graph::maxflow_reuse_trees_init() 160 | { 161 | node* i; 162 | node* j; 163 | node* queue = queue_first[1]; 164 | arc* a; 165 | nodeptr* np; 166 | 167 | queue_first[0] = queue_last[0] = NULL; 168 | queue_first[1] = queue_last[1] = NULL; 169 | orphan_first = orphan_last = NULL; 170 | 171 | TIME ++; 172 | 173 | while ((i=queue)) 174 | { 175 | queue = i->next; 176 | if (queue == i) queue = NULL; 177 | i->next = NULL; 178 | i->is_marked = 0; 179 | set_active(i); 180 | 181 | if (i->tr_cap == 0) 182 | { 183 | if (i->parent) set_orphan_rear(i); 184 | continue; 185 | } 186 | 187 | if (i->tr_cap > 0) 188 | { 189 | if (!i->parent || i->is_sink) 190 | { 191 | i->is_sink = 0; 192 | for (a=i->first; a; a=a->next) 193 | { 194 | j = a->head; 195 | if (!j->is_marked) 196 | { 197 | if (j->parent == a->sister) set_orphan_rear(j); 198 | if (j->parent && j->is_sink && a->r_cap > 0) set_active(j); 199 | } 200 | } 201 | add_to_changed_list(i); 202 | } 203 | } 204 | else 205 | { 206 | if (!i->parent || !i->is_sink) 207 | { 208 | i->is_sink = 1; 209 | for (a=i->first; a; a=a->next) 210 | { 211 | j = a->head; 212 | if (!j->is_marked) 213 | { 214 | if (j->parent == a->sister) set_orphan_rear(j); 215 | if (j->parent && !j->is_sink && a->sister->r_cap > 0) set_active(j); 216 | } 217 | } 218 | add_to_changed_list(i); 219 | } 220 | } 221 | i->parent = TERMINAL; 222 | i -> TS = TIME; 223 | i -> DIST = 1; 224 | } 225 | 226 | //test_consistency(); 227 | 228 | /* adoption */ 229 | while ((np=orphan_first)) 230 | { 231 | orphan_first = np -> next; 232 | i = np -> ptr; 233 | nodeptr_block -> Delete(np); 234 | if (!orphan_first) orphan_last = NULL; 235 | if (i->is_sink) process_sink_orphan(i); 236 | else process_source_orphan(i); 237 | } 238 | /* adoption end */ 239 | 240 | //test_consistency(); 241 | } 242 | 243 | template 244 | void Graph::augment(arc *middle_arc) 245 | { 246 | node *i; 247 | arc *a; 248 | tcaptype bottleneck; 249 | 250 | 251 | /* 1. Finding bottleneck capacity */ 252 | /* 1a - the source tree */ 253 | bottleneck = middle_arc -> r_cap; 254 | for (i=middle_arc->sister->head; ; i=a->head) 255 | { 256 | a = i -> parent; 257 | if (a == TERMINAL) break; 258 | if (bottleneck > a->sister->r_cap) bottleneck = a -> sister -> r_cap; 259 | } 260 | if (bottleneck > i->tr_cap) bottleneck = i -> tr_cap; 261 | /* 1b - the sink tree */ 262 | for (i=middle_arc->head; ; i=a->head) 263 | { 264 | a = i -> parent; 265 | if (a == TERMINAL) break; 266 | if (bottleneck > a->r_cap) bottleneck = a -> r_cap; 267 | } 268 | if (bottleneck > - i->tr_cap) bottleneck = - i -> tr_cap; 269 | 270 | 271 | /* 2. Augmenting */ 272 | /* 2a - the source tree */ 273 | middle_arc -> sister -> r_cap += bottleneck; 274 | middle_arc -> r_cap -= bottleneck; 275 | for (i=middle_arc->sister->head; ; i=a->head) 276 | { 277 | a = i -> parent; 278 | if (a == TERMINAL) break; 279 | a -> r_cap += bottleneck; 280 | a -> sister -> r_cap -= bottleneck; 281 | if (!a->sister->r_cap) 282 | { 283 | set_orphan_front(i); // add i to the beginning of the adoption list 284 | } 285 | } 286 | i -> tr_cap -= bottleneck; 287 | if (!i->tr_cap) 288 | { 289 | set_orphan_front(i); // add i to the beginning of the adoption list 290 | } 291 | /* 2b - the sink tree */ 292 | for (i=middle_arc->head; ; i=a->head) 293 | { 294 | a = i -> parent; 295 | if (a == TERMINAL) break; 296 | a -> sister -> r_cap += bottleneck; 297 | a -> r_cap -= bottleneck; 298 | if (!a->r_cap) 299 | { 300 | set_orphan_front(i); // add i to the beginning of the adoption list 301 | } 302 | } 303 | i -> tr_cap += bottleneck; 304 | if (!i->tr_cap) 305 | { 306 | set_orphan_front(i); // add i to the beginning of the adoption list 307 | } 308 | 309 | 310 | flow += bottleneck; 311 | } 312 | 313 | /***********************************************************************/ 314 | 315 | template 316 | void Graph::process_source_orphan(node *i) 317 | { 318 | node *j; 319 | arc *a0, *a0_min = NULL, *a; 320 | int d, d_min = INFINITE_D; 321 | 322 | /* trying to find a new parent */ 323 | for (a0=i->first; a0; a0=a0->next) 324 | if (a0->sister->r_cap) 325 | { 326 | j = a0 -> head; 327 | if (!j->is_sink && (a=j->parent)) 328 | { 329 | /* checking the origin of j */ 330 | d = 0; 331 | while ( 1 ) 332 | { 333 | if (j->TS == TIME) 334 | { 335 | d += j -> DIST; 336 | break; 337 | } 338 | a = j -> parent; 339 | d ++; 340 | if (a==TERMINAL) 341 | { 342 | j -> TS = TIME; 343 | j -> DIST = 1; 344 | break; 345 | } 346 | if (a==ORPHAN) { d = INFINITE_D; break; } 347 | j = a -> head; 348 | } 349 | if (dhead; j->TS!=TIME; j=j->parent->head) 358 | { 359 | j -> TS = TIME; 360 | j -> DIST = d --; 361 | } 362 | } 363 | } 364 | } 365 | 366 | if (i->parent = a0_min) 367 | { 368 | i -> TS = TIME; 369 | i -> DIST = d_min + 1; 370 | } 371 | else 372 | { 373 | /* no parent is found */ 374 | add_to_changed_list(i); 375 | 376 | /* process neighbors */ 377 | for (a0=i->first; a0; a0=a0->next) 378 | { 379 | j = a0 -> head; 380 | if (!j->is_sink && (a=j->parent)) 381 | { 382 | if (a0->sister->r_cap) set_active(j); 383 | if (a!=TERMINAL && a!=ORPHAN && a->head==i) 384 | { 385 | set_orphan_rear(j); // add j to the end of the adoption list 386 | } 387 | } 388 | } 389 | } 390 | } 391 | 392 | template 393 | void Graph::process_sink_orphan(node *i) 394 | { 395 | node *j; 396 | arc *a0, *a0_min = NULL, *a; 397 | int d, d_min = INFINITE_D; 398 | 399 | /* trying to find a new parent */ 400 | for (a0=i->first; a0; a0=a0->next) 401 | if (a0->r_cap) 402 | { 403 | j = a0 -> head; 404 | if (j->is_sink && (a=j->parent)) 405 | { 406 | /* checking the origin of j */ 407 | d = 0; 408 | while ( 1 ) 409 | { 410 | if (j->TS == TIME) 411 | { 412 | d += j -> DIST; 413 | break; 414 | } 415 | a = j -> parent; 416 | d ++; 417 | if (a==TERMINAL) 418 | { 419 | j -> TS = TIME; 420 | j -> DIST = 1; 421 | break; 422 | } 423 | if (a==ORPHAN) { d = INFINITE_D; break; } 424 | j = a -> head; 425 | } 426 | if (dhead; j->TS!=TIME; j=j->parent->head) 435 | { 436 | j -> TS = TIME; 437 | j -> DIST = d --; 438 | } 439 | } 440 | } 441 | } 442 | 443 | if (i->parent = a0_min) 444 | { 445 | i -> TS = TIME; 446 | i -> DIST = d_min + 1; 447 | } 448 | else 449 | { 450 | /* no parent is found */ 451 | add_to_changed_list(i); 452 | 453 | /* process neighbors */ 454 | for (a0=i->first; a0; a0=a0->next) 455 | { 456 | j = a0 -> head; 457 | if (j->is_sink && (a=j->parent)) 458 | { 459 | if (a0->r_cap) set_active(j); 460 | if (a!=TERMINAL && a!=ORPHAN && a->head==i) 461 | { 462 | set_orphan_rear(j); // add j to the end of the adoption list 463 | } 464 | } 465 | } 466 | } 467 | } 468 | 469 | /***********************************************************************/ 470 | 471 | template 472 | flowtype Graph::maxflow(bool reuse_trees, Block* _changed_list) 473 | { 474 | node *i, *j, *current_node = NULL; 475 | arc *a; 476 | nodeptr *np, *np_next; 477 | 478 | if (!nodeptr_block) 479 | { 480 | nodeptr_block = new DBlock(NODEPTR_BLOCK_SIZE, error_function); 481 | } 482 | 483 | changed_list = _changed_list; 484 | if (maxflow_iteration == 0 && reuse_trees) { if (error_function) (*error_function)("reuse_trees cannot be used in the first call to maxflow()!"); exit(1); } 485 | if (changed_list && !reuse_trees) { if (error_function) (*error_function)("changed_list cannot be used without reuse_trees!"); exit(1); } 486 | 487 | if (reuse_trees) maxflow_reuse_trees_init(); 488 | else maxflow_init(); 489 | 490 | // main loop 491 | while ( 1 ) 492 | { 493 | // test_consistency(current_node); 494 | 495 | if ((i=current_node)) 496 | { 497 | i -> next = NULL; /* remove active flag */ 498 | if (!i->parent) i = NULL; 499 | } 500 | if (!i) 501 | { 502 | if (!(i = next_active())) break; 503 | } 504 | 505 | /* growth */ 506 | if (!i->is_sink) 507 | { 508 | /* grow source tree */ 509 | for (a=i->first; a; a=a->next) 510 | if (a->r_cap) 511 | { 512 | j = a -> head; 513 | if (!j->parent) 514 | { 515 | j -> is_sink = 0; 516 | j -> parent = a -> sister; 517 | j -> TS = i -> TS; 518 | j -> DIST = i -> DIST + 1; 519 | set_active(j); 520 | add_to_changed_list(j); 521 | } 522 | else if (j->is_sink) break; 523 | else if (j->TS <= i->TS && 524 | j->DIST > i->DIST) 525 | { 526 | /* heuristic - trying to make the distance from j to the source shorter */ 527 | j -> parent = a -> sister; 528 | j -> TS = i -> TS; 529 | j -> DIST = i -> DIST + 1; 530 | } 531 | } 532 | } 533 | else 534 | { 535 | /* grow sink tree */ 536 | for (a=i->first; a; a=a->next) 537 | if (a->sister->r_cap) 538 | { 539 | j = a -> head; 540 | if (!j->parent) 541 | { 542 | j -> is_sink = 1; 543 | j -> parent = a -> sister; 544 | j -> TS = i -> TS; 545 | j -> DIST = i -> DIST + 1; 546 | set_active(j); 547 | add_to_changed_list(j); 548 | } 549 | else if (!j->is_sink) { a = a -> sister; break; } 550 | else if (j->TS <= i->TS && 551 | j->DIST > i->DIST) 552 | { 553 | /* heuristic - trying to make the distance from j to the sink shorter */ 554 | j -> parent = a -> sister; 555 | j -> TS = i -> TS; 556 | j -> DIST = i -> DIST + 1; 557 | } 558 | } 559 | } 560 | 561 | TIME ++; 562 | 563 | if (a) 564 | { 565 | i -> next = i; /* set active flag */ 566 | current_node = i; 567 | 568 | /* augmentation */ 569 | augment(a); 570 | /* augmentation end */ 571 | 572 | /* adoption */ 573 | while ((np=orphan_first)) 574 | { 575 | np_next = np -> next; 576 | np -> next = NULL; 577 | 578 | while ((np=orphan_first)) 579 | { 580 | orphan_first = np -> next; 581 | i = np -> ptr; 582 | nodeptr_block -> Delete(np); 583 | if (!orphan_first) orphan_last = NULL; 584 | if (i->is_sink) process_sink_orphan(i); 585 | else process_source_orphan(i); 586 | } 587 | 588 | orphan_first = np_next; 589 | } 590 | /* adoption end */ 591 | } 592 | else current_node = NULL; 593 | } 594 | // test_consistency(); 595 | 596 | if (!reuse_trees || (maxflow_iteration % 64) == 0) 597 | { 598 | delete nodeptr_block; 599 | nodeptr_block = NULL; 600 | } 601 | 602 | maxflow_iteration ++; 603 | return flow; 604 | } 605 | 606 | /***********************************************************************/ 607 | 608 | 609 | template 610 | void Graph::test_consistency(node* current_node) 611 | { 612 | node *i; 613 | arc *a; 614 | int r; 615 | int num1 = 0, num2 = 0; 616 | 617 | // test whether all nodes i with i->next!=NULL are indeed in the queue 618 | for (i=nodes; inext || i==current_node) num1 ++; 621 | } 622 | for (r=0; r<3; r++) 623 | { 624 | i = (r == 2) ? current_node : queue_first[r]; 625 | if (i) 626 | for ( ; ; i=i->next) 627 | { 628 | num2 ++; 629 | if (i->next == i) 630 | { 631 | if (r<2) assert(i == queue_last[r]); 632 | else assert(i == current_node); 633 | break; 634 | } 635 | } 636 | } 637 | assert(num1 == num2); 638 | 639 | for (i=nodes; iparent == NULL) {} 643 | else if (i->parent == ORPHAN) {} 644 | else if (i->parent == TERMINAL) 645 | { 646 | if (!i->is_sink) assert(i->tr_cap > 0); 647 | else assert(i->tr_cap < 0); 648 | } 649 | else 650 | { 651 | if (!i->is_sink) assert (i->parent->sister->r_cap > 0); 652 | else assert (i->parent->r_cap > 0); 653 | } 654 | // test whether passive nodes in search trees have neighbors in 655 | // a different tree through non-saturated edges 656 | if (i->parent && !i->next) 657 | { 658 | if (!i->is_sink) 659 | { 660 | assert(i->tr_cap >= 0); 661 | for (a=i->first; a; a=a->next) 662 | { 663 | if (a->r_cap > 0) assert(a->head->parent && !a->head->is_sink); 664 | } 665 | } 666 | else 667 | { 668 | assert(i->tr_cap <= 0); 669 | for (a=i->first; a; a=a->next) 670 | { 671 | if (a->sister->r_cap > 0) assert(a->head->parent && a->head->is_sink); 672 | } 673 | } 674 | } 675 | // test marking invariants 676 | if (i->parent && i->parent!=ORPHAN && i->parent!=TERMINAL) 677 | { 678 | assert(i->TS <= i->parent->head->TS); 679 | if (i->TS == i->parent->head->TS) assert(i->DIST > i->parent->head->DIST); 680 | } 681 | } 682 | } 683 | 684 | template 685 | void Graph::Copy(Graph* g0) 686 | { 687 | node* i; 688 | arc* a; 689 | 690 | reset(); 691 | 692 | if (node_max < nodes + g0->node_num) 693 | { 694 | free(nodes); 695 | nodes = node_last = (node*) malloc(g0->node_num*sizeof(node)); 696 | node_max = nodes + g0->node_num; 697 | } 698 | if (arc_max < arcs + (g0->arc_last - g0->arcs)) 699 | { 700 | free(arcs); 701 | arcs = arc_last = (arc*) malloc((g0->arc_last - g0->arcs)*sizeof(arc)); 702 | arc_max = arcs + (g0->arc_last - g0->arcs); 703 | } 704 | 705 | node_num = g0->node_num; 706 | node_last = nodes + node_num; 707 | memcpy(nodes, g0->nodes, node_num*sizeof(node)); 708 | for (i=nodes; ifirst) i->first = (arc*)((char*)arcs + (((char*)i->first) - ((char*)g0->arcs))); 711 | if (i->parent && i->parent!=TERMINAL && i->parent!=ORPHAN) i->parent = (arc*)((char*)arcs + (((char*)i->parent) - ((char*)g0->arcs))); 712 | if (i->next) i->next = (node*)((char*)nodes + (((char*)i->next) - ((char*)g0->nodes))); 713 | } 714 | 715 | arc_last = arcs + (g0->arc_last - g0->arcs); 716 | memcpy(arcs, g0->arcs, (g0->arc_last - g0->arcs)*sizeof(arc)); 717 | for (a=arcs; ahead = (node*)((char*)nodes + (((char*)a->head) - ((char*)g0->nodes))); 720 | if (a->next) a->next = (arc*)((char*)arcs + (((char*)a->next) - ((char*)g0->arcs))); 721 | a->sister = (arc*)((char*)arcs + (((char*)a->sister) - ((char*)g0->arcs))); 722 | } 723 | 724 | error_function = g0->error_function; 725 | flow = g0->flow; 726 | maxflow_iteration = g0->maxflow_iteration; 727 | 728 | queue_first[0] = (g0->queue_first[0]==NULL) ? NULL : (node*)((char*)nodes + (((char*)g0->queue_first[0]) - ((char*)g0->nodes))); 729 | queue_first[1] = (g0->queue_first[1]==NULL) ? NULL : (node*)((char*)nodes + (((char*)g0->queue_first[1]) - ((char*)g0->nodes))); 730 | queue_last[0] = (g0->queue_last[0]==NULL) ? NULL : (node*)((char*)nodes + (((char*)g0->queue_last[0]) - ((char*)g0->nodes))); 731 | queue_last[1] = (g0->queue_last[1]==NULL) ? NULL : (node*)((char*)nodes + (((char*)g0->queue_last[1]) - ((char*)g0->nodes))); 732 | TIME = g0->TIME; 733 | } 734 | --------------------------------------------------------------------------------