├── methods ├── __init__.py ├── SA.py ├── Hopfield.py └── GA.py ├── datas ├── __init__.py ├── ChinaCitys.txt └── citydata.py ├── tools ├── __init__.py └── plot.py ├── README.md └── TSP.py /methods/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datas/__init__.py: -------------------------------------------------------------------------------- 1 | # _*_ coding:utf-8 _*_ 2 | __author__ = 'CHEN Shen' -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | # _*_ coding:utf-8 _*_ 2 | __author__ = 'CHEN Shen' -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TSP-Algorithms 2 | 3 | 利用模拟退火算法、遗传算法等解决旅行商问题。 4 | 5 | 因为用于解决TSP问题的算法有很多种,遂将其整合成一个代码框架。 6 | 7 | ## Directory structure 8 | 9 | ``` 10 | ├─ datas 数据集 11 | │ ├─ ChinaCitys.txt 中国34个城市的经纬度数据 12 | │ ├─ __init__.py 13 | │ └─ citydata.py 包含多个城市数据集 14 | ├─ methods 15 | │ ├─ GA.py 遗传算法 16 | │ ├─ SA.py 模拟退火算法 17 | │ ├─ Hopfield.py Hopfield神经网络 18 | │ └─ __init__.py 19 | ├─ tools 工具 20 | │ ├─ __init__.py 21 | │ └─ plot.py 作图工具 22 | ├─ README.md 23 | └─ TSP.py 旅行商问题 24 | ``` 25 | 26 | ## Usage 27 | 28 | ``` 29 | python TSP.py 30 | ``` 31 | -------------------------------------------------------------------------------- /datas/ChinaCitys.txt: -------------------------------------------------------------------------------- 1 | 北京 116.46 39.92 2 | 天津 117.2 39.13 3 | 上海 121.48 31.22 4 | 重庆 106.54 29.59 5 | 拉萨 91.11 29.97 6 | 乌鲁木齐 87.68 43.77 7 | 银川 106.27 38.47 8 | 呼和浩特 111.65 40.82 9 | 南宁 108.33 22.84 10 | 哈尔滨 126.63 45.75 11 | 长春 125.35 43.88 12 | 沈阳 123.38 41.8 13 | 石家庄 114.48 38.03 14 | 太原 112.53 37.87 15 | 西宁 101.74 36.56 16 | 济南 117 36.65 17 | 郑州 113.6 34.76 18 | 南京 118.78 32.04 19 | 合肥 117.27 31.86 20 | 杭州 120.19 30.26 21 | 福州 119.3 26.08 22 | 南昌 115.89 28.68 23 | 长沙 113 28.21 24 | 武汉 114.31 30.52 25 | 广州 113.23 23.16 26 | 台北 121.5 25.05 27 | 海口 110.35 20.02 28 | 兰州 103.73 36.03 29 | 西安 108.95 34.27 30 | 成都 104.06 30.67 31 | 贵阳 106.71 26.57 32 | 昆明 102.73 25.04 33 | 香港 114.1 22.2 34 | 澳门 113.33 22.13 -------------------------------------------------------------------------------- /tools/plot.py: -------------------------------------------------------------------------------- 1 | # _*_ coding:utf-8 _*_ 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def figure(): 8 | """ 9 | 打开绘图窗口 10 | :return: 11 | """ 12 | plt.figure() 13 | 14 | 15 | def plot_city(citys, path): 16 | """ 17 | 绘制将城市坐标以及城市之间的连线 18 | :param citys: 城市点数据 19 | :param path: 城市路径数据 20 | :return: 21 | """ 22 | paths = np.append(path, path[0]) 23 | plt.clf() 24 | plt.scatter(citys[:, 0], citys[:, 1], color='g') 25 | plt.plot(citys[paths, 0], citys[paths, 1], 'r') 26 | # plt.pause(0.001) 27 | plt.show() 28 | 29 | 30 | def plot_iter_curve(scores, title=''): 31 | """ 32 | 绘制每次迭代过程中算法表现的曲线 33 | :param scores: 每次迭代的算法表现 34 | :return: 35 | """ 36 | plt.plot(scores) 37 | plt.title(title) 38 | plt.xlabel('iter t') 39 | plt.ylabel('score d') 40 | plt.show() 41 | 42 | 43 | def draw_H_and_E(citys, H_path, energys): 44 | """ 45 | 可视化画出哈密顿回路和能量趋势 46 | :param citys: 47 | :param H_path: 48 | :param energys: 49 | :return: 50 | """ 51 | fig = plt.figure() 52 | # 绘制哈密顿回路 53 | ax1 = fig.add_subplot(121) 54 | plt.xlim(0, 7) 55 | plt.ylim(0, 7) 56 | for (from_, to_) in H_path: 57 | p1 = plt.Circle(citys[from_], 0.2, color='red') 58 | p2 = plt.Circle(citys[to_], 0.2, color='red') 59 | ax1.add_patch(p1) 60 | ax1.add_patch(p2) 61 | ax1.plot((citys[from_][0], citys[to_][0]), (citys[from_][1], citys[to_][1]), color='red') 62 | # ax1.annotate(to_, xy=citys[to_], xytext=(-8, -4), textcoords='offset points', fontsize=8) 63 | ax1.axis('equal') 64 | ax1.grid() 65 | # 绘制能量趋势图 66 | ax2 = fig.add_subplot(122) 67 | ax2.plot(np.arange(0, len(energys), 10), energys[::10], color='red') 68 | plt.show() -------------------------------------------------------------------------------- /datas/citydata.py: -------------------------------------------------------------------------------- 1 | # _*_ coding:utf-8 _*_ 2 | 3 | import numpy as np 4 | from numpy import random 5 | 6 | 7 | def city_random_generator(nums): 8 | """ 9 | 随机生成nums个城市数据 10 | :param nums: 城市个数 11 | :return: nums个城市数据 12 | """ 13 | citys = [] 14 | n = 0 15 | while n < nums: 16 | x = random.randn() 17 | y = random.randn() 18 | if [x, y] not in citys: 19 | citys.append([x, y]) 20 | n = n + 1 21 | return np.array(citys) 22 | 23 | 24 | def city10(): 25 | """ 26 | 10个城市数据,坐标数据都在[0, 1]范围内 27 | :return: 10个城市数据 28 | """ 29 | citys = np.array([[0.4000, 0.4439], [0.2439, 0.1463], [0.1707, 0.2293], [0.2293, 0.7610], [0.5171, 0.9414], 30 | [0.8732, 0.6536], [0.6878, 0.5219], [0.8488, 0.3609], [0.6683, 0.2536], [0.6195, 0.2634]]) 31 | return citys 32 | 33 | 34 | def city52(): 35 | """ 36 | 来源于网上的52个城市数据 37 | :return: 52个城市数据 38 | """ 39 | citys = np.array([[565.0, 575.0], [25.0, 185.0], [345.0, 750.0], [945.0, 685.0], [845.0, 655.0], 40 | [880.0, 660.0], [25.0, 230.0], [525.0, 1000.0], [580.0, 1175.0], [650.0, 1130.0], 41 | [1605.0, 620.0], [1220.0, 580.0], [1465.0, 200.0], [1530.0, 5.0], [845.0, 680.0], 42 | [725.0, 370.0], [145.0, 665.0], [415.0, 635.0], [510.0, 875.0], [560.0, 365.0], 43 | [300.0, 465.0], [520.0, 585.0], [480.0, 415.0], [835.0, 625.0], [975.0, 580.0], 44 | [1215.0, 245.0], [1320.0, 315.0], [1250.0, 400.0], [660.0, 180.0], [410.0, 250.0], 45 | [420.0, 555.0], [575.0, 665.0], [1150.0, 1160.0], [700.0, 580.0], [685.0, 595.0], 46 | [685.0, 610.0], [770.0, 610.0], [795.0, 645.0], [720.0, 635.0], [760.0, 650.0], 47 | [475.0, 960.0], [95.0, 260.0], [875.0, 920.0], [700.0, 500.0], [555.0, 815.0], 48 | [830.0, 485.0], [1170.0, 65.0], [830.0, 610.0], [605.0, 625.0], [595.0, 360.0], 49 | [1340.0, 725.0], [1740.0, 245.0]]) 50 | return citys 51 | 52 | 53 | def china_citys(): 54 | """ 55 | 中国34个城市的经纬度数据 56 | :return: 34个城市数据 57 | """ 58 | citys = [] 59 | f = open("./datas/ChinaCitys.txt", "r", encoding='UTF-8') 60 | while True: 61 | city = str(f.readline()) 62 | if not city: break 63 | city = city.replace("\n", "") 64 | city = city.split("\t") 65 | citys.append([float(city[1]), float(city[2])]) 66 | return np.array(citys) -------------------------------------------------------------------------------- /methods/SA.py: -------------------------------------------------------------------------------- 1 | # _*_ coding:utf-8 _*_ 2 | 3 | import numpy as np 4 | 5 | 6 | class SA(object): 7 | """ 8 | 模拟退火算法 9 | """ 10 | 11 | def __init__(self, n, T=280, L_len=100, alpha=0.92, energy_fun=lambda S:1): 12 | self.n = n 13 | self.T = T # 初始温度 14 | self.L = L_len * self.n # 每个温度下的迭代次数 15 | self.alpha = alpha # 温度下降缩减因子 16 | self.S = np.arange(self.n) # 初始状态 17 | self.energy = energy_fun 18 | 19 | def neighbors(self, S): 20 | """ 21 | 采用两点交换法从状态S的邻域中随机选择 22 | :param S: 当前状态 23 | :return: 随机选择的状态 24 | """ 25 | S_neibor = np.copy(S) 26 | u = np.random.randint(0, self.n) 27 | v = np.random.randint(0, self.n) 28 | while u == v: 29 | v = np.random.randint(0, self.n) 30 | S_neibor[u], S_neibor[v] = S_neibor[v], S_neibor[u] 31 | return S_neibor 32 | 33 | def neighbors2(self, S): 34 | """ 35 | 采用2交换法从状态S的邻域中随机选择 36 | :param S: 当前状态 37 | :return: 随机选择的状态 38 | """ 39 | S_neibor = np.copy(S) 40 | u = np.random.randint(0, self.n) 41 | v = np.random.randint(0, self.n) 42 | if u > v: 43 | u, v = v, u 44 | while u == v: 45 | v = np.random.randint(0, self.n) 46 | temp = S_neibor[u:v] 47 | S_neibor[u:v] = temp[::-1] 48 | return S_neibor 49 | 50 | def anneal(self): 51 | """ 52 | 一步退火过程 53 | :return: 54 | """ 55 | print('search on T:{}'.format(self.T)) 56 | for i in range(self.L): 57 | E_pre = self.energy(self.S) 58 | S_now = self.neighbors2(self.S) 59 | E_now = self.energy(S_now) 60 | if (E_now < E_pre) or (np.exp((E_pre - E_now) / self.T) >= np.random.rand()): 61 | self.S = S_now 62 | 63 | def search(self): 64 | """ 65 | 模拟退火搜索过程 66 | :return: 搜索结束后的解 67 | """ 68 | Ts = [] 69 | Es = [] 70 | while self.T >= 0.1: 71 | print('search on T:{}'.format(self.T)) 72 | for i in range(self.L): 73 | E_pre = self.energy(self.S) 74 | S_now = self.neighbors2(self.S) 75 | E_now = self.energy(S_now) 76 | if (E_now < E_pre) or (np.exp((E_pre - E_now) / self.T) >= np.random.rand()): 77 | self.S = S_now 78 | 79 | Ts.append(self.T) 80 | E_now = self.energy(self.S) 81 | Es.append(E_now) 82 | print(E_now) 83 | 84 | # 判断是否达到终止状态 85 | self.T = self.T * self.alpha 86 | print(self.S) 87 | print('finished\n') 88 | 89 | return Ts, Es 90 | 91 | -------------------------------------------------------------------------------- /TSP.py: -------------------------------------------------------------------------------- 1 | # _*_ coding:utf-8 _*_ 2 | 3 | import numpy as np 4 | from datas.citydata import china_citys 5 | from tools.plot import figure, plot_city, plot_iter_curve 6 | from methods.GA import GA 7 | from methods.SA import SA 8 | from methods.Hopfield import Hopfield 9 | 10 | 11 | class TSP(object): 12 | """ 13 | 旅行商问题 14 | """ 15 | 16 | def __init__(self, plot=False): 17 | self.citys = china_citys() 18 | self.n = len(self.citys) 19 | self.dists = self.init_dist() 20 | self.best_path = None 21 | self.plot = plot 22 | 23 | def init_dist(self): 24 | """ 25 | 根据城市数据计算每个城市间的距离 26 | :return: 城市的距离矩阵 27 | """ 28 | dists = np.zeros((self.n, self.n)) 29 | for i in range(self.n): 30 | for j in range(i, self.n): 31 | dists[i][j] = dists[j][i] = np.linalg.norm(self.citys[i] - self.citys[j]) 32 | return dists 33 | 34 | def calc_dist(self, path): 35 | """ 36 | 计算路径path下的城市总距离 37 | :param path: 完成TSP问题的一个城市路径 38 | :return: 城市总距离 39 | """ 40 | dist_sum = 0.0 41 | for i in range(self.n - 1): 42 | dist_sum = dist_sum + self.dists[path[i]][path[i + 1]] 43 | dist_sum = dist_sum + self.dists[path[self.n - 1]][path[0]] 44 | return dist_sum 45 | 46 | def search(self, method='SA'): 47 | """ 48 | 利用method解决TSP问题 49 | :param method: 采用的算法 50 | :return: 51 | """ 52 | 53 | if self.plot: 54 | figure() 55 | 56 | scores = [] 57 | 58 | # 模拟退火算法 59 | if method == 'SA': 60 | sa = SA(self.n, energy_fun=self.calc_dist) 61 | # sa.search() 62 | 63 | while sa.T >= 0.1: 64 | sa.anneal() 65 | 66 | score = sa.energy(sa.S) 67 | scores.append(score) 68 | print('search on T:{}'.format(sa.T)) 69 | print(score) 70 | 71 | if self.plot: 72 | # plot_city(self.citys, sa.S) 73 | pass 74 | 75 | # 温度衰减 76 | sa.T = sa.T * sa.alpha 77 | 78 | self.best_path = sa.S 79 | 80 | # 遗传算法 81 | elif method == 'GA': 82 | ga = GA(self.n, 100, 0.8, 0.05, lambda gene: 1.0 / self.calc_dist(gene)) 83 | # ga.evolution() 84 | 85 | while ga.generation < 3000: 86 | ga.generate_next() 87 | 88 | score = 1.0 / ga.best.score 89 | scores.append(score) 90 | print("generation: {}".format(ga.generation)) 91 | print(score) 92 | 93 | if (ga.generation % 30 == 0) and self.plot: 94 | # plot_city(self.citys, ga.best.gene) 95 | pass 96 | 97 | self.best_path = ga.best.gene 98 | for p in ga.population: 99 | print(1/p.score) 100 | 101 | # Hopfield神经网络 102 | elif method == 'Hopfield': 103 | network = Hopfield(self.n, 0.0009, 0.0001, self.calc_dist) 104 | self.best_path, scores = network.train(50000, self.dists) 105 | 106 | print("best path: {}".format(self.best_path)) 107 | if self.plot: 108 | plot_iter_curve(scores, 'China Citys') 109 | plot_city(self.citys, self.best_path) 110 | 111 | 112 | if __name__ == '__main__': 113 | tsp = TSP(True) 114 | tsp.search('Hopfield') 115 | -------------------------------------------------------------------------------- /methods/Hopfield.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Hopfield(object): 5 | """ 6 | 连续型——Hopfield神经网络求解TSP 7 | 1、初始化权值(A,D,U0) 8 | 2、计算N个城市的距离矩阵dxy 9 | 3、初始化神经网络的输入电压Uxi和输出电压Vxi 10 | 4、利用动力微分方程计算:dUxi/dt 11 | 5、由一阶欧拉方法更新计算:Uxi(t+1) = Uxi(t) + dUxi/dt * step 12 | 6、由非线性函数sigmoid更新计算:Vxi(t) = 0.5 * (1 + th(Uxi/U0)) 13 | 7、计算能量函数E 14 | 8、检查路径是否合法 15 | """ 16 | 17 | def __init__(self, n, u0, step, func=lambda x:1): 18 | self.N = n 19 | self.U0 = u0 20 | self.step = step 21 | self.dist_func = func 22 | self.A = n * n 23 | self.D = n / 2 24 | 25 | def calc_du(self, V, distance): 26 | """ 27 | 动态方程计算微分方程du 28 | :param distance: 距离矩阵 29 | :return: 导数du 30 | """ 31 | N = self.N 32 | a = np.sum(V, axis=0) - 1 33 | b = np.sum(V, axis=1) - 1 34 | t1 = np.zeros((N, N)) 35 | t2 = np.zeros((N, N)) 36 | for i in range(N): 37 | for j in range(N): 38 | t1[i, j] = a[j] 39 | for i in range(N): 40 | for j in range(N): 41 | t2[j, i] = b[j] 42 | c_1 = V[:, 1:N] 43 | c_0 = np.zeros((N, 1)) 44 | c_0[:, 0] = V[:, 0] 45 | c = np.concatenate((c_1, c_0), axis=1) 46 | c = np.dot(distance, c) 47 | return -self.A * (t1 + t2) - self.D * c 48 | 49 | def calc_U(self, U, du, step): 50 | """ 51 | 更新神经网络的输入电压U 52 | :param du: 电压导数 53 | :param step: 步长 54 | :return: 输入电压 55 | """ 56 | return U + du * step 57 | 58 | def calc_V(self, U, U0): 59 | """ 60 | # 更新神经网络的输出电压V 61 | :param U0: 初始电压 62 | :return: 输出电压 63 | """ 64 | return 1 / 2 * (1 + np.tanh(U / U0)) 65 | 66 | def calc_energy(self, V, distance): 67 | """ 68 | 计算当前网络的能量 69 | :param distance: 距离矩阵 70 | :return: 能量 71 | """ 72 | t1 = np.sum(np.power(np.sum(V, axis=0) - 1, 2)) 73 | t2 = np.sum(np.power(np.sum(V, axis=1) - 1, 2)) 74 | idx = [i for i in range(1, self.N)] 75 | idx = idx + [0] 76 | Vt = V[:, idx] 77 | t3 = distance * Vt 78 | t3 = np.sum(np.sum(np.multiply(V, t3))) 79 | e = 0.5 * (self.A * (t1 + t2) + self.D * t3) 80 | return e 81 | 82 | def check_path(self, V): 83 | """ 84 | 检查路径的正确性 85 | :return: 路径 86 | """ 87 | N = self.N 88 | route = [] 89 | for i in range(N): 90 | mm = np.max(V[:, i]) 91 | for j in range(N): 92 | if V[j, i] == mm: 93 | route += [j] 94 | break 95 | return route 96 | 97 | def train(self, num_iter, distance): 98 | """ 99 | 训练网络 100 | :param num_iter: 迭代次数 101 | :param distance: 距离矩阵 102 | :return: 最佳路径,网络能量 103 | """ 104 | U = 1 / 2 * self.U0 * np.log(self.N - 1) + (2 * (np.random.random((self.N, self.N))) - 1) 105 | V = self.calc_V(U, self.U0) 106 | energys = np.array([0.0 for x in range(num_iter)]) 107 | best_distance = np.inf 108 | best_route = [] 109 | 110 | for i in range(num_iter): 111 | du = self.calc_du(V, distance) 112 | U = self.calc_U(U, du, self.step) 113 | V = self.calc_V(U, self.U0) 114 | energys[i] = self.calc_energy(V, distance) 115 | route = self.check_path(V) 116 | if len(np.unique(route)) == self.N: 117 | route.append(route[0]) 118 | dis = self.dist_func(route) 119 | if dis < best_distance: 120 | best_distance = dis 121 | best_route = route 122 | print("iter {}: dist:{}, energy:{}".format(i, best_distance, energys[i])) 123 | print("route: {}".format(best_route)) 124 | # H_path = [] 125 | # [H_path.append((route[i], route[i + 1])) for i in range(len(route) - 1)] 126 | return best_route, energys 127 | -------------------------------------------------------------------------------- /methods/GA.py: -------------------------------------------------------------------------------- 1 | # _*_ coding:utf-8 _*_ 2 | 3 | import random 4 | 5 | 6 | class GA(object): 7 | """ 8 | 遗传算法 9 | """ 10 | 11 | def __init__(self, gene_length, population_nums=100, cross_prob=0.8, mutation_prob=.03, match_fun=lambda gene:1): 12 | self.gene_length = gene_length 13 | self.population_nums = population_nums 14 | self.cross_prob = cross_prob 15 | self.mutation_prob = mutation_prob 16 | self.match_fun = match_fun 17 | self.generation = 0 18 | self.total_score = 0.0 19 | self.best = None 20 | self.elites = None 21 | self.elites_num = 20 22 | 23 | self.init_population() 24 | 25 | def init_population(self): 26 | """ 27 | 种群初始化 28 | :return: 初始化种群 29 | """ 30 | self.population = [] 31 | for i in range(self.population_nums): 32 | gene = list(range(self.gene_length)) 33 | random.shuffle(gene) 34 | self.population.append(Life(gene)) 35 | self.evaluate() 36 | 37 | def evaluate(self): 38 | self.total_score = 0.0 39 | for p in self.population: 40 | p.score = self.match_fun(p.gene) 41 | self.total_score = self.total_score + p.score 42 | 43 | self.elites = sorted(self.population, key=lambda x:x.score, reverse=True) 44 | self.best = self.elites[0] 45 | 46 | def cross(self, parent1, parent2): 47 | """ 48 | 交叉运算 49 | :param parent1: 父代1 50 | :param parent2: 父代2 51 | :return: 子代 52 | """ 53 | i1 = random.randint(0, self.gene_length - 1) 54 | i2 = random.randint(i1, self.gene_length - 1) 55 | temp_gene = parent2.gene[i1:i2] 56 | next_gene = [] 57 | p1len = 0 58 | for g in parent1.gene: 59 | if p1len == i1: 60 | next_gene.extend(temp_gene) 61 | p1len += 1 62 | if g not in temp_gene: 63 | next_gene.append(g) 64 | p1len += 1 65 | return next_gene 66 | 67 | def mutation(self, gene, method='order'): 68 | """ 69 | 变异运算 70 | :param gene: 要进行变异的基因 71 | :return: 变异后的新基因 72 | """ 73 | if method == 'order': 74 | i1 = random.randint(0, self.gene_length - 1) 75 | i2 = random.randint(0, self.gene_length - 1) 76 | gene[i1], gene[i2] = gene[i2], gene[i1] 77 | elif method == 'shuffle': 78 | i1 = random.randint(0, self.gene_length - 2) 79 | i2 = random.randint(i1+1, self.gene_length - 1) 80 | temp_gene = gene[i1:i2] 81 | random.shuffle(temp_gene) 82 | gene[i1:i2] = temp_gene 83 | elif method == 'position': 84 | i1 = random.randint(0, self.gene_length - 1) 85 | i2 = random.randint(0, self.gene_length - 1) 86 | # TODO 完善基于位置的变异算子 87 | return gene 88 | 89 | def select(self): 90 | """ 91 | 选择算子 92 | 按照赌盘轮转法选择个体 93 | :return: 返回被选择的一个个体 94 | """ 95 | r = random.uniform(0, self.total_score) 96 | for life in self.population: 97 | r -= life.score 98 | if r <= 0: 99 | return life 100 | 101 | def generate_one(self): 102 | """ 103 | 对选择的一个个体 104 | 根据概率cross_prob进行交叉运算 105 | 根据概率mutation_prob进行变异运算 106 | :return: 下一代种群的一个个体 107 | """ 108 | parent1 = self.select() 109 | 110 | if random.random() < self.cross_prob: 111 | parent2 = self.select() 112 | gene = self.cross(parent1, parent2) 113 | else: 114 | gene = parent1.gene 115 | 116 | if random.random() < self.mutation_prob: 117 | gene = self.mutation(gene, method='order') 118 | 119 | return Life(gene) 120 | 121 | def generate_next(self): 122 | """ 123 | 产生下一代种群 124 | :return: 将当前种群更新为下一代种群 125 | """ 126 | next_population = [] 127 | next_population.extend(self.elites[0:self.elites_num]) 128 | 129 | while len(next_population) < self.population_nums-self.elites_num: 130 | next_population.append(self.generate_one()) 131 | self.population = next_population 132 | self.generation += 1 133 | self.evaluate() 134 | 135 | def finished(self): 136 | """ 137 | 判断遗传算法是否终止 138 | :return: True or False 139 | """ 140 | return self.generation < 3000 141 | 142 | def evolution(self): 143 | """ 144 | 遗传算法演化过程 145 | :return: 演化结束后的最优种群 146 | """ 147 | while self.finished(): 148 | self.generate_next() 149 | print("generation: {}".format(self.generation)) 150 | print(1.0 / self.best.score) 151 | return self.population 152 | 153 | 154 | class Life(object): 155 | """ 156 | 种群中的个体 157 | """ 158 | 159 | def __init__(self, mGene=None, mScore=-1): 160 | self.gene = mGene 161 | self.score = mScore 162 | --------------------------------------------------------------------------------