├── README.md ├── preliminary contest ├── base.cpp ├── main43.cpp ├── main63.cpp └── test.sh └── warmup ├── .vscode ├── launch.json ├── settings.json └── tasks.json └── hello.cpp /README.md: -------------------------------------------------------------------------------- 1 | # CodeCraft2020 2 | 3 | 华为2020软件精英挑战赛复赛代码开源(京津东北赛区1504b,复赛A榜 rank1, B榜无有效成绩) 4 | 5 | ### 代码清单 6 | - main63.cpp : 6+3 版本代码,线下`4.8x`, 线上`2.6x`。(线下为`1963w`数据,下同) 7 | - main43.cpp : 4+3 版本代码,线下`5.5x`, 线上`2.3x`。 8 | - test.sh : 批量测试脚本,可用于测试路径、答案等是否正确。用法 : 9 | - 添加执行权限:`chmod u+x test.sh` 10 | - 测试:`./test.sh main.cpp` 11 | - (可能提示没有权限创建文件,以 sudo 权限运行即可) 12 | - test_data_xxxx.txt : 数据集,其中xx为环的个数 13 | - answer_xxxx.txt : 与数据集对应的答案 14 | - log.txt : 测试日志文件,可用于排查错误 15 | 16 | ### 基本思路 17 | 18 | **1.建图**(耗时`130ms`) 19 | 20 | 由于vector数组速度慢、静态数组受节点出入度的限制影响很大、动态数组管理不方便且在内存中排布不够紧凑等原因,采用**前向星**作为图的数据结构具有很大的优势。前向星中的边在内存中紧密排布,所需空间为`n * sizeof(Edge)`。其中,n 为边的数量,也就是转账记录数。 21 | 22 | 其他`trick`: 23 | 24 | (1)不使用`unordered_map`来映射; 25 | 26 | (2)减少哈系表的访问,如标记该节点是否在哈系表中,以减少`if(!Map.count(key))`的消耗。 27 | 28 | **2.找环**(`6+3`耗时`4.0x`,`4+3`耗时`4.8x` ) 29 | 30 | 很多同学都是`6+3`或`4+3`两个思路,我的`6+3`在线下的`1963w`数据集上表现优于`4+3`,但在线上却稍逊于`4+3`,看起来似乎`6+3`更适合于线下的随机图,而`4+3`更适合于线上的随机图+菊花图(+完全图)。(也有可能是我最后两天转的`4+3`,还没有调教得很好) 31 | 32 | 由于大部分同学都是用的这两个思路,因此方法层面没什么好说的,这里有几个`trick`可能有一定的优化作用: 33 | 34 | (1)尽量减少不必需的内存访问,如在`dfs`的过程中访问`ID`数组来获取原始id。 35 | 36 | (2)能用`uint8/bool`数组绝不用`u32/int`数组,这可能是因为`u32/int`数组占用空间大,`cache`中存在的几率更小,从而导致更多的`cache miss`。 37 | 38 | (3)`dfs`过程中的`if-continue`的判断次序很重要,对于更有可能导致`continue`的分支应该放在更前面。 39 | 40 | (4)避免不必要的判断,例如`if(DIS[curr] < 3)`已经包含`if(curr != start)`,因此后者无须再判断一次。 41 | 42 | **3.转换输出**(耗时`580+ms`) 43 | 44 | 我的转换输出并不快,主要思路是先将多线程找到的环进行合并,然后进行多线程查表转换,最后以**mmap+多线程memcpy**的方式写入文件。 45 | 46 | 这里采取的建表方式为:以映射后的id为下标,将映射前的id转换为字符串后存入对应的位置。例如,`5439817`映射后为`42`,则`conv[42] == “5439817”`。这样可以使得转换过程中无须进行反映射,从而减少的内存访问。 47 | 48 | 上述建表方式耗时`<1ms`,空间大小为`11 * n`,这里的`n`为有效节点数。 49 | 50 | ### 参赛感想 51 | 52 | 本次比赛从热身赛到复赛结束耗时2个月有余,收获不少但遗憾更多。遗憾之处在于没能拿到更好的名次,更在于不能到深圳和各位共同进步的大佬学习。 53 | -------------------------------------------------------------------------------- /preliminary contest/base.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | 2020/3/31 by WavenZ 3 | */ 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | 17 | using namespace std; 18 | 19 | const char* data_file = "/data/test_data.txt"; 20 | const char* result_file = "/projects/student/result.txt"; 21 | 22 | struct Node{ 23 | uint32_t from; 24 | vector to; 25 | uint32_t visit; 26 | uint16_t degree; 27 | Node(uint32_t f) : from(f), visit(0), degree(0) {} 28 | }; 29 | 30 | vector trans; // 交易 31 | unordered_map Map; // 账号映射为连续值 32 | vector> Res; // 结果 33 | 34 | void sort_trans(vector& trans){ // 将交易按字典序排序 35 | sort(trans.begin(), trans.end(), [](const Node* a, const Node* b){ 36 | return a->from < b->from; 37 | }); 38 | for(int i = 0; i < trans.size(); ++i){ 39 | sort(trans[i]->to.begin(), trans[i]->to.end(), [](const Node* a, const Node* b){ 40 | return a->from < b->from; 41 | }); 42 | } 43 | } 44 | 45 | void read(){ // 读取数据  46 | struct stat statue; 47 | int fd = open(data_file, O_RDONLY); 48 | fstat(fd, &statue); 49 | int size = statue.st_size; 50 | char* start = (char*)mmap(NULL, statue.st_size, PROT_READ, MAP_PRIVATE, fd, 0); 51 | char* curr = start; 52 | uint32_t from, to, amount, number = 0;; 53 | uint8_t state = 0; 54 | char ch; 55 | while(1){ 56 | ch = *curr; 57 | if(ch == ','){ 58 | state ? to = number : from = number; 59 | state = 1 - state; 60 | number = 0; 61 | } 62 | else if(ch == '\n'){ 63 | amount = number; 64 | if(!Map.count(from)) Map[from] = new Node(from); 65 | if(!Map.count(to)) Map[to] = new Node(to); 66 | Map[from]->to.push_back(Map[to]); 67 | Map[to]->degree ++; 68 | if(curr - start + 1 >= size) break; 69 | state = number = 0; 70 | } 71 | else{ 72 | number = number * 10 + ch - '0'; 73 | } 74 | curr++; 75 | } 76 | for(const auto& m : Map){ 77 | trans.push_back(m.second); 78 | } 79 | sort_trans(trans); 80 | } 81 | 82 | void solve(){ 83 | // 7层dfs算法 84 | vector res; 85 | int cnt = 0; 86 | for(int i = 0; i < trans.size(); i ++){ 87 | Node* node = trans[i]; 88 | if(node->degree > 0){ 89 | node->visit = 1; 90 | res.push_back(node->from); 91 | for(Node* node1 : node->to){ 92 | if(node1->from > node->from){ 93 | node1->visit = 1; 94 | res.push_back(node1->from); 95 | for(Node* node2 : node1->to){ 96 | if(node2->from > node->from){ 97 | node2->visit = 1; 98 | res.push_back(node2->from); 99 | for(Node* node3 : node2->to){ 100 | if(node3->from == node->from){ 101 | Res.push_back(res); 102 | } 103 | else if(node3->from < node->from || node3->visit == 1) continue; 104 | else{ 105 | node3->visit = 1; 106 | res.push_back(node3->from); 107 | for(Node* node4 : node3->to){ 108 | if(node4->from == node->from){ 109 | Res.push_back(res); 110 | } 111 | else if(node4->from < node->from || node4->visit == 1) continue; 112 | else{ 113 | node4->visit = 1; 114 | res.push_back(node4->from); 115 | for(Node* node5 : node4->to){ 116 | if(node5->from == node->from){ 117 | Res.push_back(res); 118 | } 119 | else if(node5->from < node->from || node5->visit == 1) continue; 120 | else{ 121 | node5->visit = 1; 122 | res.push_back(node5->from); 123 | for(Node* node6 : node5->to){ 124 | if(node6->from == node->from){ 125 | Res.push_back(res); 126 | } 127 | else if(node6->from < node->from || node6->visit == 1) continue; 128 | else{ 129 | node6->visit = 1; 130 | res.push_back(node6->from); 131 | for(Node* node7 : node6->to){ 132 | if(node7->from == node->from){ 133 | Res.push_back(res); 134 | } 135 | } 136 | node6->visit = 0; 137 | res.pop_back(); 138 | } 139 | } 140 | node5->visit = 0; 141 | res.pop_back(); 142 | } 143 | } 144 | node4->visit = 0; 145 | res.pop_back(); 146 | } 147 | } 148 | node3->visit = 0; 149 | res.pop_back(); 150 | } 151 | } 152 | node2->visit = 0; 153 | res.pop_back(); 154 | } 155 | } 156 | node1->visit = 0; 157 | res.pop_back(); 158 | } 159 | } 160 | node->visit = 0; 161 | res.pop_back(); 162 | } 163 | } 164 | } 165 | 166 | void sort_res(){ 167 | // 按长度排序 168 | stable_sort(Res.begin(), Res.end(), 169 | [](const vector& a, const vector& b){ 170 | return a.size() < b.size(); 171 | }); 172 | } 173 | 174 | void save_res(){ 175 | // 保存结果 176 | FILE* fp = fopen(result_file, "w"); 177 | fprintf(fp, "%d\n", (int)Res.size()); 178 | for(int i = 0; i < Res.size(); i ++){ 179 | for(int j = 0; j < Res[i].size() - 1; ++j){ 180 | fprintf(fp, "%d,", Res[i][j]); 181 | } 182 | fprintf(fp, "%d\n", Res[i].back()); 183 | } 184 | } 185 | 186 | int main(int argc, char* argv[]){ 187 | read(); 188 | solve(); 189 | sort_res(); 190 | save_res(); 191 | return 0; 192 | } 193 | -------------------------------------------------------------------------------- /preliminary contest/main43.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include 19 | 20 | #ifdef DEBUG 21 | #include "timer.h" 22 | #endif 23 | 24 | using namespace std; 25 | 26 | const int nthread = 4; // 线程数 27 | 28 | const int maxs = 20000000; // 最大环数 29 | const int maxn = 2000000; // 最大节点数 30 | 31 | // 数据读取buffer 32 | uint32_t loadbuf[nthread][maxn][3]; 33 | uint32_t loadnum[nthread]; 34 | 35 | // 图数据结构(前向星) 36 | struct Edge{ 37 | uint32_t id; 38 | uint32_t money; 39 | }; 40 | 41 | Edge From[maxn]; // 正图 42 | Edge To[maxn]; // 反图 43 | uint32_t phead[maxn]; // 各id起始位置 44 | uint32_t plen[maxn]; // 各id转账数 45 | uint32_t qhead[maxn]; 46 | uint32_t qlen[maxn]; 47 | 48 | // 简易的固定长度vector,效率和数组相近 49 | class Vector{ 50 | public: 51 | Vector(){ _data = new uint32_t[maxn]; } 52 | void push_back(uint32_t t){ 53 | _data[_size++] = t; 54 | } 55 | int size(){ return _size; } 56 | void clear(){ _size = 0; } 57 | uint32_t* begin(){ return _data; } 58 | uint32_t* end(){ return _data + _size; } 59 | uint32_t& operator[](int index){ return _data[index]; } 60 | private: 61 | uint32_t* _data; 62 | int _size; 63 | }; 64 | Vector ID; // 真实ID表 65 | 66 | // 简易的固定桶数hash_table,采用开放寻址法解决冲突 67 | class Hash_map{ 68 | public: 69 | Hash_map(){ 70 | map.resize(m); 71 | } 72 | void insert(uint32_t k){ 73 | // 如果不存在则插入,否则直接返回,作用等同: 74 | // if(!Map.count(key)) Map.insert(key); 75 | int i = 0; 76 | int hash_val = 0; 77 | while (i < m) { 78 | hash_val = hash(k, i); 79 | if (map[hash_val].val == -1){ 80 | map[hash_val].key = k; 81 | map[hash_val].val = ID.size(); 82 | map[hash_val].len = 1; 83 | ID.push_back(k); 84 | hashes.push_back(hash_val); 85 | cnt++; 86 | break; 87 | } 88 | if(map[hash_val].key == k){ 89 | map[hash_val].len++; 90 | break; 91 | }else i++; 92 | } 93 | } 94 | int search(uint32_t k){ 95 | // 搜索 96 | int i = 0, hash_val = 0; 97 | while (i < m) { 98 | hash_val = hash(k, i); 99 | if(map[hash_val].val == -1) break; 100 | if(map[hash_val].key == k) return map[hash_val].val; 101 | else i++; 102 | } 103 | return -1; 104 | } 105 | void sort_hash(){ 106 | // 将hash值排序,确保id映射前后相对大小不变化 107 | sort(hashes.begin(), hashes.end(), [&](int m, int m1){ 108 | return map[m].key < map[m1].key; 109 | }); 110 | for(int i = 0; i < hashes.size(); ++i){ 111 | map[hashes[i]].val = i; 112 | } 113 | } 114 | int size(){ 115 | return cnt; 116 | } 117 | private: 118 | // 常数取质数,且 m1 < m 119 | const int m = 2222281; 120 | const int m1 = 2205167; 121 | int cnt = 0; 122 | // 数据结构 123 | struct data{ 124 | uint32_t len = 0; 125 | uint32_t key; 126 | int val = -1; 127 | }; 128 | vector map; 129 | vector hashes; 130 | uint32_t hash(uint32_t k, int i){ 131 | // 哈希函数 132 | return k % m + i; // 一次哈希 133 | // return (k % m + i * (m1 - k % m1)) % m; // 双重哈希 134 | } 135 | }Map; 136 | 137 | struct Record{ 138 | uint32_t m, m1, n, n1; 139 | }; 140 | struct Route{ 141 | void push_back(uint32_t index, uint32_t money, uint32_t money1 , uint32_t node, uint32_t node1){ 142 | uint8_t& sz = size[index]; 143 | if(sz < 16){ 144 | route[index][sz].m = money; 145 | route[index][sz].m1 = money1; 146 | route[index][sz].n = node; 147 | route[index][sz].n1 = node1; 148 | }else if(sz == 16){ 149 | memcpy(route1[index], route[index], 16 * sizeof(Record)); 150 | route1[index][sz].m = money; 151 | route1[index][sz].m1 = money1; 152 | route1[index][sz].n = node; 153 | route1[index][sz].n1 = node1; 154 | }else{ 155 | route1[index][sz].m = money; 156 | route1[index][sz].m1 = money1; 157 | route1[index][sz].n = node; 158 | route1[index][sz].n1 = node1; 159 | } 160 | sz++; 161 | } 162 | Record route[maxn][16]; 163 | Record route1[maxn][128]; 164 | uint8_t size[maxn]; 165 | bool sorted[maxn]; 166 | }; 167 | 168 | inline bool check(const uint32_t& m, const uint32_t& m1){ 169 | return (m1 * 5UL >= m) && (m * 3UL >= m1); 170 | } 171 | 172 | uint32_t* loops[nthread][5]; // 存3-7环 173 | uint32_t loop_size[nthread]; // 存每个线程的总环数(不必要) 174 | uint32_t loops_size[maxn][5]; // 存每个id的3-7环个数 175 | 176 | uint8_t which_thread[maxn]; // 记录该id被哪个线程取到 177 | uint8_t DIS[nthread][maxn]; // dfs辅助数组,用于记录逆向的最近距离 178 | 179 | atomic curNode; // 原子计数,用于dfs负载均衡 180 | void* find_thread(void* arg){ 181 | /* 182 | Thread: search loops by dfs method. 183 | */ 184 | #ifdef DEBUG 185 | // Timer t; 186 | #endif 187 | int tid = *(int*)arg; 188 | Route* mroute = new Route; 189 | memset(mroute->size, 0, (ID.size() + 1) * sizeof(uint8_t)); 190 | memset(mroute->sorted, false, (ID.size() + 1) * sizeof(bool)); 191 | uint32_t* loop3 = loops[tid][0]; 192 | uint32_t* loop4 = loops[tid][1]; 193 | uint32_t* loop5 = loops[tid][2]; 194 | uint32_t* loop6 = loops[tid][3]; 195 | uint32_t* loop7 = loops[tid][4]; 196 | int cnt; 197 | uint32_t from; 198 | uint32_t cnt3, cnt4, cnt5, cnt6, cnt7; 199 | uint32_t node, node1, node2, node3, node4, node5, node6; 200 | uint32_t money, money1, money2, money3, money4, money5, money6; 201 | Vector reach; 202 | while(1){ 203 | node = curNode++; 204 | if(node >= ID.size()) break; 205 | which_thread[node] = tid; 206 | cnt = 0; 207 | if(plen[node] == 0 || qlen[node] == 0){ 208 | continue; 209 | } 210 | cnt3 = cnt4 = cnt5 = cnt6 = cnt7 = 0; 211 | reach.clear(); 212 | // 逆向三层记录路径 213 | for(int j = qhead[node]; j < qhead[node + 1]; ++j){ 214 | const Edge& from1 = From[j]; 215 | node1 = from1.id; 216 | if(node1 <= node) break; 217 | money1 = from1.money; 218 | for(int k = qhead[node1]; k < qhead[node1 + 1]; ++k){ 219 | const Edge& from2 = From[k]; 220 | node2 = from2.id; 221 | if(node2 <= node) break; 222 | money2 = from2.money; 223 | if(!check(money2, money1) || node2 == node) continue; 224 | for(int l = qhead[node2]; l < qhead[node2 + 1]; ++l){ 225 | const Edge& from3 = From[l]; 226 | node3 = from3.id; 227 | if(node3 <= node) break; 228 | money3 = from3.money; 229 | if(!check(money3, money2) || node3 == node || node3 == node1) continue; 230 | mroute->push_back(node3, money3, money1, node2, node1); 231 | reach.push_back(node3); 232 | } 233 | } 234 | } 235 | // 正向1层: 236 | // 正向2层:5环 237 | // 正向3层:3、6环 238 | // 正向4层:4、7环 239 | int j = phead[node]; 240 | for(; j < phead[node + 1]; ++j) if(To[j].id > node) break; 241 | for(; j < phead[node + 1]; ++j){ 242 | const Edge& to1 = To[j]; 243 | node1 = to1.id; 244 | money1 = to1.money; 245 | int k = phead[node1]; 246 | for(; k < phead[node1 + 1]; ++k) if(To[k].id > node) break; 247 | for(; k < phead[node1 + 1]; ++k){ 248 | const Edge& to2 = To[k]; 249 | node2 = to2.id; 250 | money2 = to2.money; 251 | if(!check(money1, money2)) continue; 252 | if(mroute->size[node2]){ // 五环 253 | if(mroute->size[node2] <= 16){ 254 | if(!mroute->sorted[node2] && mroute->size[node2] > 1){ 255 | sort(mroute->route[node2], mroute->route[node2] + mroute->size[node2], [](const Record& lhs, const Record& rhs){ 256 | if(lhs.n == rhs.n) return lhs.n1 < rhs.n1; 257 | return lhs.n < rhs.n; 258 | }); 259 | mroute->sorted[node2] = true; 260 | } 261 | for(int x = 0; x < mroute->size[node2]; ++x){ 262 | money3 = mroute->route[node2][x].m; 263 | money = mroute->route[node2][x].m1; 264 | if(check(money2, money3) && check(money, money1)){ 265 | node3 = mroute->route[node2][x].n; 266 | node4 = mroute->route[node2][x].n1; 267 | if(node3 == node1) continue; 268 | if(node4 == node1 || node4 == node2) continue; 269 | 270 | loop5[0] = node;loop5[1] = node1;loop5[2] = node2; 271 | loop5[3] = node3;loop5[4] = node4; 272 | loop5 += 5; 273 | cnt5++; 274 | } 275 | } 276 | }else{ 277 | if(!mroute->sorted[node2]){ 278 | sort(mroute->route1[node2], mroute->route1[node2] + mroute->size[node2], [](const Record& lhs, const Record& rhs){ 279 | if(lhs.n == rhs.n) return lhs.n1 < rhs.n1; 280 | return lhs.n < rhs.n; 281 | }); 282 | mroute->sorted[node2] = true; 283 | } 284 | for(int x = 0; x < mroute->size[node2]; ++x){ 285 | money3 = mroute->route1[node2][x].m; 286 | money = mroute->route1[node2][x].m1; 287 | if(check(money2, money3) && check(money, money1)){ 288 | node3 = mroute->route1[node2][x].n; 289 | node4 = mroute->route1[node2][x].n1; 290 | if(node3 == node1) continue; 291 | if(node4 == node1 || node4 == node2) continue; 292 | 293 | loop5[0] = node;loop5[1] = node1;loop5[2] = node2; 294 | loop5[3] = node3;loop5[4] = node4; 295 | loop5 += 5; 296 | cnt5++; 297 | } 298 | } 299 | } 300 | 301 | } 302 | int l = phead[node2]; 303 | for(; l < phead[node2 + 1]; ++l) if(To[l].id >= node) break; 304 | for(; l < phead[node2 + 1]; ++l){ 305 | const Edge& to3 = To[l]; 306 | node3 = to3.id; 307 | money3 = to3.money; 308 | if(!check(money2, money3)) continue; // || node3 == node1) continue; 309 | if(node3 == node){ 310 | if(check(money3, money1)){ // 三环 311 | loop3[0] = node;loop3[1] = node1;loop3[2] = node2; 312 | loop3 += 3; 313 | cnt3++; 314 | } 315 | continue; 316 | } 317 | if(node3 == node1) continue; 318 | if(mroute->size[node3]){ 319 | if(mroute->size[node3] <= 16){ 320 | if(!mroute->sorted[node3] && mroute->size[node3] > 1){ 321 | sort(mroute->route[node3], mroute->route[node3] + mroute->size[node3], [](const Record& lhs, const Record& rhs){ 322 | if(lhs.n == rhs.n) return lhs.n1 < rhs.n1; 323 | return lhs.n < rhs.n; 324 | }); 325 | mroute->sorted[node3] = true; 326 | } 327 | for(int x = 0; x < mroute->size[node3]; ++x){ 328 | money4 = mroute->route[node3][x].m; 329 | money = mroute->route[node3][x].m1; 330 | if(check(money3, money4) && check(money, money1)){ 331 | node4 = mroute->route[node3][x].n; 332 | node5 = mroute->route[node3][x].n1; 333 | if(node4 == node1 || node5 == node1 || node4 == node2 || node5 == node2 || node5 == node3) continue; 334 | loop6[0] = node;loop6[1] = node1;loop6[2] = node2; 335 | loop6[3] = node3;loop6[4] = node4;loop6[5] = node5; 336 | loop6 += 6; 337 | cnt6++; 338 | } 339 | } 340 | }else{ 341 | if(!mroute->sorted[node3]){ 342 | sort(mroute->route1[node3], mroute->route1[node3] + mroute->size[node3], [](const Record& lhs, const Record& rhs){ 343 | if(lhs.n == rhs.n) return lhs.n1 < rhs.n1; 344 | return lhs.n < rhs.n; 345 | }); 346 | mroute->sorted[node3] = true; 347 | } 348 | for(int x = 0; x < mroute->size[node3]; ++x){ 349 | money4 = mroute->route1[node3][x].m; 350 | money = mroute->route1[node3][x].m1; 351 | if(check(money3, money4) && check(money, money1)){ 352 | node4 = mroute->route1[node3][x].n; 353 | node5 = mroute->route1[node3][x].n1; 354 | if(node4 == node1 || node5 == node1 || node4 == node2 || node5 == node2 || node5 == node3) continue; 355 | loop6[0] = node;loop6[1] = node1;loop6[2] = node2; 356 | loop6[3] = node3;loop6[4] = node4;loop6[5] = node5; 357 | loop6 += 6; 358 | cnt6++; 359 | } 360 | } 361 | } 362 | 363 | } 364 | 365 | int m = phead[node3]; 366 | for(; m < phead[node3 + 1]; ++m) if(To[m].id >= node) break; 367 | for(; m < phead[node3 + 1]; ++m){ 368 | const Edge& to4 = To[m]; 369 | node4 = to4.id; 370 | money4 = to4.money; 371 | if(!check(money3, money4)) continue; 372 | if(node4 == node){ 373 | if(check(money4, money1)){ // 四环 374 | loop4[0] = node;loop4[1] = node1; 375 | loop4[2] = node2;loop4[3] = node3; 376 | loop4 += 4; 377 | cnt4++; 378 | } 379 | continue; 380 | } 381 | if(node4 == node1 || node4 == node2) continue; 382 | if(mroute->size[node4]){ // 七环 383 | if(mroute->size[node4] <= 16){ 384 | if(!mroute->sorted[node4] && mroute->size[node4] > 1){ 385 | sort(mroute->route[node4], mroute->route[node4] + mroute->size[node4], [](const Record& lhs, const Record& rhs){ 386 | if(lhs.n == rhs.n) return lhs.n1 < rhs.n1; 387 | return lhs.n < rhs.n; 388 | }); 389 | mroute->sorted[node4] = true; 390 | } 391 | for(int x = 0; x < mroute->size[node4]; ++x){ 392 | money5 = mroute->route[node4][x].m; 393 | money = mroute->route[node4][x].m1; 394 | if(check(money4, money5) && check(money, money1)){ 395 | node5 = mroute->route[node4][x].n; 396 | node6 = mroute->route[node4][x].n1; 397 | if(node5 == node1 || node5 == node2 || node5 == node3) continue; 398 | if(node6 == node1 || node6 == node2 || node6 == node3 || node6 == node4) continue; 399 | loop7[0] = node;loop7[1] = node1;loop7[2] = node2;loop7[3] = node3; 400 | loop7[4] = node4;loop7[5] = node5;loop7[6] = node6; 401 | loop7 += 7; 402 | cnt7++; 403 | } 404 | } 405 | }else{ 406 | if(!mroute->sorted[node4]){ 407 | sort(mroute->route1[node4], mroute->route1[node4] + mroute->size[node4], [](const Record& lhs, const Record& rhs){ 408 | if(lhs.n == rhs.n) return lhs.n1 < rhs.n1; 409 | return lhs.n < rhs.n; 410 | }); 411 | mroute->sorted[node4] = true; 412 | } 413 | for(int x = 0; x < mroute->size[node4]; ++x){ 414 | money5 = mroute->route1[node4][x].m; 415 | money = mroute->route1[node4][x].m1; 416 | if(check(money4, money5) && check(money, money1)){ 417 | node5 = mroute->route1[node4][x].n; 418 | node6 = mroute->route1[node4][x].n1; 419 | if(node5 == node1 || node5 == node2 || node5 == node3) continue; 420 | if(node6 == node1 || node6 == node2 || node6 == node3 || node6 == node4) continue; 421 | loop7[0] = node;loop7[1] = node1;loop7[2] = node2;loop7[3] = node3; 422 | loop7[4] = node4;loop7[5] = node5;loop7[6] = node6; 423 | loop7 += 7; 424 | cnt7++; 425 | } 426 | } 427 | } 428 | 429 | } 430 | 431 | } 432 | } 433 | } 434 | } 435 | 436 | for(int j = 0; j < reach.size(); ++j){ 437 | mroute->size[reach[j]] = 0; 438 | mroute->sorted[reach[j]] = false; 439 | } 440 | loop_size[tid] = (loop3 - loops[tid][0]) / 3 + 441 | (loop4 - loops[tid][1]) / 4 + 442 | (loop5 - loops[tid][2]) / 5 + 443 | (loop6 - loops[tid][3]) / 6 + 444 | (loop7 - loops[tid][4]) / 7; 445 | loops_size[node][0] = cnt3; 446 | loops_size[node][1] = cnt4; 447 | loops_size[node][2] = cnt5; 448 | loops_size[node][3] = cnt6; 449 | loops_size[node][4] = cnt7; 450 | 451 | } 452 | *loop3 = *loop4 = *loop5 = *loop6 = *loop7 = 0xffffffff; 453 | } 454 | 455 | void find_loops(){ 456 | #ifdef DEBUG 457 | cout << __func__ << endl; 458 | Timer t; 459 | #endif 460 | // 开内存用于存环 461 | for(int i = 0; i < nthread; ++i){ 462 | for(int j = 0; j < 5; ++j){ 463 | loops[i][j] = new uint32_t[maxs / 2 * (j + 3)]; 464 | } 465 | } 466 | // 多线程找环 467 | pthread_t threads[nthread]; 468 | int tid[nthread]; 469 | for(int i = 0; i < nthread; ++i){ 470 | tid[i] = i; 471 | pthread_create(&threads[i], NULL, find_thread, (void*)&tid[i]); 472 | } 473 | for(int i = 0; i < nthread; ++i) 474 | pthread_join(threads[i], NULL); 475 | #ifdef DEBUG 476 | int loops_total = 0; 477 | for(int i = 0; i < nthread; ++i){ 478 | loops_total += (loop_size[i]); 479 | } 480 | cout << "loops: " << loops_total << endl; 481 | #endif 482 | } 483 | 484 | uint32_t* Loops[5]; // 合并的结果 485 | size_t accu[8]; // 3-7环的累计和 486 | void merge_loops(){ 487 | /* 488 | 合并各线程找到的环 489 | */ 490 | #ifdef DEBUG 491 | cout << __func__ << endl; 492 | Timer t; 493 | #endif 494 | // 开内存 495 | uint32_t* lptr[5]; 496 | for(int i = 0; i < 5; ++i){ 497 | Loops[i] = new uint32_t[maxs * (i + 3)]; 498 | lptr[i] = Loops[i]; 499 | } 500 | // 辅助指针 501 | uint32_t* ptr[nthread][5]; 502 | for(int i = 0; i < nthread; ++i){ 503 | for(int j = 0; j < 5; ++j){ 504 | ptr[i][j] = loops[i][j]; 505 | } 506 | } 507 | // 合并 508 | uint32_t len; 509 | for(int i = 0; i < 5; ++i){ 510 | for(int j = 0; j < ID.size(); ++j){ 511 | uint32_t* &curr = ptr[which_thread[j]][i]; // 从各个线程取出结果进行合并 512 | if(loops_size[j][i]){ 513 | len = loops_size[j][i] * (i + 3); 514 | memcpy(lptr[i], curr, len * sizeof(uint32_t)); 515 | lptr[i] += len; 516 | curr += len; 517 | } 518 | } 519 | } 520 | // 计算累计和 521 | for(int i = 3; i < 8; ++i){ 522 | accu[i] = accu[i - 1] + (lptr[i - 3] - Loops[i - 3]) / i; 523 | } 524 | } 525 | 526 | inline void convert(uint32_t temp, char* ptr){ 527 | // 转换:uint32 -> char[] 528 | // 仅在多线程建转换表时使用 529 | if(temp < 10){ 530 | ptr[0] = 1; 531 | ptr[1] = temp % 10 + '0'; 532 | }else if(temp < 100){ 533 | ptr[0] = 2; 534 | ptr[1] = temp / 10 % 10 + '0'; 535 | ptr[2] = temp % 10 + '0'; 536 | }else if(temp < 1000){ 537 | ptr[0] = 3; 538 | ptr[1] = temp / 100 % 10 + '0'; 539 | ptr[2] = temp / 10 % 10 + '0'; 540 | ptr[3] = temp % 10 + '0'; 541 | }else if(temp < 10000){ 542 | ptr[0] = 4; 543 | ptr[1] = temp / 1000 % 10 + '0'; 544 | ptr[2] = temp / 100 % 10 + '0'; 545 | ptr[3] = temp / 10 % 10 + '0'; 546 | ptr[4] = temp % 10 + '0'; 547 | }else if(temp < 100000){ 548 | ptr[0] = 5; 549 | ptr[1] = temp / 10000 % 10 + '0'; 550 | ptr[2] = temp / 1000 % 10 + '0'; 551 | ptr[3] = temp / 100 % 10 + '0'; 552 | ptr[4] = temp / 10 % 10 + '0'; 553 | ptr[5] = temp % 10 + '0'; 554 | }else if(temp < 1000000){ 555 | ptr[0] = 6; 556 | ptr[1] = temp / 100000 % 10 + '0'; 557 | ptr[2] = temp / 10000 % 10 + '0'; 558 | ptr[3] = temp / 1000 % 10 + '0'; 559 | ptr[4] = temp / 100 % 10 + '0'; 560 | ptr[5] = temp / 10 % 10 + '0'; 561 | ptr[6] = temp % 10 + '0'; 562 | }else if(temp < 10000000){ 563 | ptr[0] = 7; 564 | ptr[1] = temp / 1000000 % 10 + '0'; 565 | ptr[2] = temp / 100000 % 10 + '0'; 566 | ptr[3] = temp / 10000 % 10 + '0'; 567 | ptr[4] = temp / 1000 % 10 + '0'; 568 | ptr[5] = temp / 100 % 10 + '0'; 569 | ptr[6] = temp / 10 % 10 + '0'; 570 | ptr[7] = temp % 10 + '0'; 571 | }else if(temp < 100000000){ 572 | ptr[0] = 8; 573 | ptr[1] = temp / 10000000 % 10 + '0'; 574 | ptr[2] = temp / 1000000 % 10 + '0'; 575 | ptr[3] = temp / 100000 % 10 + '0'; 576 | ptr[4] = temp / 10000 % 10 + '0'; 577 | ptr[5] = temp / 1000 % 10 + '0'; 578 | ptr[6] = temp / 100 % 10 + '0'; 579 | ptr[7] = temp / 10 % 10 + '0'; 580 | ptr[8] = temp % 10 + '0'; 581 | }else if(temp < 1000000000){ 582 | ptr[0] = 9; 583 | ptr[1] = temp / 100000000 % 10 + '0'; 584 | ptr[2] = temp / 10000000 % 10 + '0'; 585 | ptr[3] = temp / 1000000 % 10 + '0'; 586 | ptr[4] = temp / 100000 % 10 + '0'; 587 | ptr[5] = temp / 10000 % 10 + '0'; 588 | ptr[6] = temp / 1000 % 10 + '0'; 589 | ptr[7] = temp / 100 % 10 + '0'; 590 | ptr[8] = temp / 10 % 10 + '0'; 591 | ptr[9] = temp % 10 + '0'; 592 | }else{ 593 | ptr[0] = 10; 594 | ptr[1] = temp / 1000000000 % 10 + '0'; 595 | ptr[2] = temp / 100000000 % 10 + '0'; 596 | ptr[3] = temp / 10000000 % 10 + '0'; 597 | ptr[4] = temp / 1000000 % 10 + '0'; 598 | ptr[5] = temp / 100000 % 10 + '0'; 599 | ptr[6] = temp / 10000 % 10 + '0'; 600 | ptr[7] = temp / 1000 % 10 + '0'; 601 | ptr[8] = temp / 100 % 10 + '0'; 602 | ptr[9] = temp / 10 % 10 + '0'; 603 | ptr[10] = temp % 10 + '0'; 604 | } 605 | } 606 | 607 | 608 | char conv[maxn][11]; // 转换表 609 | void* conv_thread(void* args){ 610 | // 将映射后的id作为下标,真实id转换为字符串后保存在对应位置. 611 | // 例如: Map[5142132] = 42,则 conv[42] = "5142132"; 612 | // 上述操作可以减少一次逆映射导致的访存 613 | int tid = *(int*)args; 614 | for(int i = tid; i < ID.size(); i += nthread){ 615 | convert(ID[i], conv[i]); 616 | } 617 | } 618 | void convert_init(){ 619 | pthread_t threads[nthread]; 620 | int tid[nthread]; 621 | for(int i = 0; i < nthread; ++i){ 622 | tid[i] = i; 623 | pthread_create(&threads[i], NULL, conv_thread, (void*)&tid[i]); 624 | } 625 | for(int i = 0; i < nthread; ++i) 626 | pthread_join(threads[i], NULL); 627 | } 628 | 629 | 630 | char* buf[nthread]; 631 | int len[nthread]; 632 | void* itoa(void* arg){ 633 | // 将结果转换为字符串 634 | 635 | #ifdef DEBUG 636 | Timer t; 637 | // cout << __func__ << endl; 638 | #endif 639 | int tid = *(int*)arg; 640 | // 任务分割,基本均衡 641 | int start = tid * (accu[7] / nthread); 642 | int end = (tid + 1) * (accu[7] / nthread); 643 | if(tid == nthread - 1) end = accu[7]; 644 | 645 | // 查表转换 646 | int temp = 0, offset; 647 | char* ptr = buf[tid]; 648 | for(int i = start; i < end; ++i){ 649 | for(int j = 3; j < 8; ++j){ // 某些线程可能跨长度,典型为线程0跨越3-7环 650 | if(i < accu[j]){ 651 | offset = (i - accu[j - 1]) * j; 652 | for(int k = 0; k < j; ++k){ 653 | temp = Loops[j - 3][offset + k]; 654 | char& sz = conv[temp][0]; 655 | memcpy(ptr, conv[temp] + 1, sz); 656 | ptr += sz; 657 | *ptr++ = ','; 658 | } 659 | *(ptr - 1) = '\n'; 660 | break; 661 | } 662 | } 663 | } 664 | len[tid] = ptr - buf[tid]; 665 | } 666 | 667 | // 保存结果:mmap + 多线程memcpy 668 | // 线下看起来快,线上和fwrite差不多 669 | char* write_pos; 670 | void* save(void* arg){ 671 | int tid = *(int*)arg; 672 | int start = 0; 673 | if(tid){ 674 | for(int i = 0; i < tid; ++i) start += len[i]; 675 | } 676 | memcpy(write_pos + start, buf[tid], len[tid]); 677 | } 678 | 679 | const char* test_data_file = "/data/test_data.txt"; 680 | const char* result_file = "/projects/student/result.txt"; 681 | 682 | void save_loops(){ 683 | // 保存结果 684 | #ifdef DEBUG 685 | Timer t; 686 | cout << __func__ << endl; 687 | #endif 688 | // 开内存 689 | for(int i = 0; i < nthread; ++i){ 690 | buf[i] = new char[1024 * 1024 * 512]; 691 | } 692 | // 构造转换表 693 | convert_init(); 694 | // 结果转换为字符串 695 | pthread_t threads[nthread]; 696 | int tid[nthread]; 697 | for(int i = 0; i < nthread; ++i){ 698 | tid[i] = i; 699 | pthread_create(&threads[i], NULL, itoa, (void*)&tid[i]); 700 | } 701 | 702 | for(int i = 0; i < nthread; ++i) 703 | pthread_join(threads[i], NULL); 704 | // 环数转换为字符串 705 | char* size_buf = new char[12]; 706 | char* ptr = size_buf; 707 | uint32_t size = accu[7]; 708 | convert(size, ptr); 709 | ptr += (*size_buf++ + 1); 710 | *ptr++ = '\n'; 711 | // mmap写 712 | int total_len = (ptr - size_buf); 713 | for(int i = 0; i < nthread; ++i) total_len += len[i]; 714 | int fd = open(result_file, O_RDWR | O_CREAT , 0666); 715 | lseek(fd, total_len - 1, SEEK_SET); 716 | int ret = write(fd, "\0", 1); 717 | char* ptr_ans = (char*)mmap(NULL, total_len, PROT_WRITE, MAP_SHARED, fd, 0); 718 | close(fd); 719 | // 写环数 720 | memcpy(ptr_ans, size_buf, ptr - size_buf); 721 | write_pos = ptr_ans + (ptr - size_buf); 722 | // 多线程写环 723 | for(int i = 0; i < nthread; ++i){ 724 | pthread_create(&threads[i], NULL, save, (void*)&tid[i]); 725 | } 726 | for(int i = 0; i < nthread; ++i) 727 | pthread_join(threads[i], NULL); 728 | } 729 | 730 | void* sort_thread(void* arg){ 731 | // 对每个节点的邻接表排序:正向正序(保证结果为字典序),反向反序(用于反向dfs提前break退出) 732 | int tid = *(int*)arg; 733 | for(int i = tid; i < ID.size(); i += nthread){ 734 | sort(To + phead[i], To + phead[i] + plen[i], [](const Edge& m, const Edge& m1){ 735 | return m.id < m1.id; 736 | }); 737 | sort(From + qhead[i], From + qhead[i] + qlen[i], [](const Edge& m, const Edge& m1){ 738 | return m.id > m1.id; 739 | }); 740 | } 741 | } 742 | 743 | void* build_thread(void* arg){ 744 | // 构造前向星 745 | // 1.统计每个节点邻接点个数 746 | // 2.计算每个节点开始位置 747 | // 3.遍历所有节点构造前向星 748 | int tid = *(int*)arg; 749 | uint32_t from, to, money; 750 | if(tid == 0){ // 正图 751 | int* curlen = new int[ID.size() + 1](); 752 | for(int k = 0; k < nthread; ++k){ 753 | for(int i = 0; i < loadnum[k]; ++i){ 754 | if(loadbuf[k][i][1] != 0xffffffff){ 755 | from = loadbuf[k][i][0]; 756 | plen[from]++; 757 | } 758 | } 759 | } 760 | phead[0] = 0; 761 | for(int i = 1; i <= ID.size(); ++i){ 762 | phead[i] = phead[i - 1] + plen[i - 1]; 763 | } 764 | for(int k = 0; k < nthread; ++k){ 765 | for(int i = 0; i < loadnum[k]; ++i){ 766 | if(loadbuf[k][i][1] != 0xffffffff){ 767 | to = loadbuf[k][i][1]; 768 | from = loadbuf[k][i][0]; 769 | To[phead[from] + curlen[from]].id = to; 770 | To[phead[from] + curlen[from]++].money = loadbuf[k][i][2]; 771 | } 772 | 773 | } 774 | } 775 | 776 | }else{ // 反图 777 | int* curlen = new int[ID.size() + 1](); 778 | for(int k = 0; k < nthread; ++k){ 779 | for(int i = 0; i < loadnum[k]; ++i){ 780 | if(loadbuf[k][i][1] != 0xffffffff){ 781 | to = loadbuf[k][i][1]; 782 | qlen[to]++; 783 | } 784 | } 785 | } 786 | qhead[0] = 0; 787 | for(int i = 1; i <= ID.size(); ++i){ 788 | qhead[i] = qhead[i - 1] + qlen[i - 1]; 789 | } 790 | for(int k = 0; k < nthread; ++k){ 791 | for(int i = 0; i < loadnum[k]; ++i){ 792 | if(loadbuf[k][i][1] != 0xffffffff){ 793 | to = loadbuf[k][i][1]; 794 | from = loadbuf[k][i][0]; 795 | From[qhead[to] + curlen[to]].id = from; 796 | From[qhead[to] + curlen[to]++].money = loadbuf[k][i][2]; 797 | } 798 | } 799 | } 800 | } 801 | } 802 | 803 | char* file; 804 | int file_size; 805 | void* load_thread(void* args){ 806 | // 多线程读图 807 | int tid = *(int*)args; 808 | int size = file_size / nthread; 809 | if(tid == nthread - 1) size = file_size - (nthread - 1) * size; 810 | 811 | char* start = file + tid * (file_size / nthread); 812 | char* curr = start; 813 | 814 | // 确保两个线程不读到分割位置的同一行 815 | if(tid != 0 && *(curr - 1) != '\n') while(*curr++ != '\n'); 816 | 817 | uint32_t from, to , money, temp = 0; 818 | uint8_t state = 0; 819 | char ch; 820 | while(1){ 821 | ch = *curr; 822 | if(ch == ','){ 823 | state ? to = temp : from = temp; 824 | state = 1 - state; 825 | temp = 0; 826 | } 827 | else if(ch == '\r' || ch == '\n'){ 828 | loadbuf[tid][loadnum[tid]][0] = from; 829 | loadbuf[tid][loadnum[tid]][1] = to; 830 | loadbuf[tid][loadnum[tid]][2] = temp; 831 | loadnum[tid]++; 832 | if(ch == '\r') curr++; 833 | if(curr - start + 1 >= size) break; 834 | state = temp = 0; 835 | } 836 | else{ 837 | temp = temp * 10 + ch - '0'; 838 | } 839 | curr++; 840 | } 841 | } 842 | 843 | void* clear_thread(void* args){ 844 | // 删除记录中出度为0的节点 845 | // 并将记录id映射为{ 0 1 2 ... n-1 } 846 | int tid = *(int*)args; 847 | uint32_t from, to; 848 | for(int i = 0; i < loadnum[tid]; ++i){ 849 | to = Map.search(loadbuf[tid][i][1]); 850 | if(to == -1){ 851 | loadbuf[tid][i][1] = 0xffffffff; // 标记这一行转账数据不要了 852 | }else{ 853 | from = Map.search(loadbuf[tid][i][0]); 854 | loadbuf[tid][i][0] = from; 855 | loadbuf[tid][i][1] = to; 856 | } 857 | } 858 | } 859 | 860 | void* shit_thread(void* args){ 861 | // 哈希排序,确保映射前后相对大小不变 862 | int tid = *(int*)args; 863 | uint32_t from , to; 864 | if(tid == 0){ 865 | Map.sort_hash(); 866 | }else{ 867 | sort(ID.begin(), ID.end()); 868 | } 869 | } 870 | 871 | void load_data(){ 872 | /* 873 | @brief: load data from file "/data/test_data.txt" 874 | @method: mmap 875 | */ 876 | #ifdef DEBUG 877 | cout << __func__ << endl; 878 | Timer t; 879 | #endif 880 | // mmap 881 | struct stat statue; 882 | int fd = open(test_data_file, O_RDONLY); 883 | fstat(fd, &statue); 884 | file_size = statue.st_size; 885 | file = (char*)mmap(NULL, statue.st_size, PROT_READ, MAP_PRIVATE, fd, 0); 886 | close(fd); 887 | // 1.多线程读数据到loadbuf中 888 | pthread_t threads[nthread]; 889 | int tid[nthread]; 890 | for(int i = 0; i < nthread; ++i){ 891 | tid[i] = i; 892 | pthread_create(&threads[i], NULL, load_thread, (void*)&tid[i]); 893 | } 894 | for(int i = 0; i < nthread; ++i) 895 | pthread_join(threads[i], NULL); 896 | 897 | // 2.哈希(仅映射 {u, v, w} 中的 u,因为不在 U 中出现的 id 肯定不成环) 898 | for(int k = 0; k < nthread; ++k){ 899 | for(int i = 0; i < loadnum[k]; ++i){ 900 | Map.insert(loadbuf[k][i][0]); 901 | } 902 | } 903 | // 3.哈希排序,确保映射前后相对大小不变 904 | for(int i = 0; i < 2; ++i){ 905 | tid[i] = i; 906 | pthread_create(&threads[i], NULL, shit_thread, (void*)&tid[i]); 907 | } 908 | for(int i = 0; i < 2; ++i) 909 | pthread_join(threads[i], NULL); 910 | 911 | // 4.将不用的记录删除(使得后续不用再查询哈系表) 912 | for(int i = 0; i < nthread; ++i){ 913 | tid[i] = i; 914 | pthread_create(&threads[i], NULL, clear_thread, (void*)&tid[i]); 915 | } 916 | for(int i = 0; i < nthread; ++i) 917 | pthread_join(threads[i], NULL); 918 | 919 | // 5.构造前向星 920 | for(int i = 0; i < 2; ++i){ 921 | tid[i] = i; 922 | pthread_create(&threads[i], NULL, build_thread, (void*)&tid[i]); 923 | } 924 | for(int i = 0; i < 2; ++i) 925 | pthread_join(threads[i], NULL); 926 | 927 | // 6.前向星排序 928 | for(int i = 0; i < nthread; ++i){ 929 | tid[i] = i; 930 | pthread_create(&threads[i], NULL, sort_thread, (void*)&tid[i]); 931 | } 932 | for(int i = 0; i < nthread; ++i) 933 | pthread_join(threads[i], NULL); 934 | } 935 | 936 | int main(int argc, char** argv){ 937 | #ifdef DEBUG 938 | Timer t; 939 | #endif 940 | load_data(); 941 | find_loops(); 942 | merge_loops(); 943 | save_loops(); 944 | #ifdef DEBUG 945 | cout << __func__ << endl; 946 | #endif 947 | return 0; 948 | } 949 | -------------------------------------------------------------------------------- /preliminary contest/main63.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include 19 | 20 | #ifdef DEBUG 21 | #include "timer.h" 22 | #endif 23 | 24 | using namespace std; 25 | 26 | const int nthread = 4; // 线程数 27 | 28 | const int maxs = 20000000; // 最大环数 29 | const int maxn = 2000000; // 最大节点数 30 | 31 | // 数据读取buffer 32 | uint32_t loadbuf[nthread][maxn][3]; 33 | uint32_t loadnum[nthread]; 34 | 35 | // 图数据结构(前向星) 36 | struct Edge{ 37 | uint32_t id; 38 | uint32_t money; 39 | }; 40 | 41 | Edge From[maxn]; // 正图 42 | Edge To[maxn]; // 反图 43 | uint32_t phead[maxn]; // 各id起始位置 44 | uint32_t plen[maxn]; // 各id转账数 45 | uint32_t qhead[maxn]; 46 | uint32_t qlen[maxn]; 47 | 48 | // 简易的固定长度vector,效率和数组相近 49 | class Vector{ 50 | public: 51 | Vector(){ _data = new uint32_t[maxn]; } 52 | void push_back(uint32_t t){ 53 | _data[_size++] = t; 54 | } 55 | int size(){ return _size; } 56 | void clear(){ _size = 0; } 57 | uint32_t* begin(){ return _data; } 58 | uint32_t* end(){ return _data + _size; } 59 | uint32_t& operator[](int index){ return _data[index]; } 60 | private: 61 | uint32_t* _data; 62 | int _size; 63 | }; 64 | Vector ID; // 真实ID表 65 | 66 | // 简易的固定桶数hash_table,采用开放寻址法解决冲突 67 | class Hash_map{ 68 | public: 69 | Hash_map(){ 70 | map.resize(m); 71 | } 72 | void insert(uint32_t k){ 73 | // 如果不存在则插入,否则直接返回,作用等同: 74 | // if(!Map.count(key)) Map.insert(key); 75 | int i = 0; 76 | int hash_val = 0; 77 | while (i < m) { 78 | hash_val = hash(k, i); 79 | if (map[hash_val].val == -1){ 80 | map[hash_val].key = k; 81 | map[hash_val].val = ID.size(); 82 | map[hash_val].len = 1; 83 | ID.push_back(k); 84 | hashes.push_back(hash_val); 85 | cnt++; 86 | break; 87 | } 88 | if(map[hash_val].key == k){ 89 | map[hash_val].len++; 90 | break; 91 | }else i++; 92 | } 93 | } 94 | int search(uint32_t k){ 95 | // 搜索 96 | int i = 0, hash_val = 0; 97 | while (i < m) { 98 | hash_val = hash(k, i); 99 | if(map[hash_val].val == -1) break; 100 | if(map[hash_val].key == k) return map[hash_val].val; 101 | else i++; 102 | } 103 | return -1; 104 | } 105 | void sort_hash(){ 106 | // 将hash值排序,确保id映射前后相对大小不变化 107 | sort(hashes.begin(), hashes.end(), [&](int a, int b){ 108 | return map[a].key < map[b].key; 109 | }); 110 | for(int i = 0; i < hashes.size(); ++i){ 111 | map[hashes[i]].val = i; 112 | } 113 | } 114 | int size(){ 115 | return cnt; 116 | } 117 | private: 118 | // 常数取质数,且 m1 < m 119 | const int m = 2222281; 120 | const int m1 = 2205167; 121 | int cnt = 0; 122 | // 数据结构 123 | struct data{ 124 | uint32_t len = 0; 125 | uint32_t key; 126 | int val = -1; 127 | }; 128 | vector map; 129 | vector hashes; 130 | uint32_t hash(uint32_t k, int i){ 131 | // 哈希函数 132 | return k % m + i; // 一次哈希 133 | // return (k % m + i * (m1 - k % m1)) % m; // 双重哈希 134 | } 135 | }Map; 136 | 137 | inline bool check(const uint32_t& a, const uint32_t& b){ 138 | return (b * 5UL >= a) && (a * 3UL >= b); 139 | } 140 | 141 | 142 | uint32_t* loops[nthread][5]; // 存3-7环 143 | uint32_t loop_size[nthread]; // 存每个线程的总环数(不必要) 144 | uint32_t loops_size[maxn][5]; // 存每个id的3-7环个数 145 | 146 | uint8_t which_thread[maxn]; // 记录该id被哪个线程取到 147 | uint8_t DIS[nthread][maxn]; // dfs辅助数组,用于记录逆向的最近距离 148 | 149 | atomic curNode; // 原子计数,用于dfs负载均衡 150 | void* find_thread(void* arg){ 151 | /* 152 | 方法:6+3 153 | */ 154 | #ifdef DEBUG 155 | Timer t; 156 | #endif 157 | int tid = *(int*)arg; 158 | uint32_t* Money = new uint32_t[ID.size() + 1]; // 记录逆向1层的金额,用于减少第7层搜索 159 | memset(DIS[tid], 8, (ID.size() + 1) * sizeof(uint8_t)); 160 | Vector reach; // 记录逆向到达的节点,用于清空DIS数组 161 | uint32_t* loop3 = loops[tid][0]; 162 | uint32_t* loop4 = loops[tid][1]; 163 | uint32_t* loop5 = loops[tid][2]; 164 | uint32_t* loop6 = loops[tid][3]; 165 | uint32_t* loop7 = loops[tid][4]; 166 | uint32_t from; 167 | uint32_t cnt3, cnt4, cnt5, cnt6, cnt7; 168 | uint32_t node, node1, node2, node3, node4, node5, node6; 169 | uint32_t money, money1, money2, money3, money4, money5, money6; 170 | while(1){ 171 | // 获取任务 172 | node = curNode++; 173 | if(node >= ID.size()) break; 174 | which_thread[node] = tid; 175 | // 跳过出/入度为0的节点 176 | if(plen[node] == 0 || qlen[node] == 0){ 177 | continue; 178 | } 179 | 180 | reach.clear(); 181 | cnt3 = cnt4 = cnt5 = cnt6 = cnt7 = 0; 182 | // 逆向3层dfs,记录到起始点的最近距离 183 | for(int j = qhead[node]; j < qhead[node + 1]; ++j){ 184 | const Edge& from1 = From[j]; 185 | node1 = from1.id; 186 | if(node1 < node) break; 187 | money1 = from1.money; 188 | Money[node1] = money1; 189 | DIS[tid][node1] = 1; 190 | reach.push_back(node1); 191 | for(int k = qhead[node1]; k < qhead[node1 + 1]; ++k){ 192 | const Edge& from2 = From[k]; 193 | node2 = from2.id; 194 | money2 = from2.money; 195 | if(node2 < node) break; 196 | if(!check(money2, money1) || node2 == node) continue; 197 | else if(2 < DIS[tid][node2]){ 198 | DIS[tid][node2] = 2; 199 | reach.push_back(node2); 200 | } 201 | for(int l = qhead[node2]; l < qhead[node2 + 1]; ++l){ 202 | const Edge& from3 = From[l]; 203 | node3 = from3.id; 204 | money3 = from3.money; 205 | if(node3 < node) break; 206 | else if(!check(money3, money2) ||node3 == node || node3 == node1) continue; 207 | if(3 < DIS[tid][node3]){ 208 | DIS[tid][node3] = 3; 209 | reach.push_back(node3); 210 | } 211 | } 212 | } 213 | } 214 | DIS[tid][node] = 0; 215 | // 正向6层dfs,循环中的判断顺序比较重要 216 | int j = phead[node]; 217 | for(; j < phead[node + 1]; ++j) if(To[j].id > node) break; 218 | for(; j < phead[node + 1]; ++j){ 219 | const Edge& to1 = To[j]; 220 | node1 = to1.id; 221 | money1 = to1.money; 222 | int k = phead[node1]; 223 | for(; k < phead[node1 + 1]; ++k) if(To[k].id > node) break; 224 | for(; k < phead[node1 + 1]; ++k){ 225 | const Edge& to2 = To[k]; 226 | node2 = to2.id; 227 | money2 = to2.money; 228 | if(!check(money1, money2)) continue; 229 | int l = phead[node2]; 230 | for(; l < phead[node2 + 1]; ++l) if(To[l].id >= node) break; 231 | for(; l < phead[node2 + 1]; ++l){ 232 | const Edge& to3 = To[l]; 233 | node3 = to3.id; 234 | money3 = to3.money; 235 | if(!check(money2, money3) || node3 == node1) continue; 236 | else if(node3 == node){ 237 | if(check(money3, money1)){ 238 | loop3[0] = node;loop3[1] = node1;loop3[2] = node2; 239 | loop3 += 3; 240 | cnt3++; 241 | } 242 | continue; 243 | } 244 | for(int m = phead[node3]; m < phead[node3 + 1]; ++m){ 245 | const Edge& to4 = To[m]; 246 | node4 = to4.id; 247 | money4 = to4.money; 248 | if(DIS[tid][node4] > 3 || !check(money3, money4)) continue; 249 | else if(node4 == node){ 250 | if(check(money4, money1)){ 251 | loop4[0] = node;loop4[1] = node1; 252 | loop4[2] = node2;loop4[3] = node3; 253 | loop4 += 4; 254 | cnt4++; 255 | } 256 | continue; 257 | }else if(node4 == node1 || node4 == node2) continue; 258 | for(int n = phead[node4]; n < phead[node4 + 1]; ++n){ 259 | const Edge& to5 = To[n]; 260 | node5 = to5.id; 261 | money5 = to5.money; 262 | if(DIS[tid][node5] > 2 || !check(money4, money5)) continue; 263 | else if(node5 == node){ 264 | if(check(money5, money1)){ 265 | loop5[0] = node;loop5[1] = node1;loop5[2] = node2; 266 | loop5[3] = node3;loop5[4] = node4; 267 | loop5 += 5; 268 | cnt5++; 269 | } 270 | continue; 271 | }else if(node5 == node1 || node5 == node2 || node5 == node3) continue; 272 | for(int o = phead[node5]; o < phead[node5 + 1]; ++o){ 273 | const Edge& to6 = To[o]; 274 | node6 = to6.id; 275 | money6 = to6.money; 276 | if(DIS[tid][node6] > 1 || !check(money5, money6)) continue; 277 | else if(DIS[tid][node6] == 1){ 278 | if(node6 == node1 || node6 == node2 || node6 == node3 || node6 == node4) continue; 279 | money = Money[node6]; 280 | if(check(money6, money) && check(money, money1)){ 281 | loop7[0] = node;loop7[1] = node1;loop7[2] = node2;loop7[3] = node3; 282 | loop7[4] = node4;loop7[5] = node5;loop7[6] = node6; 283 | loop7 += 7; 284 | cnt7++; 285 | } 286 | }else if(node6 == node){ 287 | if(check(money6, money1)){ 288 | loop6[0] = node;loop6[1] = node1;loop6[2] = node2; 289 | loop6[3] = node3;loop6[4] = node4;loop6[5] = node5; 290 | loop6 += 6; 291 | cnt6++; 292 | } 293 | } 294 | } 295 | } 296 | } 297 | } 298 | } 299 | } 300 | 301 | DIS[tid][node] = 8; 302 | for(int j = 0; j < reach.size(); ++j) DIS[tid][reach[j]] = 8; // 清空DIS数组 303 | // 统计 304 | loop_size[tid] = (loop3 - loops[tid][0]) / 3 + 305 | (loop4 - loops[tid][1]) / 4 + 306 | (loop5 - loops[tid][2]) / 5 + 307 | (loop6 - loops[tid][3]) / 6 + 308 | (loop7 - loops[tid][4]) / 7; 309 | loops_size[node][0] = cnt3; 310 | loops_size[node][1] = cnt4; 311 | loops_size[node][2] = cnt5; 312 | loops_size[node][3] = cnt6; 313 | loops_size[node][4] = cnt7; 314 | 315 | } 316 | *loop3 = *loop4 = *loop5 = *loop6 = *loop7 = 0xffffffff; 317 | } 318 | 319 | void find_loops(){ 320 | #ifdef DEBUG 321 | cout << __func__ << endl; 322 | Timer t; 323 | #endif 324 | // 开内存用于存环 325 | for(int i = 0; i < nthread; ++i){ 326 | for(int j = 0; j < 5; ++j){ 327 | loops[i][j] = new uint32_t[maxs / 2 * (j + 3)]; 328 | } 329 | } 330 | // 多线程找环 331 | pthread_t threads[nthread]; 332 | int tid[nthread]; 333 | for(int i = 0; i < nthread; ++i){ 334 | tid[i] = i; 335 | pthread_create(&threads[i], NULL, find_thread, (void*)&tid[i]); 336 | } 337 | for(int i = 0; i < nthread; ++i) 338 | pthread_join(threads[i], NULL); 339 | #ifdef DEBUG 340 | int loops_total = 0; 341 | for(int i = 0; i < nthread; ++i){ 342 | loops_total += (loop_size[i]); 343 | } 344 | cout << "loops: " << loops_total << endl; 345 | #endif 346 | } 347 | 348 | uint32_t* Loops[5]; // 合并的结果 349 | size_t accu[8]; // 3-7环的累计和 350 | void merge_loops(){ 351 | /* 352 | 合并各线程找到的环 353 | */ 354 | #ifdef DEBUG 355 | cout << __func__ << endl; 356 | Timer t; 357 | #endif 358 | // 开内存 359 | uint32_t* lptr[5]; 360 | for(int i = 0; i < 5; ++i){ 361 | Loops[i] = new uint32_t[maxs * (i + 3)]; 362 | lptr[i] = Loops[i]; 363 | } 364 | // 辅助指针 365 | uint32_t* ptr[nthread][5]; 366 | for(int i = 0; i < nthread; ++i){ 367 | for(int j = 0; j < 5; ++j){ 368 | ptr[i][j] = loops[i][j]; 369 | } 370 | } 371 | // 合并 372 | uint32_t len; 373 | for(int i = 0; i < 5; ++i){ 374 | for(int j = 0; j < ID.size(); ++j){ 375 | uint32_t* &curr = ptr[which_thread[j]][i]; // 从各个线程取出结果进行合并 376 | if(loops_size[j][i]){ 377 | len = loops_size[j][i] * (i + 3); 378 | memcpy(lptr[i], curr, len * sizeof(uint32_t)); 379 | lptr[i] += len; 380 | curr += len; 381 | } 382 | } 383 | } 384 | // 计算累计和 385 | for(int i = 3; i < 8; ++i){ 386 | accu[i] = accu[i - 1] + (lptr[i - 3] - Loops[i - 3]) / i; 387 | } 388 | } 389 | 390 | inline void convert(uint32_t temp, char* ptr){ 391 | // 转换:uint32 -> char[] 392 | // 仅在多线程建转换表时使用 393 | if(temp < 10){ 394 | ptr[0] = 1; 395 | ptr[1] = temp % 10 + '0'; 396 | }else if(temp < 100){ 397 | ptr[0] = 2; 398 | ptr[1] = temp / 10 % 10 + '0'; 399 | ptr[2] = temp % 10 + '0'; 400 | }else if(temp < 1000){ 401 | ptr[0] = 3; 402 | ptr[1] = temp / 100 % 10 + '0'; 403 | ptr[2] = temp / 10 % 10 + '0'; 404 | ptr[3] = temp % 10 + '0'; 405 | }else if(temp < 10000){ 406 | ptr[0] = 4; 407 | ptr[1] = temp / 1000 % 10 + '0'; 408 | ptr[2] = temp / 100 % 10 + '0'; 409 | ptr[3] = temp / 10 % 10 + '0'; 410 | ptr[4] = temp % 10 + '0'; 411 | }else if(temp < 100000){ 412 | ptr[0] = 5; 413 | ptr[1] = temp / 10000 % 10 + '0'; 414 | ptr[2] = temp / 1000 % 10 + '0'; 415 | ptr[3] = temp / 100 % 10 + '0'; 416 | ptr[4] = temp / 10 % 10 + '0'; 417 | ptr[5] = temp % 10 + '0'; 418 | }else if(temp < 1000000){ 419 | ptr[0] = 6; 420 | ptr[1] = temp / 100000 % 10 + '0'; 421 | ptr[2] = temp / 10000 % 10 + '0'; 422 | ptr[3] = temp / 1000 % 10 + '0'; 423 | ptr[4] = temp / 100 % 10 + '0'; 424 | ptr[5] = temp / 10 % 10 + '0'; 425 | ptr[6] = temp % 10 + '0'; 426 | }else if(temp < 10000000){ 427 | ptr[0] = 7; 428 | ptr[1] = temp / 1000000 % 10 + '0'; 429 | ptr[2] = temp / 100000 % 10 + '0'; 430 | ptr[3] = temp / 10000 % 10 + '0'; 431 | ptr[4] = temp / 1000 % 10 + '0'; 432 | ptr[5] = temp / 100 % 10 + '0'; 433 | ptr[6] = temp / 10 % 10 + '0'; 434 | ptr[7] = temp % 10 + '0'; 435 | }else if(temp < 100000000){ 436 | ptr[0] = 8; 437 | ptr[1] = temp / 10000000 % 10 + '0'; 438 | ptr[2] = temp / 1000000 % 10 + '0'; 439 | ptr[3] = temp / 100000 % 10 + '0'; 440 | ptr[4] = temp / 10000 % 10 + '0'; 441 | ptr[5] = temp / 1000 % 10 + '0'; 442 | ptr[6] = temp / 100 % 10 + '0'; 443 | ptr[7] = temp / 10 % 10 + '0'; 444 | ptr[8] = temp % 10 + '0'; 445 | }else if(temp < 1000000000){ 446 | ptr[0] = 9; 447 | ptr[1] = temp / 100000000 % 10 + '0'; 448 | ptr[2] = temp / 10000000 % 10 + '0'; 449 | ptr[3] = temp / 1000000 % 10 + '0'; 450 | ptr[4] = temp / 100000 % 10 + '0'; 451 | ptr[5] = temp / 10000 % 10 + '0'; 452 | ptr[6] = temp / 1000 % 10 + '0'; 453 | ptr[7] = temp / 100 % 10 + '0'; 454 | ptr[8] = temp / 10 % 10 + '0'; 455 | ptr[9] = temp % 10 + '0'; 456 | }else{ 457 | ptr[0] = 10; 458 | ptr[1] = temp / 1000000000 % 10 + '0'; 459 | ptr[2] = temp / 100000000 % 10 + '0'; 460 | ptr[3] = temp / 10000000 % 10 + '0'; 461 | ptr[4] = temp / 1000000 % 10 + '0'; 462 | ptr[5] = temp / 100000 % 10 + '0'; 463 | ptr[6] = temp / 10000 % 10 + '0'; 464 | ptr[7] = temp / 1000 % 10 + '0'; 465 | ptr[8] = temp / 100 % 10 + '0'; 466 | ptr[9] = temp / 10 % 10 + '0'; 467 | ptr[10] = temp % 10 + '0'; 468 | } 469 | } 470 | 471 | 472 | char conv[maxn][11]; // 转换表 473 | void* conv_thread(void* args){ 474 | // 将映射后的id作为下标,真实id转换为字符串后保存在对应位置. 475 | // 例如: Map[5142132] = 42,则 conv[42] = "5142132"; 476 | // 上述操作可以减少一次逆映射导致的访存 477 | int tid = *(int*)args; 478 | for(int i = tid; i < ID.size(); i += nthread){ 479 | convert(ID[i], conv[i]); 480 | } 481 | } 482 | void convert_init(){ 483 | pthread_t threads[nthread]; 484 | int tid[nthread]; 485 | for(int i = 0; i < nthread; ++i){ 486 | tid[i] = i; 487 | pthread_create(&threads[i], NULL, conv_thread, (void*)&tid[i]); 488 | } 489 | for(int i = 0; i < nthread; ++i) 490 | pthread_join(threads[i], NULL); 491 | } 492 | 493 | 494 | char* buf[nthread]; 495 | int len[nthread]; 496 | void* itoa(void* arg){ 497 | // 将结果转换为字符串 498 | 499 | #ifdef DEBUG 500 | Timer t; 501 | // cout << __func__ << endl; 502 | #endif 503 | int tid = *(int*)arg; 504 | // 任务分割,基本均衡 505 | int start = tid * (accu[7] / nthread); 506 | int end = (tid + 1) * (accu[7] / nthread); 507 | if(tid == nthread - 1) end = accu[7]; 508 | 509 | // 查表转换 510 | int temp = 0, offset; 511 | char* ptr = buf[tid]; 512 | for(int i = start; i < end; ++i){ 513 | for(int j = 3; j < 8; ++j){ // 某些线程可能跨长度,典型为线程0跨越3-7环 514 | if(i < accu[j]){ 515 | offset = (i - accu[j - 1]) * j; 516 | for(int k = 0; k < j; ++k){ 517 | temp = Loops[j - 3][offset + k]; 518 | char& sz = conv[temp][0]; 519 | memcpy(ptr, conv[temp] + 1, sz); 520 | ptr += sz; 521 | *ptr++ = ','; 522 | } 523 | *(ptr - 1) = '\n'; 524 | break; 525 | } 526 | } 527 | } 528 | len[tid] = ptr - buf[tid]; 529 | } 530 | 531 | // 保存结果:mmap + 多线程memcpy 532 | // 线下看起来快,线上和fwrite差不多 533 | char* write_pos; 534 | void* save(void* arg){ 535 | int tid = *(int*)arg; 536 | int start = 0; 537 | if(tid){ 538 | for(int i = 0; i < tid; ++i) start += len[i]; 539 | } 540 | memcpy(write_pos + start, buf[tid], len[tid]); 541 | } 542 | 543 | const char* test_data_file = "/data/test_data.txt"; 544 | const char* result_file = "/projects/student/result.txt"; 545 | 546 | void save_loops(){ 547 | // 保存结果 548 | #ifdef DEBUG 549 | Timer t; 550 | cout << __func__ << endl; 551 | #endif 552 | // 开内存 553 | for(int i = 0; i < nthread; ++i){ 554 | buf[i] = new char[1024 * 1024 * 512]; 555 | } 556 | // 构造转换表 557 | convert_init(); 558 | // 结果转换为字符串 559 | pthread_t threads[nthread]; 560 | int tid[nthread]; 561 | for(int i = 0; i < nthread; ++i){ 562 | tid[i] = i; 563 | pthread_create(&threads[i], NULL, itoa, (void*)&tid[i]); 564 | } 565 | 566 | for(int i = 0; i < nthread; ++i) 567 | pthread_join(threads[i], NULL); 568 | // 环数转换为字符串 569 | char* size_buf = new char[12]; 570 | char* ptr = size_buf; 571 | uint32_t size = accu[7]; 572 | convert(size, ptr); 573 | ptr += (*size_buf++ + 1); 574 | *ptr++ = '\n'; 575 | // mmap写 576 | int total_len = (ptr - size_buf); 577 | for(int i = 0; i < nthread; ++i) total_len += len[i]; 578 | int fd = open(result_file, O_RDWR | O_CREAT , 0666); 579 | lseek(fd, total_len - 1, SEEK_SET); 580 | int ret = write(fd, "\0", 1); 581 | char* ptr_ans = (char*)mmap(NULL, total_len, PROT_WRITE, MAP_SHARED, fd, 0); 582 | close(fd); 583 | // 写环数 584 | memcpy(ptr_ans, size_buf, ptr - size_buf); 585 | write_pos = ptr_ans + (ptr - size_buf); 586 | // 多线程写环 587 | for(int i = 0; i < nthread; ++i){ 588 | pthread_create(&threads[i], NULL, save, (void*)&tid[i]); 589 | } 590 | for(int i = 0; i < nthread; ++i) 591 | pthread_join(threads[i], NULL); 592 | } 593 | 594 | void* sort_thread(void* arg){ 595 | // 对每个节点的邻接表排序:正向正序(保证结果为字典序),反向反序(用于反向dfs提前break退出) 596 | int tid = *(int*)arg; 597 | for(int i = tid; i < ID.size(); i += nthread){ 598 | sort(To + phead[i], To + phead[i] + plen[i], [](const Edge& a, const Edge& b){ 599 | return a.id < b.id; 600 | }); 601 | sort(From + qhead[i], From + qhead[i] + qlen[i], [](const Edge& a, const Edge& b){ 602 | return a.id > b.id; 603 | }); 604 | } 605 | } 606 | 607 | void* build_thread(void* arg){ 608 | // 构造前向星 609 | // 1.统计每个节点邻接点个数 610 | // 2.计算每个节点开始位置 611 | // 3.遍历所有节点构造前向星 612 | int tid = *(int*)arg; 613 | uint32_t from, to, money; 614 | if(tid == 0){ // 正图 615 | int* curlen = new int[ID.size() + 1](); 616 | for(int k = 0; k < nthread; ++k){ 617 | for(int i = 0; i < loadnum[k]; ++i){ 618 | if(loadbuf[k][i][1] != 0xffffffff){ 619 | from = loadbuf[k][i][0]; 620 | plen[from]++; 621 | } 622 | } 623 | } 624 | phead[0] = 0; 625 | for(int i = 1; i <= ID.size(); ++i){ 626 | phead[i] = phead[i - 1] + plen[i - 1]; 627 | } 628 | for(int k = 0; k < nthread; ++k){ 629 | for(int i = 0; i < loadnum[k]; ++i){ 630 | if(loadbuf[k][i][1] != 0xffffffff){ 631 | to = loadbuf[k][i][1]; 632 | from = loadbuf[k][i][0]; 633 | To[phead[from] + curlen[from]].id = to; 634 | To[phead[from] + curlen[from]++].money = loadbuf[k][i][2]; 635 | } 636 | 637 | } 638 | } 639 | 640 | }else{ // 反图 641 | int* curlen = new int[ID.size() + 1](); 642 | for(int k = 0; k < nthread; ++k){ 643 | for(int i = 0; i < loadnum[k]; ++i){ 644 | if(loadbuf[k][i][1] != 0xffffffff){ 645 | to = loadbuf[k][i][1]; 646 | qlen[to]++; 647 | } 648 | } 649 | } 650 | qhead[0] = 0; 651 | for(int i = 1; i <= ID.size(); ++i){ 652 | qhead[i] = qhead[i - 1] + qlen[i - 1]; 653 | } 654 | for(int k = 0; k < nthread; ++k){ 655 | for(int i = 0; i < loadnum[k]; ++i){ 656 | if(loadbuf[k][i][1] != 0xffffffff){ 657 | to = loadbuf[k][i][1]; 658 | from = loadbuf[k][i][0]; 659 | From[qhead[to] + curlen[to]].id = from; 660 | From[qhead[to] + curlen[to]++].money = loadbuf[k][i][2]; 661 | } 662 | } 663 | } 664 | } 665 | } 666 | 667 | char* file; 668 | int file_size; 669 | void* load_thread(void* args){ 670 | // 多线程读图 671 | int tid = *(int*)args; 672 | int size = file_size / nthread; 673 | if(tid == nthread - 1) size = file_size - (nthread - 1) * size; 674 | 675 | char* start = file + tid * (file_size / nthread); 676 | char* curr = start; 677 | 678 | // 确保两个线程不读到分割位置的同一行 679 | if(tid != 0 && *(curr - 1) != '\n') while(*curr++ != '\n'); 680 | 681 | uint32_t from, to , money, temp = 0; 682 | uint8_t state = 0; 683 | char ch; 684 | while(1){ 685 | ch = *curr; 686 | if(ch == ','){ 687 | state ? to = temp : from = temp; 688 | state = 1 - state; 689 | temp = 0; 690 | } 691 | else if(ch == '\r' || ch == '\n'){ 692 | loadbuf[tid][loadnum[tid]][0] = from; 693 | loadbuf[tid][loadnum[tid]][1] = to; 694 | loadbuf[tid][loadnum[tid]][2] = temp; 695 | loadnum[tid]++; 696 | if(ch == '\r') curr++; 697 | if(curr - start + 1 >= size) break; 698 | state = temp = 0; 699 | } 700 | else{ 701 | temp = temp * 10 + ch - '0'; 702 | } 703 | curr++; 704 | } 705 | } 706 | 707 | void* clear_thread(void* args){ 708 | // 删除记录中出度为0的节点 709 | // 并将记录id映射为{ 0 1 2 ... n-1 } 710 | int tid = *(int*)args; 711 | uint32_t from, to; 712 | for(int i = 0; i < loadnum[tid]; ++i){ 713 | to = Map.search(loadbuf[tid][i][1]); 714 | if(to == -1){ 715 | loadbuf[tid][i][1] = 0xffffffff; // 标记这一行转账数据不要了 716 | }else{ 717 | from = Map.search(loadbuf[tid][i][0]); 718 | loadbuf[tid][i][0] = from; 719 | loadbuf[tid][i][1] = to; 720 | } 721 | } 722 | } 723 | 724 | void* shit_thread(void* args){ 725 | // 哈希排序,确保映射前后相对大小不变 726 | int tid = *(int*)args; 727 | uint32_t from , to; 728 | if(tid == 0){ 729 | Map.sort_hash(); 730 | }else{ 731 | sort(ID.begin(), ID.end()); 732 | } 733 | } 734 | 735 | void load_data(){ 736 | /* 737 | @brief: load data from file "/data/test_data.txt" 738 | @method: mmap 739 | */ 740 | #ifdef DEBUG 741 | cout << __func__ << endl; 742 | Timer t; 743 | #endif 744 | // mmap 745 | struct stat statue; 746 | int fd = open(test_data_file, O_RDONLY); 747 | fstat(fd, &statue); 748 | file_size = statue.st_size; 749 | file = (char*)mmap(NULL, statue.st_size, PROT_READ, MAP_PRIVATE, fd, 0); 750 | close(fd); 751 | // 1.多线程读数据到loadbuf中 752 | pthread_t threads[nthread]; 753 | int tid[nthread]; 754 | for(int i = 0; i < nthread; ++i){ 755 | tid[i] = i; 756 | pthread_create(&threads[i], NULL, load_thread, (void*)&tid[i]); 757 | } 758 | for(int i = 0; i < nthread; ++i) 759 | pthread_join(threads[i], NULL); 760 | 761 | // 2.哈希(仅映射 {u, v, w} 中的 u,因为不在 U 中出现的 id 肯定不成环) 762 | for(int k = 0; k < nthread; ++k){ 763 | for(int i = 0; i < loadnum[k]; ++i){ 764 | Map.insert(loadbuf[k][i][0]); 765 | } 766 | } 767 | // 3.哈希排序,确保映射前后相对大小不变 768 | for(int i = 0; i < 2; ++i){ 769 | tid[i] = i; 770 | pthread_create(&threads[i], NULL, shit_thread, (void*)&tid[i]); 771 | } 772 | for(int i = 0; i < 2; ++i) 773 | pthread_join(threads[i], NULL); 774 | 775 | // 4.将不用的记录删除(使得后续不用再查询哈系表) 776 | for(int i = 0; i < nthread; ++i){ 777 | tid[i] = i; 778 | pthread_create(&threads[i], NULL, clear_thread, (void*)&tid[i]); 779 | } 780 | for(int i = 0; i < nthread; ++i) 781 | pthread_join(threads[i], NULL); 782 | 783 | // 5.构造前向星 784 | for(int i = 0; i < 2; ++i){ 785 | tid[i] = i; 786 | pthread_create(&threads[i], NULL, build_thread, (void*)&tid[i]); 787 | } 788 | for(int i = 0; i < 2; ++i) 789 | pthread_join(threads[i], NULL); 790 | 791 | // 6.前向星排序 792 | for(int i = 0; i < nthread; ++i){ 793 | tid[i] = i; 794 | pthread_create(&threads[i], NULL, sort_thread, (void*)&tid[i]); 795 | } 796 | for(int i = 0; i < nthread; ++i) 797 | pthread_join(threads[i], NULL); 798 | } 799 | 800 | int main(int argc, char** argv){ 801 | #ifdef DEBUG 802 | Timer t; 803 | #endif 804 | load_data(); 805 | find_loops(); 806 | merge_loops(); 807 | save_loops(); 808 | #ifdef DEBUG 809 | cout << __func__ << endl; 810 | #endif 811 | return 0; 812 | } 813 | -------------------------------------------------------------------------------- /preliminary contest/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | test_file[0]="test_data_43.txt" 4 | test_file[1]="test_data_9153.txt" 5 | test_file[2]="test_data_697518.txt" 6 | test_file[3]="test_data_19630345.txt" 7 | 8 | 9 | answer_file[0]="result_43.txt" 10 | answer_file[1]="result_9153.txt" 11 | answer_file[2]="result_697518.txt" 12 | answer_file[3]="result_19630345.txt" 13 | 14 | 15 | if [ "$2" = "DDEBUG" ] 16 | then 17 | g++ -O3 $1 -o test -lpthread -DDEBUG 18 | else 19 | g++ -O3 $1 -o test -lpthread 20 | fi 21 | 22 | if [ ! -d "/data/" ] 23 | then 24 | mkdir /data 25 | fi 26 | 27 | if [ ! -d "/projects/student/" ] 28 | then 29 | mkdir /projects 30 | mkdir /projects/student 31 | fi 32 | 33 | 34 | echo "------------------" 35 | 36 | date > log.txt 37 | old=1 38 | cnt=0 39 | total_cnt=0 40 | 41 | for ((i=0;i<${#test_file[@]};i++)) 42 | do 43 | echo ${test_file[$i]} 44 | cp ${test_file[$i]} /data/test_data.txt 45 | if [ -f "/projects/student/result.txt" ] 46 | then 47 | rm /projects/student/result.txt 48 | fi 49 | time ./test 50 | if [ -f "/projects/student/result.txt" ] 51 | then 52 | diff /projects/student/result.txt "${answer_file[$i]}" >> log.txt 53 | len=$(sed -n '$=' log.txt) 54 | if [ $len -gt $old ] 55 | then 56 | echo "Fail!" 57 | else 58 | echo "Pass!" 59 | let cnt+=1 60 | fi 61 | rm /projects/student/result.txt 62 | else 63 | echo "File not generated !" 64 | fi 65 | old=$len 66 | let total_cnt+=1 67 | echo "------------------" 68 | done 69 | 70 | echo "$cnt/$total_cnt Passed" | tee -a ./log.txt 71 | 72 | -------------------------------------------------------------------------------- /warmup/.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "g++.exe build and debug active file", 9 | "type": "cppdbg", 10 | "request": "launch", 11 | "program": "${fileDirname}\\${fileBasenameNoExtension}.exe", 12 | "args": [], 13 | "stopAtEntry": false, 14 | "cwd": "${workspaceFolder}", 15 | "environment": [], 16 | "externalConsole": false, 17 | "MIMode": "gdb", 18 | "miDebuggerPath": "C:\\mingw64\\bin\\gdb.exe", 19 | "setupCommands": [ 20 | { 21 | "description": "为 gdb 启用整齐打印", 22 | "text": "-enable-pretty-printing", 23 | "ignoreFailures": true 24 | } 25 | ], 26 | "preLaunchTask": "g++.exe build active file" 27 | } 28 | ] 29 | } -------------------------------------------------------------------------------- /warmup/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.associations": { 3 | "unordered_map": "cpp", 4 | "thread": "cpp", 5 | "mutex": "cpp", 6 | "iostream": "cpp", 7 | "xutility": "cpp", 8 | "ostream": "cpp", 9 | "vector": "cpp", 10 | "algorithm": "cpp", 11 | "cctype": "cpp", 12 | "chrono": "cpp", 13 | "cmath": "cpp", 14 | "concepts": "cpp", 15 | "cstddef": "cpp", 16 | "cstdint": "cpp", 17 | "cstdio": "cpp", 18 | "cstdlib": "cpp", 19 | "cstring": "cpp", 20 | "ctime": "cpp", 21 | "cwchar": "cpp", 22 | "exception": "cpp", 23 | "fstream": "cpp", 24 | "initializer_list": "cpp", 25 | "iomanip": "cpp", 26 | "ios": "cpp", 27 | "iosfwd": "cpp", 28 | "istream": "cpp", 29 | "limits": "cpp", 30 | "list": "cpp", 31 | "memory": "cpp", 32 | "new": "cpp", 33 | "ratio": "cpp", 34 | "sstream": "cpp", 35 | "stdexcept": "cpp", 36 | "streambuf": "cpp", 37 | "string": "cpp", 38 | "system_error": "cpp", 39 | "tuple": "cpp", 40 | "type_traits": "cpp", 41 | "typeinfo": "cpp", 42 | "utility": "cpp", 43 | "xfacet": "cpp", 44 | "xhash": "cpp", 45 | "xiosbase": "cpp", 46 | "xlocale": "cpp", 47 | "xlocinfo": "cpp", 48 | "xlocmon": "cpp", 49 | "xlocnum": "cpp", 50 | "xloctime": "cpp", 51 | "xmemory": "cpp", 52 | "xstddef": "cpp", 53 | "xstring": "cpp", 54 | "xtr1common": "cpp", 55 | "random": "cpp" 56 | } 57 | } -------------------------------------------------------------------------------- /warmup/.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | "tasks": [ 3 | { 4 | "type": "shell", 5 | "label": "g++.exe build active file", 6 | "command": "C:\\mingw64\\bin\\g++.exe", 7 | "args": [ 8 | "-g", 9 | "${file}", 10 | "-o", 11 | "${fileDirname}\\${fileBasenameNoExtension}.exe" 12 | ], 13 | "options": { 14 | "cwd": "C:\\mingw64\\bin" 15 | } 16 | } 17 | ], 18 | "version": "2.0.0" 19 | } -------------------------------------------------------------------------------- /warmup/hello.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | 22 | using namespace std; 23 | 24 | 25 | // dataset capacity 26 | // COLS是16的倍数 27 | const int COLS = 176; 28 | const int ROWS = 5000; 29 | const int max_datasize = 5000; 30 | 31 | float gain = 1.011 ; 32 | int testBytes = 0; // 测试文件总字节数 33 | 34 | pthread_mutex_t mutex; 35 | pthread_mutexattr_t mutexattr; 36 | 37 | #define cpus 4 38 | 39 | struct Share{ 40 | int cnt; 41 | int num0; 42 | int num1; 43 | float mean0[COLS]; 44 | float mean1[COLS]; 45 | }; 46 | Share* share; 47 | 48 | const char* trainFile = "/data/train_data.txt"; 49 | const char* testFile = "/data/test_data.txt"; 50 | const char* resultFile = "/projects/student/result.txt"; 51 | 52 | char* train_buf, *test_buf; 53 | 54 | void delay(int us){ 55 | chrono::system_clock::time_point start = chrono::system_clock::now(); 56 | while(1){ 57 | if(chrono::duration_cast(chrono::system_clock::now() - start).count() > us) break; 58 | } 59 | } 60 | std::chrono::system_clock::time_point start; 61 | void Merge_sum(float* sum, float* data){ 62 | float32x4_t sum_vec, sum_vec1, sum_vec2, sum_vec3; 63 | float32x4_t data_vec, data_vec1, data_vec2, data_vec3; 64 | for(int i = 0; i < COLS; i += 16){ 65 | sum_vec = vld1q_f32(sum + i); 66 | sum_vec1 = vld1q_f32(sum + i + 4); 67 | sum_vec2 = vld1q_f32(sum + i + 8); 68 | sum_vec3 = vld1q_f32(sum + i + 12); 69 | data_vec = vld1q_f32(data + i); 70 | data_vec1 = vld1q_f32(data + i + 4); 71 | data_vec2 = vld1q_f32(data + i + 8); 72 | data_vec3 = vld1q_f32(data + i + 12); 73 | sum_vec = vaddq_f32(sum_vec, data_vec); 74 | sum_vec1 = vaddq_f32(sum_vec1, data_vec1); 75 | sum_vec2 = vaddq_f32(sum_vec2, data_vec2); 76 | sum_vec3 = vaddq_f32(sum_vec3, data_vec3); 77 | vst1q_f32(sum + i, sum_vec); 78 | vst1q_f32(sum + i + 4, sum_vec1); 79 | vst1q_f32(sum + i + 8, sum_vec2); 80 | vst1q_f32(sum + i + 12, sum_vec3); 81 | } 82 | } 83 | 84 | void loadTrainData(int id){ 85 | 86 | struct stat statue; 87 | int fd = open(trainFile, O_RDONLY); 88 | fstat(fd, &statue); 89 | 90 | train_buf = (char*)mmap(NULL, statue.st_size, PROT_READ, MAP_SHARED, fd, 0); 91 | char* curr = train_buf + id * max_datasize / cpus * 6500; 92 | while(*curr++ != '\n'); 93 | 94 | float* sum0 = (float*)malloc(COLS * sizeof(float)); 95 | float* sum1 = (float*)malloc(COLS * sizeof(float)); 96 | float* data = (float*)malloc(COLS * sizeof(float)); 97 | memset(sum0, 0, COLS * sizeof(float)); 98 | memset(sum1, 0, COLS * sizeof(float)); 99 | float number = 0.0f; 100 | int col_cnt = 0, row_cnt = 0; 101 | int num0_cnt = 0, num1_cnt = 0; 102 | float A[3][58] = { 103 | {0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f}, 104 | {0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f, 0.01f, 0.02f, 0.03f, 0.04f, 0.05f, 0.06f, 0.07f, 0.08f, 0.09f}, 105 | {0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f, 0.001f, 0.002f, 0.003f, 0.004f, 0.005f, 0.006f, 0.007f, 0.008f, 0.009f}, 106 | }; 107 | while(1){ 108 | if(*curr == '-'){ 109 | number = (curr[1]- '0') + A[0][curr[3]] + A[1][curr[4]];// + A[2][curr[5]]; 110 | number = -number; 111 | }else{ 112 | number = (curr[0]- '0') + A[0][curr[2]] + A[1][curr[3]];// + A[2][curr[4]]; 113 | } 114 | data[col_cnt++] = number; 115 | if(col_cnt == COLS){ // 如果一条数据有了 116 | curr += (1000 - COLS) * 6; 117 | while(*curr++ != '\n'); 118 | if(*(curr - 2) == '1'){ 119 | num1_cnt ++; 120 | Merge_sum(sum1, data); 121 | }else{ 122 | num0_cnt ++; 123 | Merge_sum(sum0, data); 124 | } 125 | row_cnt ++; 126 | if(row_cnt >= max_datasize / cpus) break; 127 | col_cnt = 0; 128 | }else{ 129 | 130 | if(*curr == '-') curr += 7; 131 | else curr += 6; 132 | } 133 | } 134 | 135 | pthread_mutex_lock(&mutex); 136 | share->num0 += num0_cnt; 137 | share->num1 += num1_cnt; 138 | float32x4_t mean0_vec, mean1_vec, sum0_vec, sum1_vec; 139 | // float32x4_t div_vec = vdupq_n_f32(1.0f / max_datasize); 140 | for(int i = 0; i < COLS; i += 4){ 141 | mean0_vec = vld1q_f32(share->mean0 + i); 142 | mean1_vec = vld1q_f32(share->mean1 + i); 143 | sum0_vec = vld1q_f32(sum0 + i); 144 | sum1_vec = vld1q_f32(sum1 + i); 145 | mean0_vec = vaddq_f32(mean0_vec, sum0_vec); 146 | mean1_vec = vaddq_f32(mean1_vec, sum1_vec); 147 | vst1q_f32(share->mean0 + i, mean0_vec); 148 | vst1q_f32(share->mean1 + i, mean1_vec); 149 | } 150 | if(share->cnt == cpus - 1){ // 最后一个进程负责计算均值 151 | float32x4_t num0_vec = vdupq_n_f32(1.0f / share->num0); 152 | float32x4_t num1_vec = vdupq_n_f32(1.0f / share->num1); 153 | for(int i = 0; i < COLS; i += 4){ 154 | mean0_vec = vld1q_f32(share->mean0 + i); 155 | mean1_vec = vld1q_f32(share->mean1 + i); 156 | mean0_vec = vmulq_f32(mean0_vec, num0_vec); 157 | mean1_vec = vmulq_f32(mean1_vec, num1_vec); 158 | vst1q_f32(share->mean0 + i, mean0_vec); 159 | vst1q_f32(share->mean1 + i, mean1_vec); 160 | } 161 | share->cnt = 0; 162 | } 163 | else share->cnt++; 164 | pthread_mutex_unlock(&mutex); 165 | free(sum0); 166 | free(sum1); 167 | free(data); 168 | } 169 | 170 | int predict_one(float* data){ 171 | float32x4_t mean0_vec, mean1_vec, data_vec, temp0_vec, temp1_vec; 172 | float32x4_t mean0_vec1, mean1_vec1, data_vec1, temp0_vec1, temp1_vec1; 173 | float32x4_t mean0_vec2, mean1_vec2, data_vec2, temp0_vec2, temp1_vec2; 174 | float32x4_t mean0_vec3, mean1_vec3, data_vec3, temp0_vec3, temp1_vec3; 175 | float32x4_t err0_vec = vdupq_n_f32(0.0f); 176 | float32x4_t err1_vec = vdupq_n_f32(0.0f); 177 | for(int i = 0; i < COLS; i += 16){ 178 | mean0_vec = vld1q_f32(share->mean0 + i); 179 | mean1_vec = vld1q_f32(share->mean1 + i); 180 | data_vec = vld1q_f32(data + i); 181 | temp0_vec = vsubq_f32(data_vec, mean0_vec); 182 | temp1_vec = vsubq_f32(data_vec, mean1_vec); 183 | err0_vec = vmlaq_f32(err0_vec, temp0_vec, temp0_vec); 184 | err1_vec = vmlaq_f32(err1_vec, temp1_vec, temp1_vec); 185 | 186 | mean0_vec1 = vld1q_f32(share->mean0 + i + 4); 187 | mean1_vec1 = vld1q_f32(share->mean1 + i + 4); 188 | data_vec1 = vld1q_f32(data + i + 4); 189 | temp0_vec1 = vsubq_f32(data_vec1, mean0_vec1); 190 | temp1_vec1 = vsubq_f32(data_vec1, mean1_vec1); 191 | err0_vec = vmlaq_f32(err0_vec, temp0_vec1, temp0_vec1); 192 | err1_vec = vmlaq_f32(err1_vec, temp1_vec1, temp1_vec1); 193 | 194 | mean0_vec2 = vld1q_f32(share->mean0 + i + 8); 195 | mean1_vec2 = vld1q_f32(share->mean1 + i + 8); 196 | data_vec2 = vld1q_f32(data + i + 8); 197 | temp0_vec2 = vsubq_f32(data_vec2, mean0_vec2); 198 | temp1_vec2 = vsubq_f32(data_vec2, mean1_vec2); 199 | err0_vec = vmlaq_f32(err0_vec, temp0_vec2, temp0_vec2); 200 | err1_vec = vmlaq_f32(err1_vec, temp1_vec2, temp1_vec2); 201 | 202 | mean0_vec3 = vld1q_f32(share->mean0 + i + 12); 203 | mean1_vec3 = vld1q_f32(share->mean1 + i + 12); 204 | data_vec3 = vld1q_f32(data + i + 12); 205 | temp0_vec3 = vsubq_f32(data_vec3, mean0_vec3); 206 | temp1_vec3 = vsubq_f32(data_vec3, mean1_vec3); 207 | err0_vec = vmlaq_f32(err0_vec, temp0_vec3, temp0_vec3); 208 | err1_vec = vmlaq_f32(err1_vec, temp1_vec3, temp1_vec3); 209 | } 210 | float32x2_t r0 = vadd_f32(vget_high_f32(err0_vec), vget_low_f32(err0_vec)); 211 | float err0 = vget_lane_f32(vpadd_f32(r0, r0), 0); 212 | float32x2_t r1 = vadd_f32(vget_high_f32(err1_vec), vget_low_f32(err1_vec)); 213 | float err1 = vget_lane_f32(vpadd_f32(r1, r1), 0); 214 | return err0 * gain > err1; 215 | } 216 | 217 | char* write_buf; 218 | int test_size; 219 | void loadTestDataAndPredict(int id){ 220 | 221 | struct stat statue; 222 | int fd = open(testFile, O_RDONLY); 223 | fstat(fd, &statue); 224 | test_buf = (char*)mmap(NULL, statue.st_size, PROT_READ, MAP_SHARED, fd, 0); 225 | char* curr = test_buf + id * 20000 * 6000 / cpus; 226 | 227 | int size = 5000; 228 | if(share->num0 > share->num1) gain = 1.0 / gain; 229 | float* data = (float*)malloc(COLS * sizeof(float)); 230 | float A[3][58] = { 231 | {0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f}, 232 | {0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f, 0.01f, 0.02f, 0.03f, 0.04f, 0.05f, 0.06f, 0.07f, 0.08f, 0.09f}, 233 | {0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f,0.0f, 0.001f, 0.002f, 0.003f, 0.004f, 0.005f, 0.006f, 0.007f, 0.008f, 0.009f}, 234 | }; 235 | write_buf = (char*)malloc(42000 / cpus * sizeof(char)); 236 | memset(write_buf, '\n', 42000 / cpus * sizeof(char)); 237 | for(int i = 0; i < size; ++i){ 238 | for(int j = 0; j < COLS; j ++){ 239 | data[j] = curr[0] - '0' + A[0][curr[2]] + A[1][curr[3]];// + A[2][curr[4]];; 240 | curr += 6; 241 | } 242 | write_buf[2 * i] = predict_one(data) ? '1' : '0'; 243 | curr += (1000 - COLS) * 6; 244 | } 245 | test_size = size; 246 | free(data); 247 | } 248 | 249 | void accuracy(){ 250 | FILE* fp1 = fopen("answer.txt", "r"); 251 | FILE* fp2 = fopen(resultFile, "r"); 252 | if(fp1 == nullptr || fp2 == nullptr){ 253 | cout << "open failed." << endl; 254 | return; 255 | } 256 | int a, b, cnt = 0, total = 0, cntp = 0, cntn = 0; 257 | while(fscanf(fp1, "%d", &a) != EOF && fscanf(fp2, "%d", &b) != EOF){ 258 | if(a == b) cnt++; 259 | else if(a == 1) cntp++; 260 | else cntn++; 261 | total++; 262 | } 263 | cout << setprecision(4); 264 | cout << "Acc: " << cnt * 1.0f / total << " " << cntp << " " << cntn << endl; 265 | fclose(fp1); 266 | fclose(fp2); 267 | } 268 | int main(){ 269 | start = chrono::system_clock::now(); 270 | 271 | // 2. 创建均值的共享内存 272 | int shmid = shmget((key_t)823, sizeof(Share), 0666 | IPC_CREAT); 273 | share = (Share*)shmat(shmid, 0, 0); 274 | 275 | share->cnt = 0; 276 | share->num0 = 0; 277 | share->num1 = 0; 278 | memset(share->mean0, 0, COLS * sizeof(float)); 279 | memset(share->mean1, 0, COLS * sizeof(float)); 280 | 281 | // 3. 创建互斥锁 282 | pthread_mutexattr_init(&mutexattr); 283 | pthread_mutexattr_setpshared(&mutexattr, PTHREAD_PROCESS_SHARED); 284 | pthread_mutex_init(&mutex, &mutexattr); 285 | #ifdef DEBUG 286 | cout << "fork start: "<< chrono::duration_cast(chrono::system_clock::now() - start).count() << endl; 287 | #endif 288 | // 4. 创建进程 289 | int id; 290 | for(id = 0; id < cpus; ++id){ 291 | pid_t pid = fork(); 292 | if(pid == 0){ // 子进程 293 | #ifdef DEBUG 294 | cout << "process " << id << "train start: "<< chrono::duration_cast(chrono::system_clock::now() - start).count() << endl; 295 | #endif 296 | loadTrainData(id); 297 | #ifdef DEBUG 298 | cout << "process " << id << "train finish: "<< chrono::duration_cast(chrono::system_clock::now() - start).count() << endl; 299 | #endif 300 | while(1){ 301 | if(share->cnt == 0) break; 302 | delay(1); 303 | } 304 | #ifdef DEBUG 305 | cout << "process " << id << "predict start: "<< chrono::duration_cast(chrono::system_clock::now() - start).count() << endl; 306 | #endif 307 | loadTestDataAndPredict(id); 308 | #ifdef DEBUG 309 | cout << "process " << id << "predict finish: "<< chrono::duration_cast(chrono::system_clock::now() - start).count() << endl; 310 | #endif 311 | while(1){ 312 | if(share->cnt == id) break; 313 | delay(1); 314 | } 315 | #ifdef DEBUG 316 | cout << "process " << id << "write start: "<< chrono::duration_cast(chrono::system_clock::now() - start).count() << endl; 317 | #endif 318 | FILE* fp; 319 | if(id == 0) 320 | fp = fopen(resultFile, "w"); 321 | else 322 | fp = fopen(resultFile, "a"); 323 | fwrite(write_buf, 1, 2 * test_size * sizeof(char), fp); 324 | free(write_buf); 325 | fclose(fp); 326 | #ifdef DEBUG 327 | cout << "process " << id << "write finish: "<< chrono::duration_cast(chrono::system_clock::now() - start).count() << endl; 328 | #endif 329 | pthread_mutex_lock(&mutex); 330 | share->cnt ++; 331 | pthread_mutex_unlock(&mutex); 332 | 333 | usleep(1000); 334 | exit(0); 335 | } 336 | } 337 | #ifdef DEBUG 338 | cout << "fork finish: "<< chrono::duration_cast(chrono::system_clock::now() - start).count() << endl; 339 | #endif 340 | while(1){ 341 | if(share->cnt == cpus) break; 342 | usleep(100); 343 | } 344 | // accuracy(); 345 | // 4. 销毁共享内存 346 | shmdt(share); 347 | shmctl(shmid, IPC_RMID, 0); 348 | return 0; 349 | } --------------------------------------------------------------------------------