├── .gitignore ├── 1_Handwritten_Digit_Recognition ├── data │ ├── 5678.png │ ├── 8.png │ ├── demo.png │ ├── demo1.png │ └── demo2.png ├── digit_recognition.ipynb ├── digit_recognition_multi.ipynb ├── readme.md └── train.ipynb ├── 2_Cat_Dog_Image_Classification ├── cat_dog_classification.ipynb ├── data │ ├── cat.jpg │ ├── demo.png │ ├── dog.jpg │ ├── no_normalize.png │ ├── normalize.png │ ├── train_4.png │ ├── train_5.png │ └── train_6.png ├── dataset_process.ipynb ├── readme.md └── train.ipynb ├── 3_Transformer_Sentiment_Classification ├── Chinese_Movie_review_Text_Classification.ipynb ├── data │ └── demo.png ├── readme.md └── train_transformer.ipynb ├── 4_GAN_Image_Generator ├── GAN_MINIST_Generator.ipynb ├── data │ ├── demo.png │ ├── epoch_1.png │ ├── epoch_10.png │ ├── epoch_15.png │ ├── epoch_20.png │ ├── epoch_25.png │ ├── epoch_30.png │ └── epoch_5.png ├── digit_generate.ipynb └── readme.md ├── 5_RL_Snake ├── RL_Snake.ipynb ├── data │ └── AI_Snake.gif └── readme.md ├── 6_Adversarial_Attack ├── data │ ├── accuracy_vs_epsilon_plot.png │ ├── attack_cat.jpg │ ├── attack_dog.jpg │ ├── cat.jpg │ ├── demo.jpg │ ├── dog.jpg │ └── fgsm_examples_grid.png ├── readme.md ├── train.ipynb ├── white_box_PGD.ipynb └── white_box_attack-FGSM.ipynb ├── 7_RL_Pacman ├── RLAgents.py ├── data │ ├── example_image.png │ └── loadAgent.png ├── game.py ├── ghostAgents.py ├── graphicsDisplay.py ├── graphicsUtils.py ├── keyboardAgents.py ├── layout.py ├── layouts │ └── mediumClassic.lay ├── pacman.py ├── readme.md ├── textDisplay.py └── util.py └── readme.md /.gitignore: -------------------------------------------------------------------------------- 1 | # 忽略datasets中的数据集 2 | /datasets/* 3 | 4 | # 忽略.ipynb_checkpoints文件夹 5 | *.ipynb_checkpoints 6 | 7 | # 忽略__pycache__文件夹 8 | **/__pycache__ 9 | 10 | # 忽略模型权重文件 11 | /models/* -------------------------------------------------------------------------------- /1_Handwritten_Digit_Recognition/data/5678.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/1_Handwritten_Digit_Recognition/data/5678.png -------------------------------------------------------------------------------- /1_Handwritten_Digit_Recognition/data/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/1_Handwritten_Digit_Recognition/data/8.png -------------------------------------------------------------------------------- /1_Handwritten_Digit_Recognition/data/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/1_Handwritten_Digit_Recognition/data/demo.png -------------------------------------------------------------------------------- /1_Handwritten_Digit_Recognition/data/demo1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/1_Handwritten_Digit_Recognition/data/demo1.png -------------------------------------------------------------------------------- /1_Handwritten_Digit_Recognition/data/demo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/1_Handwritten_Digit_Recognition/data/demo2.png -------------------------------------------------------------------------------- /1_Handwritten_Digit_Recognition/readme.md: -------------------------------------------------------------------------------- 1 | ### 项目结构 📂 2 | - **train**: 数据集下载与处理,模型训练、评估及结果可视化,保存训练好的模型。 3 | - **digit_recognition**: 加载本地保存的模型,对本地手写体数字图片进行单个数字的预测。 4 | - **digit_recognition_multi**: 在 `digit_recognition` 的基础上扩展,对包含多个数字的图片,使用 `opencv` 分割图片后逐一识别并合并结果(效果可能较弱,可以尝试进行优化)。 5 | - **data**: 进行推理测试的图片 6 | 7 | 8 | ### 进阶 9 | [K-12 手写体(HME100K)数据集](https://ai.100tal.com/dataset) 10 | 利用这里的数据集实现 图片-->对应的markdown公式 -------------------------------------------------------------------------------- /2_Cat_Dog_Image_Classification/data/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/2_Cat_Dog_Image_Classification/data/cat.jpg -------------------------------------------------------------------------------- /2_Cat_Dog_Image_Classification/data/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/2_Cat_Dog_Image_Classification/data/demo.png -------------------------------------------------------------------------------- /2_Cat_Dog_Image_Classification/data/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/2_Cat_Dog_Image_Classification/data/dog.jpg -------------------------------------------------------------------------------- /2_Cat_Dog_Image_Classification/data/no_normalize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/2_Cat_Dog_Image_Classification/data/no_normalize.png -------------------------------------------------------------------------------- /2_Cat_Dog_Image_Classification/data/normalize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/2_Cat_Dog_Image_Classification/data/normalize.png -------------------------------------------------------------------------------- /2_Cat_Dog_Image_Classification/data/train_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/2_Cat_Dog_Image_Classification/data/train_4.png -------------------------------------------------------------------------------- /2_Cat_Dog_Image_Classification/data/train_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/2_Cat_Dog_Image_Classification/data/train_5.png -------------------------------------------------------------------------------- /2_Cat_Dog_Image_Classification/data/train_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/2_Cat_Dog_Image_Classification/data/train_6.png -------------------------------------------------------------------------------- /2_Cat_Dog_Image_Classification/dataset_process.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "1438d967-9e38-4a54-8c2b-2d28ff7ad305", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os,shutil\n", 11 | "import random" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "id": "e8d876c7-85f5-4ac8-8c2f-be92210cbf57", 17 | "metadata": {}, 18 | "source": [ 19 | "[kaggle 猫狗数据集下载](https://www.microsoft.com/en-us/download/details.aspx?id=54765) \n", 20 | "\n", 21 | "一共有共25000张图片" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "06586bf4-c2bf-4dd9-b1ff-bb1e7130a646", 27 | "metadata": {}, 28 | "source": [ 29 | "### 通用数据集分割流程\n", 30 | "\n", 31 | "原始数据集结构:\n", 32 | " - PetImages\n", 33 | " - Cat\n", 34 | " - Dog\n", 35 | "\n", 36 | "处理之后的数据集结构:\n", 37 | "- PetImages\n", 38 | " - train\n", 39 | " - Cat\n", 40 | " - Dog\n", 41 | " - test\n", 42 | " - Cat\n", 43 | " - Dog\n", 44 | "\n", 45 | "代码具有普适性,可直接用于处理类似结构的数据集,只需要修改`root_dir`, 并根据需要修改`test_rate`" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 2, 51 | "id": "c6821185-1711-4967-860b-5ff16ca715dc", 52 | "metadata": {}, 53 | "outputs": [ 54 | { 55 | "name": "stdout", 56 | "output_type": "stream", 57 | "text": [ 58 | "数据集处理完成\n" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "root_dir=r\"../datasets/PetImages\"\n", 64 | "categories = os.listdir(root_dir)\n", 65 | "\n", 66 | "train_dir = os.path.join(root_dir, 'train')\n", 67 | "os.makedirs(train_dir, exist_ok=True)\n", 68 | "test_dir = os.path.join(root_dir, 'test')\n", 69 | "os.makedirs(test_dir, exist_ok=True)\n", 70 | "\n", 71 | "test_rate=0.1 #训练集和测试集的比例为9:1 22500 : 2500, 可以自己指定比例\n", 72 | "\n", 73 | "for category in categories:\n", 74 | " src_dir = os.path.join(root_dir, category)\n", 75 | " filenames = os.listdir(src_dir)\n", 76 | " test_num = int(len(filenames) * test_rate)\n", 77 | " test_filenames = random.sample(filenames, test_num)\n", 78 | " \n", 79 | " # 移动测试图片\n", 80 | " test_category_dir = os.path.join(test_dir, category)\n", 81 | " os.makedirs(test_category_dir, exist_ok=True)\n", 82 | " for test_filename in test_filenames:\n", 83 | " src_path = os.path.join(src_dir, test_filename)\n", 84 | " tgt_path = os.path.join(test_category_dir, test_filename)\n", 85 | " shutil.move(src_path, tgt_path)\n", 86 | " \n", 87 | " # 移动训练集图片(src_dir中剩下的图片)\n", 88 | " train_category_dir = os.path.join(train_dir, category)\n", 89 | " os.makedirs(train_category_dir, exist_ok=True)\n", 90 | " for train_filename in os.listdir(src_dir):\n", 91 | " src_path = os.path.join(src_dir, train_filename)\n", 92 | " tgt_path = os.path.join(train_category_dir, train_filename)\n", 93 | " shutil.move(src_path, tgt_path)\n", 94 | "\n", 95 | " # 删除原始目录\n", 96 | " os.rmdir(src_dir)\n", 97 | " \n", 98 | "print(\"数据集处理完成\")" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "id": "84bb2b5f-3651-418c-b390-d40e2ff28a2f", 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [] 108 | } 109 | ], 110 | "metadata": { 111 | "kernelspec": { 112 | "display_name": "Python 3 (ipykernel)", 113 | "language": "python", 114 | "name": "python3" 115 | }, 116 | "language_info": { 117 | "codemirror_mode": { 118 | "name": "ipython", 119 | "version": 3 120 | }, 121 | "file_extension": ".py", 122 | "mimetype": "text/x-python", 123 | "name": "python", 124 | "nbconvert_exporter": "python", 125 | "pygments_lexer": "ipython3", 126 | "version": "3.8.20" 127 | } 128 | }, 129 | "nbformat": 4, 130 | "nbformat_minor": 5 131 | } 132 | -------------------------------------------------------------------------------- /2_Cat_Dog_Image_Classification/readme.md: -------------------------------------------------------------------------------- 1 | ### 项目结构 📂 2 | - **dataset_process**: 对数据集进行处理,划分为训练集和测试集(通用代码,可以迁移到其他的图片数据集)。 3 | - **train**: 因为训练时间可能比较久,加上了进度条显示,可以直观感受到模型的训练进度。 4 | - **cat_dog_classification**: 加载本地保存的模型,对本地猫狗图片进行分类预测。 5 | - **data**: 包含本地上传用于预测的猫狗图片,以及不同参数下模型训练的结果。 6 | 7 | 下载[Kaggle 猫狗数据集](https://www.microsoft.com/en-us/download/details.aspx?id=54765)到本地,然后运行·=`dataset_process`代码划分数据集。 8 | *原始数据中存在无法读取的图片(`Cat/666.png`、`Dog/11702.png`), 直接删除或替换即可。* 9 | 10 | ### 训练迭代过程(部分) 11 | 12 | *之所以是部分,是因为作者在这个项目快做完的时候才想到要记录一下过程... 于是就设计了几个不同的配置加以对照* 13 | 14 | - 1. lr=1e-3, num_epochs = 20 : Train_acc:90.5%, Train_loss:0.228, Test_acc:86.7%,Test_loss:0.322 明显出现了过拟合 15 | - 2. lr=1e-4, num_epochs = 20 : Train_acc:81.4%, Train_loss:0.409, Test_acc:83.4%,Test_loss:0.383 学习率太小,训练不充分, 导致欠拟合 16 | - 3. lr=1e-3, 使用StepLR学习率衰减策略,权重衰减: Train_acc:90.3%, Train_loss:0.239, Test_acc:88.0%,Test_loss:0.304 相比于1, 学习率衰减和权重衰减确实能减轻过拟合, 但是过拟合仍然存在 17 | - 4. lr=1e-3, num_epochs = 20, 增加水平翻转、旋转、颜色抖动的数据增强操作:Train_acc:87.9%, Train_loss:0.291, Test_acc:87.6%,Test_loss:0.288 完美解决了过拟合的问题!果然数据集才是根本。不过准确率还是差了一些 18 | - 5. lr=1e-3, num_epochs = 20, 新增加一个卷积层`conv4`:Train_acc:94.9%, Train_loss:0.124, Test_acc:91.9%,Test_loss:0.199 只增加了一个卷积层就大幅提升了模型性能!!!看来模型架构是关键。存在过拟合 19 | - 6. lr=1e-3, num_epochs = 20, 数据增强,使用StepLR学习率衰减策略,权重衰减,新增加一个卷积层`conv4`: Train_acc:93.6%, Train_loss:0.166, Test_acc:92.5%,Test_loss:0.185 终极模型,可惜效果比预想的要差一些 20 | 21 | 22 | 23 | --- 24 | 25 | ### 数据标准化对模型的影响 26 | 27 | 一开始参考别人的教程, 我对数据使用了标准化操作 28 | ```python 29 | transforms.Normalize( # 对张量进行标准化处理, 提升模型训练的稳定性 30 | mean=[0.485, 0.456, 0.406], # 使用 ImageNet 数据集的均值 31 | std=[0.229, 0.224, 0.225] # 使用 ImageNet 数据集的标准差 32 | ) 33 | ``` 34 | 但是当我把处理后的图片输出后大为震惊, 发现图片被处理得严重失真, 连肉眼都很难分辨。 35 | 36 | 标准化前: 37 | ![no_normalize](./data/no_normalize.png) 38 | 39 | 标准化后: 40 | ![normalize](./data/normalize.png) 41 | 42 | 我觉得这肯定会损害模型的性能, 于是就请教了GPT老师, 它告诉我: 43 | 44 | --- 45 | 46 | ### **为什么标准化会导致图片“失真”?** 47 | `transforms.Normalize` 是针对模型的输入进行标准化处理的操作,它使用给定的均值和标准差将像素值归一化为零均值和单位标准差。标准化后的像素值可能包含负值,因此直接展示这些数据时,肉眼会看到“失真”效果,但实际上这对模型的训练有利。 48 | 49 | - **标准化的目的**: 50 | - 减少特征值范围差异,帮助模型更快地收敛。 51 | - 提高模型对数据分布的鲁棒性,特别是在使用预训练模型(如 ResNet)时,保持输入数据与预训练的分布一致。 52 | 53 | - **为什么图像会失真**: 54 | 标准化是针对张量进行数学操作,最终会改变图像像素的视觉表现,但它不会影响模型接收的特征信息。 55 | 56 | ### **是否可以直接移除标准化?** 57 | **结论:** 58 | - 如果使用预训练模型:不建议移除标准化,推荐调整或逆向还原以适应展示需求。 59 | - 如果从头训练模型:可以移除,但可能需要更长的训练时间来弥补标准化带来的优势。 60 | 61 | --- 62 | 看到这里我心里放心了不少, 但是还需要用实验来验证它的说法 63 | 64 | 65 | 未标准化(lr=1e-3, num_epochs = 5):Train_acc:82.8%, Train_loss:0.383, Test_acc:82.0%,Test_loss:0.387 66 | 67 | 标准化:Epoch: 5, Train_acc:83.0%, Train_loss:0.383, Test_acc:82.7%,Test_loss:0.385 68 | 69 | 说明确实会有用? 具体影响还有待深入探究。 70 | 71 | 不过至少不会像我最开始想的那样损害模型性能,在上面的迭代过程中均默认使用了(ImageNet)标准化操作。 72 | 73 | **猜想: 将将均值和标准差调整为基于当前数据集的统计值效果应该能达到更好的效果, 有待进一步验证** -------------------------------------------------------------------------------- /3_Transformer_Sentiment_Classification/Chinese_Movie_review_Text_Classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "1fd97fa5-35c5-4318-94f8-2be0e04b8ed8", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import torch.nn as nn\n", 12 | "import math\n", 13 | "from transformers import BertTokenizer, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup\n" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "id": "f51feb63-a6d4-409a-be42-0fb7df600667", 20 | "metadata": {}, 21 | "outputs": [ 22 | { 23 | "data": { 24 | "text/plain": [ 25 | "device(type='cuda')" 26 | ] 27 | }, 28 | "execution_count": 2, 29 | "metadata": {}, 30 | "output_type": "execute_result" 31 | } 32 | ], 33 | "source": [ 34 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 35 | "device" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "id": "43677da2-2850-420f-8270-1bf46db981ad", 41 | "metadata": {}, 42 | "source": [ 43 | "### 加载训练好的本地模型" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "id": "ada9af59-0e5a-4e55-8cae-f8cbff7b112f", 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "name": "stderr", 54 | "output_type": "stream", 55 | "text": [ 56 | "C:\\Users\\PC\\AppData\\Local\\Temp\\ipykernel_27116\\44386423.py:140: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", 57 | " model.load_state_dict(torch.load(model_path))\n" 58 | ] 59 | }, 60 | { 61 | "data": { 62 | "text/plain": [ 63 | "TransformerClassifier(\n", 64 | " (token_embedding): Embedding(21128, 256)\n", 65 | " (positional_embedding): Embedding(128, 256)\n", 66 | " (transformer_encoder): TransformerEncoder(\n", 67 | " (layers): ModuleList(\n", 68 | " (0-3): 4 x TransformerEncoderLayer(\n", 69 | " (self_attn): MultiheadAttention(\n", 70 | " (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n", 71 | " )\n", 72 | " (linear1): Linear(in_features=256, out_features=512, bias=True)\n", 73 | " (dropout): Dropout(p=0.1, inplace=False)\n", 74 | " (linear2): Linear(in_features=512, out_features=256, bias=True)\n", 75 | " (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 76 | " (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 77 | " (dropout1): Dropout(p=0.1, inplace=False)\n", 78 | " (dropout2): Dropout(p=0.1, inplace=False)\n", 79 | " )\n", 80 | " )\n", 81 | " (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", 82 | " )\n", 83 | " (dropout): Dropout(p=0.1, inplace=False)\n", 84 | " (fc_out): Linear(in_features=256, out_features=2, bias=True)\n", 85 | ")" 86 | ] 87 | }, 88 | "execution_count": 3, 89 | "metadata": {}, 90 | "output_type": "execute_result" 91 | } 92 | ], 93 | "source": [ 94 | "class TransformerClassifier(nn.Module):\n", 95 | " def __init__(self, vocab_size, embed_dim, num_heads, num_encoder_layers, ff_dim, num_classes,\n", 96 | " max_len=512, dropout_rate=0.1):\n", 97 | " \"\"\"\n", 98 | " 初始化 Transformer 分类器\n", 99 | " \n", 100 | " Args:\n", 101 | " vocab_size (int): 词汇表大小(tokenizer.vocab_size)。\n", 102 | " embed_dim (int): 词嵌入和 Transformer 的维度(d_model)。\n", 103 | " num_heads (int): 多头注意力机制的头数,必须能整除 embed_dim。\n", 104 | " num_encoder_layers (int): Transformer Encoder 的层数。\n", 105 | " ff_dim (int): 前馈网络中间层的维度(通常为 embed_dim 的 2-4 倍)。\n", 106 | " num_classes (int): 分类任务的类别数(2 表示正/负)。\n", 107 | " max_len (int): 最大序列长度,用于位置嵌入。\n", 108 | " dropout_rate (float): Dropout 比率,用于正则化。\n", 109 | " \"\"\"\n", 110 | " super().__init__()\n", 111 | " self.embed_dim = embed_dim\n", 112 | " # 词嵌入层,将 token ID 映射到 embed_dim 维向量\n", 113 | " self.token_embedding = nn.Embedding(vocab_size, embed_dim)\n", 114 | " # 可学习的位置嵌入,为每个位置生成 embed_dim 维向量\n", 115 | " self.positional_embedding = nn.Embedding(max_len, embed_dim)\n", 116 | "\n", 117 | " # 定义单个 Transformer Encoder 层\n", 118 | " encoder_layer = nn.TransformerEncoderLayer(\n", 119 | " d_model=embed_dim, # 模型维度\n", 120 | " nhead=num_heads, # 注意力头数\n", 121 | " dim_feedforward=ff_dim, # 前馈网络中间层维度\n", 122 | " dropout=dropout_rate, # Dropout 比率\n", 123 | " batch_first=True, # 输入/输出形状为 (batch, seq, feature),适配常见数据格式\n", 124 | " activation='gelu' # 使用 GELU 激活函数,相比 ReLU 更平滑,有助于梯度流动\n", 125 | " )\n", 126 | "\n", 127 | " # 堆叠多个 Transformer Encoder 层\n", 128 | " self.transformer_encoder = nn.TransformerEncoder(\n", 129 | " encoder_layer,\n", 130 | " num_layers=num_encoder_layers,\n", 131 | " norm=nn.LayerNorm(embed_dim) # 显式添加 LayerNorm,规范化输出\n", 132 | " )\n", 133 | "\n", 134 | " self.dropout = nn.Dropout(dropout_rate)\n", 135 | " \n", 136 | " # 分类头:将 [CLS] token 的输出(embed_dim 维)映射到 num_classes 维\n", 137 | " self.fc_out = nn.Linear(embed_dim, num_classes)\n", 138 | "\n", 139 | " self.max_len = max_len # 存储最大序列长度,供位置编码使用\n", 140 | "\n", 141 | " self._init_weights()\n", 142 | "\n", 143 | " def _init_weights(self):\n", 144 | " \"\"\"\n", 145 | " 初始化模型权重,使用 Xavier Uniform 初始化,适合 Transformer 模型。\n", 146 | " 避免初始权重过大或过小,加速收敛。\n", 147 | " \"\"\"\n", 148 | " for p in self.parameters():\n", 149 | " if p.dim() > 1: # 仅对二维以上参数(如线性层、嵌入层)应用\n", 150 | " nn.init.xavier_uniform_(p)\n", 151 | " # 对嵌入层可额外应用正态初始化\n", 152 | " elif p.dim() == 2 and 'embedding' in p.name:\n", 153 | " nn.init.normal_(p, mean=0.0, std=0.02)\n", 154 | "\n", 155 | " def forward(self, input_ids, attention_mask):\n", 156 | " \"\"\"\n", 157 | " 前向传播,处理输入序列并输出分类 logits。\n", 158 | "\n", 159 | " Args:\n", 160 | " input_ids (torch.Tensor): 形状 (batch_size, seq_len),词的 ID。\n", 161 | " attention_mask (torch.Tensor): 形状 (batch_size, seq_len),1 表示有效 token,0 表示 padding。\n", 162 | "\n", 163 | " Returns:\n", 164 | " torch.Tensor: 形状 (batch_size, num_classes),分类 logits。\n", 165 | " \"\"\"\n", 166 | " seq_len = input_ids.size(1) # 获取序列长度\n", 167 | "\n", 168 | " # 1. 词嵌入\n", 169 | " token_embeds = self.token_embedding(input_ids) # (batch_size, seq_len, embed_dim)\n", 170 | " token_embeds = token_embeds * math.sqrt(self.embed_dim) # 缩放嵌入,稳定训练\n", 171 | "\n", 172 | " # 2. 位置编码\n", 173 | " # 生成位置索引:(batch_size, seq_len),每个样本重复 0 到 seq_len-1\n", 174 | " positions = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0).repeat(input_ids.size(0), 1)\n", 175 | " position_embeds = self.positional_embedding(positions) # (batch_size, seq_len, embed_dim)\n", 176 | " \n", 177 | " # 词嵌入与位置嵌入相加\n", 178 | " x = token_embeds + position_embeds\n", 179 | " x = self.dropout(x) # 在嵌入后应用 Dropout,增强鲁棒性\n", 180 | "\n", 181 | " # Transformer Encoder需要 src_key_padding_mask\n", 182 | " # attention_mask: 1是token, 0是padding.\n", 183 | " # src_key_padding_mask: True表示该位置是padding, 需要被mask掉.\n", 184 | " src_key_padding_mask = (attention_mask == 0) # (batch_size, seq_len)\n", 185 | "\n", 186 | " # 3. Transformer Encoder\n", 187 | " # 输入形状: (batch_size, seq_len, embed_dim)\n", 188 | " encoder_output = self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask)\n", 189 | " # encoder_output shape: (batch_size, seq_len, embed_dim)\n", 190 | "\n", 191 | " # 4. 分类\n", 192 | " # 通常使用第一个token ([CLS] token)的输出来进行分类\n", 193 | " cls_output = encoder_output[:, 0, :] # (batch_size, embed_dim)\n", 194 | " # 或者,可以对所有token的输出进行平均池化或最大池化\n", 195 | " # cls_output = encoder_output.mean(dim=1) # 平均池化\n", 196 | "\n", 197 | " cls_output = self.dropout(cls_output)\n", 198 | " logits = self.fc_out(cls_output) # (batch_size, num_classes)\n", 199 | "\n", 200 | " return logits\n", 201 | "\n", 202 | "\n", 203 | "# 加载Bert 的分词器\n", 204 | "tokenizer_path = '../models/3_Transformer_Sentiment_Classification/bert-base-chinese'\n", 205 | "tokenizer = BertTokenizer.from_pretrained(tokenizer_path)\n", 206 | " \n", 207 | "# 定义模型超参数\n", 208 | "VOCAB_SIZE = tokenizer.vocab_size # 从之前加载的 BERT 分词器获取\n", 209 | "EMBED_DIM = 256 # 嵌入维度,较小以减少计算量(BERT 常用 768)\n", 210 | "NUM_HEADS = 8 # 多头注意力头数,需满足 embed_dim % num_heads == 0\n", 211 | "NUM_ENCODER_LAYERS = 4 # Encoder 层数,平衡性能与计算成本\n", 212 | "FF_DIM = 512 # 前馈网络中间层维度,通常为 embed_dim 的 2-4 倍\n", 213 | "NUM_CLASSES = 2 # 分类任务的类别数(正/负情感)\n", 214 | "DROPOUT_RATE = 0.1 # Dropout 比率,防止过拟合\n", 215 | "MAX_LEN = 128\n", 216 | "\n", 217 | "\n", 218 | "model = TransformerClassifier(\n", 219 | " vocab_size=VOCAB_SIZE,\n", 220 | " embed_dim=EMBED_DIM,\n", 221 | " num_heads=NUM_HEADS,\n", 222 | " num_encoder_layers=NUM_ENCODER_LAYERS,\n", 223 | " ff_dim=FF_DIM,\n", 224 | " num_classes=NUM_CLASSES,\n", 225 | " max_len=MAX_LEN, # 从之前的配置中获取\n", 226 | " dropout_rate=DROPOUT_RATE\n", 227 | ")\n", 228 | "model = model.to(device)\n", 229 | "\n", 230 | "model_path = '../models/3_Transformer_Sentiment_Classification/model_weights.pth'\n", 231 | "\n", 232 | "# 加载模型参数\n", 233 | "model.load_state_dict(torch.load(model_path))\n", 234 | "\n", 235 | "# 将模型设置为评估模式\n", 236 | "model.eval()" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "id": "20cec762-f39c-42d0-b401-b0c7c030b5e4", 242 | "metadata": {}, 243 | "source": [ 244 | "## 模型推理" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 4, 250 | "id": "c8a318aa-aeb7-49c8-b4f7-34b62a872256", 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "def predict_sentiment(text, max_len=MAX_LEN):\n", 255 | " # 文本预处理\n", 256 | " encoding = tokenizer.encode_plus(\n", 257 | " text,\n", 258 | " add_special_tokens=True, # 添加 [CLS] 和 [SEP]\n", 259 | " max_length=max_len,\n", 260 | " padding='max_length', # 填充到 max_len\n", 261 | " truncation=True, # 截断超长文本\n", 262 | " return_tensors='pt', # 返回 PyTorch 张量\n", 263 | " return_attention_mask=True\n", 264 | " )\n", 265 | " input_ids = encoding['input_ids'].to(device) # (1, max_len)\n", 266 | " attention_mask = encoding['attention_mask'].to(device) # (1, max_len)\n", 267 | " # 模型推理\n", 268 | " with torch.no_grad(): # 禁用梯度计算,节省内存\n", 269 | " logits = model(input_ids, attention_mask) # (1, num_classes)\n", 270 | " pred = torch.argmax(logits, dim=-1).item() # 预测类别 (0 或 1)\n", 271 | "\n", 272 | " # 转换标签\n", 273 | " return '正面' if pred == 1 else '负面'" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 5, 279 | "id": "1fa98604-c3a9-47d3-a8be-7d11bcafb8e7", 280 | "metadata": {}, 281 | "outputs": [ 282 | { 283 | "name": "stdout", 284 | "output_type": "stream", 285 | "text": [ 286 | "Review : 虽没有第一部“我命由我不由天”的惊艳金句,但更多了些“怎能不知道这世间的规则,由谁所定?”的结构性思考,无量仙翁的“个体失范代替制度失范”真是最佳切口。狠狠期待第三部!\n", 287 | "sentiment : 正面\n", 288 | "Review : 我觉得中国这些人拍点电影,啥时候变成这种短视频短剧形式的切片合集了?一点点深度也没有了?太快餐了\n", 289 | "sentiment : 负面\n" 290 | ] 291 | }, 292 | { 293 | "name": "stderr", 294 | "output_type": "stream", 295 | "text": [ 296 | "D:\\develop\\anaconda\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\transformer.py:409: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at C:\\cb\\pytorch_1000000000000\\work\\aten\\src\\ATen\\NestedTensorImpl.cpp:180.)\n", 297 | " output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)\n", 298 | "D:\\develop\\anaconda\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\transformer.py:720: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at C:\\cb\\pytorch_1000000000000\\work\\aten\\src\\ATen\\native\\transformers\\cuda\\sdp_utils.cpp:555.)\n", 299 | " return torch._transformer_encoder_layer_fwd(\n" 300 | ] 301 | } 302 | ], 303 | "source": [ 304 | "pos_review = \"虽没有第一部“我命由我不由天”的惊艳金句,但更多了些“怎能不知道这世间的规则,由谁所定?”的结构性思考,无量仙翁的“个体失范代替制度失范”真是最佳切口。狠狠期待第三部!\"\n", 305 | "neg_review = \"我觉得中国这些人拍点电影,啥时候变成这种短视频短剧形式的切片合集了?一点点深度也没有了?太快餐了\"\n", 306 | "print(f'Review : {pos_review}\\nsentiment : {predict_sentiment(pos_review)}')\n", 307 | "print(f'Review : {neg_review}\\nsentiment : {predict_sentiment(neg_review)}')" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "id": "592d8098-f44e-4d6d-8a46-c79300734aac", 314 | "metadata": {}, 315 | "outputs": [], 316 | "source": [] 317 | } 318 | ], 319 | "metadata": { 320 | "kernelspec": { 321 | "display_name": "Python 3 (ipykernel)", 322 | "language": "python", 323 | "name": "python3" 324 | }, 325 | "language_info": { 326 | "codemirror_mode": { 327 | "name": "ipython", 328 | "version": 3 329 | }, 330 | "file_extension": ".py", 331 | "mimetype": "text/x-python", 332 | "name": "python", 333 | "nbconvert_exporter": "python", 334 | "pygments_lexer": "ipython3", 335 | "version": "3.8.20" 336 | } 337 | }, 338 | "nbformat": 4, 339 | "nbformat_minor": 5 340 | } 341 | -------------------------------------------------------------------------------- /3_Transformer_Sentiment_Classification/data/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/3_Transformer_Sentiment_Classification/data/demo.png -------------------------------------------------------------------------------- /3_Transformer_Sentiment_Classification/readme.md: -------------------------------------------------------------------------------- 1 | 参考文章: 2 | [文本分类原理与实践](https://aibydoing.com/notebooks/chapter09-04-lab-principles-and-practices-of-text-classification) 3 | 4 | [预训练模型分类](https://juejin.cn/post/7327721919426494474#heading-3) 5 | 6 | [NLP之文本分类:「Tf-Idf、Word2Vec和BERT」三种模型比较](https://www.leiphone.com/category/yanxishe/tbzazc3cjams815p.html) 7 | 8 | 本项目所使用的数据集是作者从豆瓣爬取的,一共有50w条正负样本均衡的数据。 9 | 10 | ### 数据集 11 | - 来源:豆瓣电影评论,包含用户评分(1-5 星)和评论文本。 12 | - 规模:508,110 条记录,包含正面(label=1)和负面(label=0)情感标签。 13 | - 样本示例: 14 | - 正面:“虽然不长 但是真的不错” (4 星) 15 | - 负面:“本該是三個騙子互相拆台的一出好戲,變成了兩個傻子在一個騙子的瘋狂擺布下…” (2 星) 16 | 17 | 18 | 一开始参考上面的文章使用了 word2vec 和 Bert 两种方式,效果都不好。最后仅保留了使用 Transformer 19 | 20 | 21 | 22 | 需要手动下载 [bert-base-chinese](https://huggingface.co/google-bert/bert-base-chinese/tree/main) 的分词器文件,并将其放置在指定目录`(../models/3_Chinese_Movie_review_Text_Classification/bert-base-chinese)`。 23 | bert-base-chinese 分词器需要以下文件: 24 | - vocab.txt 25 | - config.json 26 | - tokenizer_config.json -------------------------------------------------------------------------------- /4_GAN_Image_Generator/data/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/4_GAN_Image_Generator/data/demo.png -------------------------------------------------------------------------------- /4_GAN_Image_Generator/data/epoch_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/4_GAN_Image_Generator/data/epoch_1.png -------------------------------------------------------------------------------- /4_GAN_Image_Generator/data/epoch_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/4_GAN_Image_Generator/data/epoch_10.png -------------------------------------------------------------------------------- /4_GAN_Image_Generator/data/epoch_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/4_GAN_Image_Generator/data/epoch_15.png -------------------------------------------------------------------------------- /4_GAN_Image_Generator/data/epoch_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/4_GAN_Image_Generator/data/epoch_20.png -------------------------------------------------------------------------------- /4_GAN_Image_Generator/data/epoch_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/4_GAN_Image_Generator/data/epoch_25.png -------------------------------------------------------------------------------- /4_GAN_Image_Generator/data/epoch_30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/4_GAN_Image_Generator/data/epoch_30.png -------------------------------------------------------------------------------- /4_GAN_Image_Generator/data/epoch_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/4_GAN_Image_Generator/data/epoch_5.png -------------------------------------------------------------------------------- /4_GAN_Image_Generator/digit_generate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "28b2116b-d8db-4ecf-931b-6faf03950b26", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "device(type='cuda')" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import torch\n", 22 | "import torch.nn as nn\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "import torchvision.transforms as transforms\n", 25 | "from PIL import Image\n", 26 | "\n", 27 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 28 | "device" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "id": "cb810376-7bd1-4ce1-8e70-f67001d1f27b", 34 | "metadata": {}, 35 | "source": [ 36 | "### 加载模型" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "id": "236d9a7f-674f-42c5-9a81-587ddd6aac18", 42 | "metadata": {}, 43 | "source": [ 44 | "加载生成器模型" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "id": "5dc72b2e-ec95-4b68-926f-2e02e2652be8", 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "name": "stderr", 55 | "output_type": "stream", 56 | "text": [ 57 | "C:\\Users\\PC\\AppData\\Local\\Temp\\ipykernel_26204\\4019217968.py:69: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", 58 | " generator.load_state_dict(torch.load(model_path))\n" 59 | ] 60 | }, 61 | { 62 | "data": { 63 | "text/plain": [ 64 | "" 65 | ] 66 | }, 67 | "execution_count": 2, 68 | "metadata": {}, 69 | "output_type": "execute_result" 70 | } 71 | ], 72 | "source": [ 73 | "class Generator(nn.Module):\n", 74 | " def __init__(self, input_dim=100, output_dim=1, class_num=10):\n", 75 | " '''\n", 76 | " 初始化生成网络\n", 77 | " :param input_dim:输入随机噪声的维度,(随机噪声是为了增加输出多样性)\n", 78 | " :param output_dim:生成图像的通道数(灰度图为1,RGB图为3)\n", 79 | " :param class_num:图像种类\n", 80 | " '''\n", 81 | " super(Generator, self).__init__()\n", 82 | " \"\"\"\n", 83 | " 为什么需要拼接随机噪声和条件向量?\n", 84 | " 拼接随机噪声和条件向量的目的是将两种信息结合起来,作为生成器的输入:\n", 85 | " 随机噪声:提供生成数据的随机性。\n", 86 | " 条件向量:提供生成数据的条件信息。\n", 87 | " 通过拼接,生成器可以根据条件向量生成符合特定条件的数据, 同时确保每次生成的数据会有所不同\n", 88 | " \"\"\"\n", 89 | " self.input_dim = input_dim\n", 90 | " self.class_num = class_num\n", 91 | " self.output_dim = output_dim\n", 92 | " \n", 93 | " # 嵌入层处理条件向量(类别标签), 提高条件信息的表达能力\n", 94 | " self.label_emb = nn.Embedding(class_num, class_num)\n", 95 | " \n", 96 | " # 全连接层,将输入向量映射到高维空间,然后通过反卷积层生成图像\n", 97 | " self.fc = nn.Sequential(\n", 98 | " nn.Linear(self.input_dim + self.class_num, 256),\n", 99 | " nn.LeakyReLU(0.2, inplace=True),\n", 100 | " nn.Linear(256, 512),\n", 101 | " nn.LeakyReLU(0.2, inplace=True),\n", 102 | " nn.Linear(512, 1024),\n", 103 | " nn.LeakyReLU(0.2, inplace=True),\n", 104 | " nn.Linear(1024, 128 * 7 * 7),\n", 105 | " nn.BatchNorm1d(128 * 7 * 7),\n", 106 | " nn.LeakyReLU(0.2, inplace=True)\n", 107 | " )\n", 108 | "\n", 109 | " # 反卷积层(转置卷积层),用于将高维特征图逐步上采样为最终图像\n", 110 | " self.deconv = nn.Sequential(\n", 111 | " nn.ConvTranspose2d(128, 128, 4, 2, 1), # 7x7 -> 14x14\n", 112 | " nn.BatchNorm2d(128),\n", 113 | " nn.LeakyReLU(0.2, inplace=True),\n", 114 | " nn.ConvTranspose2d(128, 64, 4, 2, 1), # 14x14 -> 28x28\n", 115 | " nn.BatchNorm2d(64),\n", 116 | " nn.LeakyReLU(0.2, inplace=True),\n", 117 | " nn.Conv2d(64, self.output_dim, 3, 1, 1), # 保持尺寸不变,但细化特征\n", 118 | " nn.Tanh() # 激活函数,将输出值限制在 [-1, 1] 范围内,适合生成图像\n", 119 | " )\n", 120 | " \n", 121 | " def forward(self, noise, labels):\n", 122 | " # 标签处理\n", 123 | " label_embedding = self.label_emb(labels)\n", 124 | " \n", 125 | " # 拼接噪声和条件向量\n", 126 | " x = torch.cat([noise, label_embedding], dim=1)\n", 127 | " \n", 128 | " # 通过全连接层\n", 129 | " x = self.fc(x)\n", 130 | " \n", 131 | " # 重塑为特征图\n", 132 | " x = x.view(-1, 128, 7, 7)\n", 133 | " \n", 134 | " # 通过反卷积层生成图像\n", 135 | " x = self.deconv(x)\n", 136 | " \n", 137 | " return x\n", 138 | "\n", 139 | "generator = Generator().to(device)\n", 140 | "model_path = '../models/4_GAN_Image_Generator/MINIST_generator.pth'\n", 141 | "generator.load_state_dict(torch.load(model_path))" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "id": "992d8fda-ac50-441a-95b4-e4500b25368f", 147 | "metadata": {}, 148 | "source": [ 149 | "### 手写数字图像生成" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 3, 155 | "id": "810ace5f-ffa4-4dd6-9ff9-b88db3571ea8", 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "def generate_digit_image(generator, digit):\n", 160 | " \"\"\"\n", 161 | " 生成指定数字的图片\n", 162 | " :param generator: 训练好的生成器模型\n", 163 | " :param digit: 要生成的数字 (0-9)\n", 164 | " :return: 生成的图片 (PIL 图像)\n", 165 | " \"\"\"\n", 166 | " generator.eval() # 设置为评估模式\n", 167 | " with torch.no_grad():\n", 168 | " # 生成随机噪声\n", 169 | " noise = torch.randn(1, generator.input_dim).to(device)\n", 170 | " \n", 171 | " # 创建标签\n", 172 | " label = torch.tensor([digit]).to(device)\n", 173 | " \n", 174 | " # 生成图片\n", 175 | " fake_image = generator(noise, label)\n", 176 | " \n", 177 | " # 将图片从 [-1, 1] 转换到 [0, 1]\n", 178 | " fake_image = (fake_image.squeeze().cpu() + 1) / 2.0\n", 179 | " \n", 180 | " # 将 2D 张量 (H, W) 转换为 3D 张量 (1, H, W)\n", 181 | " fake_image = fake_image.unsqueeze(0)\n", 182 | " \n", 183 | " # 转换为 PIL 图像\n", 184 | " fake_image = transforms.ToPILImage()(fake_image)\n", 185 | " \n", 186 | " return fake_image\n" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 4, 192 | "id": "1f5562b0-1932-4d67-b264-1bb4755c8b4c", 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "name": "stdout", 197 | "output_type": "stream", 198 | "text": [ 199 | "图片已保存到 ./data/demo.png\n" 200 | ] 201 | } 202 | ], 203 | "source": [ 204 | "# 生成一个包含 10x10 个不同数字的大图片,并保存到本地\n", 205 | "plt.figure(figsize=(10, 10)) # 设置画布大小\n", 206 | "plt.subplots_adjust(wspace=0.1, hspace=0.1) # 调整子图间距\n", 207 | "\n", 208 | "for i in range(10): # 行\n", 209 | " for j in range(10): # 列\n", 210 | " # 生成数字 j 的图片\n", 211 | " digit_image = generate_digit_image(generator, j)\n", 212 | " \n", 213 | " # 将图片添加到子图中\n", 214 | " ax = plt.subplot(10, 10, i * 10 + j + 1)\n", 215 | " ax.imshow(digit_image, cmap='gray')\n", 216 | " ax.axis('off') # 关闭坐标轴\n", 217 | "\n", 218 | "save_path = './data/demo.png'\n", 219 | "# 保存大图片\n", 220 | "plt.savefig(save_path, bbox_inches='tight')\n", 221 | "plt.close()\n", 222 | "print(f\"图片已保存到 {save_path}\")" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "id": "cdca719b-16f7-4557-a5a2-3294f8377138", 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [] 232 | } 233 | ], 234 | "metadata": { 235 | "kernelspec": { 236 | "display_name": "Python 3 (ipykernel)", 237 | "language": "python", 238 | "name": "python3" 239 | }, 240 | "language_info": { 241 | "codemirror_mode": { 242 | "name": "ipython", 243 | "version": 3 244 | }, 245 | "file_extension": ".py", 246 | "mimetype": "text/x-python", 247 | "name": "python", 248 | "nbconvert_exporter": "python", 249 | "pygments_lexer": "ipython3", 250 | "version": "3.8.20" 251 | } 252 | }, 253 | "nbformat": 4, 254 | "nbformat_minor": 5 255 | } 256 | -------------------------------------------------------------------------------- /4_GAN_Image_Generator/readme.md: -------------------------------------------------------------------------------- 1 | 参考文章:[GAN 原理 & pytorch代码实例 - 生成MINIST手写数字](https://blog.csdn.net/Lizhi_Tech/article/details/132108893) 2 | 3 | 4 | GAN 的训练结果很难通过损失直接体现出来,通过观察保存在`data`目录的中间结果来判断训练效果。 5 | 6 | `digit_generate`利用手写数字识别的模型来帮助生成质量更高的数字图片,但是效果似乎没有很好? -------------------------------------------------------------------------------- /5_RL_Snake/RL_Snake.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "a38863e6-6c15-4e60-9e35-6795b505c41e", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "pygame 2.2.0 (SDL 2.32.50, Python 3.8.20)\n", 14 | "Hello from the pygame community. https://www.pygame.org/contribute.html\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import random\n", 20 | "import pygame" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "id": "4cc4e015-f592-46c8-abb3-e3a5a1514753", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# 定义常量\n", 31 | "SCREEN_WIDTH = 800\n", 32 | "SCREEN_HEIGHT = 600\n", 33 | "POP_SIZE = 50 # 训练次数\n", 34 | "BLOCK_SIZE = 20\n", 35 | "\n", 36 | "# 定义颜色\n", 37 | "BLACK = (0, 0, 0)\n", 38 | "WHITE = (255, 255, 255)\n", 39 | "GREEN = (0, 255, 0)\n", 40 | "RED = (255, 0, 0)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "id": "8a4ecdbc-7838-4da3-950f-bc093e30f375", 46 | "metadata": {}, 47 | "source": [ 48 | "### 实现 Snake 类" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "id": "c2502651-799b-4511-984c-f899247d1e1b", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "class Snake:\n", 59 | " def __init__(self):\n", 60 | " \"\"\"初始化蛇\"\"\"\n", 61 | " self.length = 3\n", 62 | " # 初始化蛇身,从中心开始向左延伸\n", 63 | " self.positions = [\n", 64 | " (SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2),\n", 65 | " (SCREEN_WIDTH / 2 - BLOCK_SIZE, SCREEN_HEIGHT / 2),\n", 66 | " (SCREEN_WIDTH / 2 - 2 * BLOCK_SIZE, SCREEN_HEIGHT / 2)\n", 67 | " ]\n", 68 | " self.direction = random.choice([(0, 1), (0, -1), (1, 0), (-1, 0)])\n", 69 | " self.color = GREEN\n", 70 | " self.is_alive = True\n", 71 | "\n", 72 | " def get_head_position(self):\n", 73 | " \"\"\"获得蛇头的坐标\"\"\"\n", 74 | " return self.positions[0]\n", 75 | "\n", 76 | " def turn(self, point):\n", 77 | " \"\"\"\n", 78 | " 改变蛇移动方向\n", 79 | " 防止180度反向移动\n", 80 | " \"\"\"\n", 81 | " # 如果新方向与当前方向相反,则不改变\n", 82 | " if (point[0] * -1, point[1] * -1) == self.direction:\n", 83 | " return\n", 84 | " self.direction = point\n", 85 | "\n", 86 | " def move(self):\n", 87 | " \"\"\"移动蛇身\"\"\"\n", 88 | " cur = self.get_head_position()\n", 89 | " x, y = self.direction\n", 90 | " new = (cur[0] + (x * BLOCK_SIZE), cur[1] + (y * BLOCK_SIZE))\n", 91 | " \n", 92 | " # 检查是否撞墙\n", 93 | " if (new[0] < 0 or new[0] >= SCREEN_WIDTH or \n", 94 | " new[1] < 0 or new[1] >= SCREEN_HEIGHT):\n", 95 | " self.is_alive = False\n", 96 | " return\n", 97 | "\n", 98 | " \"\"\"检查是否撞到自己\"\"\"\n", 99 | " head = self.get_head_position()\n", 100 | " if head in self.positions[1:]:\n", 101 | " self.is_alive = False\n", 102 | " return\n", 103 | " \n", 104 | " self.positions.insert(0, new)\n", 105 | " if len(self.positions) > self.length:\n", 106 | " self.positions.pop()\n", 107 | "\n", 108 | " def reset(self):\n", 109 | " \"\"\"重新开始\"\"\"\n", 110 | " self.__init__()\n", 111 | "\n", 112 | " def draw(self, surface):\n", 113 | " \"\"\"在画布上绘制蛇身\"\"\"\n", 114 | " for p in self.positions:\n", 115 | " r = pygame.Rect((p[0], p[1]), (BLOCK_SIZE, BLOCK_SIZE))\n", 116 | " pygame.draw.rect(surface, self.color, r)\n", 117 | " pygame.draw.rect(surface, BLACK, r, 1) # 绘制边框\n", 118 | "\n", 119 | " def grow(self):\n", 120 | " \"\"\"蛇身增长\"\"\"\n", 121 | " self.length += 1\n", 122 | "\n", 123 | " def check_is_alive(self):\n", 124 | " return self.is_alive" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "id": "a5345aa1-ee17-4948-929c-1944936c5b70", 130 | "metadata": {}, 131 | "source": [ 132 | "### 实现 Food 类" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 4, 138 | "id": "6ccaf002-a454-4b08-879c-f4a0e45141f4", 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "class Food:\n", 143 | " def __init__(self, snake_positions=None):\n", 144 | " \"\"\"\n", 145 | " 初始化食物\n", 146 | " snake_positions: 可选参数,蛇身位置列表,用于避免食物生成在蛇身上\n", 147 | " \"\"\"\n", 148 | " self.color = RED\n", 149 | " self._generate_position(snake_positions)\n", 150 | "\n", 151 | " def _generate_position(self, snake_positions):\n", 152 | " \"\"\"生成食物位置,确保不与蛇身重叠\"\"\"\n", 153 | " while True:\n", 154 | " x = random.randrange(0, SCREEN_WIDTH, BLOCK_SIZE)\n", 155 | " y = random.randrange(0, SCREEN_HEIGHT, BLOCK_SIZE)\n", 156 | " self.position = (x, y)\n", 157 | " # 如果没有传入蛇身位置,或位置不在蛇身上,则跳出循环\n", 158 | " if snake_positions is None or self.position not in snake_positions:\n", 159 | " break\n", 160 | "\n", 161 | " def get_position(self):\n", 162 | " \"\"\"获得食物坐标\"\"\"\n", 163 | " return self.position\n", 164 | "\n", 165 | " def draw(self, surface):\n", 166 | " \"\"\"在画布上绘制食物\"\"\"\n", 167 | " r = pygame.Rect((self.position[0], self.position[1]), (BLOCK_SIZE, BLOCK_SIZE))\n", 168 | " pygame.draw.rect(surface, self.color, r)\n", 169 | " pygame.draw.rect(surface, BLACK, r, 1) # 绘制边框\n", 170 | "\n", 171 | " def respawn(self, snake_positions=None):\n", 172 | " \"\"\"重新生成食物位置\"\"\"\n", 173 | " self._generate_position(snake_positions)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "id": "f1835258-8a2e-4687-b213-a60011d432df", 179 | "metadata": {}, 180 | "source": [ 181 | "### 构建训练蛇的神经网络" 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "id": "7bd6e615-bb2c-4aa3-9be3-8a9686325ae5", 187 | "metadata": {}, 188 | "source": [ 189 | "#### Q 值是什么?\n", 190 | "Q 值和 return 是强化学习中的两个重要概念,但它们并不相同。\n", 191 | "- Q 值表示在某个状态下采取某个动作后,未来可能获得的总奖励的期望值(模型预测得到)。\n", 192 | "- Return 是从当前时刻开始,未来所有奖励的总和(历史经验数据得到)。\n", 193 | "- Q 值用于指导策略(选择动作),而 return 用于评估策略的性能。\n", 194 | "\n", 195 | "\n", 196 | "#### 目标策略网络与当前策略模型的区别与联系,为什么要用两个模型?\n", 197 | "在Deep Q Learning(DQN)中,使用两个神经网络模型是为了解决训练不稳定的问题。 \n", 198 | "- 当前策略网络 (self.model):用于预测当前状态的 Q 值,并通过优化器更新权重。\n", 199 | "- 目标策略网络 (self.target_model):用于计算下一个状态的 Q 值。\n", 200 | "- 目标策略网络的权重是从当前策略网络定期同步的,因此目标 Q 值在一段时间内是稳定的,这有助于模型更好地收敛。" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 5, 206 | "id": "f8b44783-c1b2-49da-873f-cfcfe49744b7", 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "import numpy as np\n", 211 | "import torch\n", 212 | "import torch.nn as nn\n", 213 | "import torch.optim as optim\n", 214 | "from collections import deque\n", 215 | "import os\n", 216 | "\n", 217 | "class SnakeAI:\n", 218 | " def __init__(self, buffer_size=1000, batch_size=32):\n", 219 | " \"\"\"\n", 220 | " 初始化 SnakeAI 类。\n", 221 | " 参数:\n", 222 | " buffer_size: 经验回放缓冲区的大小。\n", 223 | " batch_size: 每次训练时从缓冲区中采样的批次大小。\n", 224 | " \"\"\"\n", 225 | " # 设置参数\n", 226 | " self.gamma = 0.99 # 折扣因子,用于计算未来奖励的重要性\n", 227 | " self.input_size = 12 # 输入状态的维度(Game类get_state()中定义)\n", 228 | " self.output_size = 4 # 输出动作的维度(上下左右四个方向)\n", 229 | " self.hidden_size = 100 # 神经网络隐藏层的大小\n", 230 | " self.batch_size = batch_size # 训练批次大小\n", 231 | " self.update_freq = 1000 # 目标网络更新频率\n", 232 | " self.train_steps = 0 # 训练步数计数器\n", 233 | " self.epsilon = 1.0 # 初始 epsilon epsilon表示蛇移动的随机性\n", 234 | " self.epsilon_min = 0.01 # 最小 epsilon\n", 235 | " self.epsilon_decay = 0.995 # 衰减率\n", 236 | "\n", 237 | " # 创建神经网络模型\n", 238 | " self.model = self.build_model() # 当前策略网络,用于预测动作\n", 239 | " self.target_model = self.build_model() # 目标策略网络,用于计算训练目标\n", 240 | " # 检查预训练权重文件是否存在\n", 241 | " weight_path = '../models/5_RL_Snake/best_weights.pth'\n", 242 | " if os.path.exists(weight_path):\n", 243 | " self.model.load_state_dict(torch.load(weight_path))\n", 244 | " self.target_model.load_state_dict(torch.load(weight_path))\n", 245 | " print(f\"Loaded pre-trained weights from {weight_path}\")\n", 246 | " \n", 247 | " self.optimizer = optim.Adam(self.model.parameters(), lr=0.001) # 优化器,用于更新模型参数\n", 248 | " self.criterion = nn.MSELoss() # 损失函数,用于计算预测值与目标值的差距\n", 249 | "\n", 250 | " # 经验回放缓冲区,用于存储游戏经验(状态、动作、奖励等)\n", 251 | " self.buffer = deque(maxlen=buffer_size)\n", 252 | "\n", 253 | " # 同步目标网络的权重,确保初始时目标网络与当前网络一致\n", 254 | " self.update_target_model()\n", 255 | "\n", 256 | " def build_model(self):\n", 257 | " \"\"\"\n", 258 | " 构建神经网络模型。\n", 259 | " 返回:\n", 260 | " 一个包含输入层、隐藏层和输出层的神经网络模型。\n", 261 | " \"\"\"\n", 262 | " model = nn.Sequential(\n", 263 | " nn.Linear(self.input_size, self.hidden_size), # 输入层到隐藏层\n", 264 | " nn.ReLU(), # 激活函数\n", 265 | " nn.Linear(self.hidden_size, self.hidden_size), # 隐藏层到隐藏层\n", 266 | " nn.ReLU(), # 激活函数\n", 267 | " nn.Linear(self.hidden_size, self.hidden_size), # 隐藏层到隐藏层\n", 268 | " nn.ReLU(), # 激活函数\n", 269 | " nn.Linear(self.hidden_size, self.output_size), # 隐藏层到输出层\n", 270 | " )\n", 271 | " return model\n", 272 | "\n", 273 | " def update_target_model(self):\n", 274 | " \"\"\"\n", 275 | " 更新目标策略网络的权重。\n", 276 | " 将当前策略网络的权重复制到目标策略网络中。\n", 277 | " \"\"\"\n", 278 | " self.target_model.load_state_dict(self.model.state_dict())\n", 279 | "\n", 280 | " def get_action(self, state):\n", 281 | " \"\"\"\n", 282 | " 根据当前状态选择动作。\n", 283 | " 使用 epsilon-greedy 策略,平衡探索与利用\n", 284 | " 参数:\n", 285 | " state: 当前游戏状态。\n", 286 | " epsilon: 探索概率,用于控制随机探索与利用的平衡。\n", 287 | " 返回:\n", 288 | " 选择的动作。\n", 289 | " \"\"\"\n", 290 | " if random.random() < self.epsilon:\n", 291 | " # 随机选择一个动作(探索)\n", 292 | " return random.randint(0, self.output_size - 1)\n", 293 | " else:\n", 294 | " # 使用模型预测动作(利用)\n", 295 | " state = torch.FloatTensor(state).unsqueeze(0) # 将状态转换为张量\n", 296 | " with torch.no_grad():\n", 297 | " q_values = self.model(state) # 获取 Q 值\n", 298 | " return torch.argmax(q_values).item() # 选择 Q 值最大的动作\n", 299 | "\n", 300 | " def train_model(self):\n", 301 | " \"\"\"\n", 302 | " 使用经验回放进行模型训练。\n", 303 | " 从缓冲区中随机采样一个批次的数据,计算损失并更新模型。\n", 304 | " \"\"\"\n", 305 | " if len(self.buffer) < self.batch_size:\n", 306 | " return # 如果缓冲区中的数据不足,则跳过训练\n", 307 | "\n", 308 | " # 从缓冲区中随机采样一个批次的数据\n", 309 | " batch = random.sample(self.buffer, self.batch_size)\n", 310 | "\n", 311 | " # 解析批次数据\n", 312 | " states = torch.FloatTensor([sample[0] for sample in batch]) # 当前状态\n", 313 | " actions = torch.LongTensor([sample[1] for sample in batch]) # 执行的动作\n", 314 | " rewards = torch.FloatTensor([sample[2] for sample in batch]) # 获得的奖励\n", 315 | " next_states = torch.FloatTensor([sample[3] for sample in batch]) # 下一个状态\n", 316 | " dones = torch.FloatTensor([sample[4] for sample in batch]) # 是否结束\n", 317 | "\n", 318 | " # 计算当前 Q 值\n", 319 | " current_q_values = self.model(states).gather(1, actions.unsqueeze(1))\n", 320 | "\n", 321 | " # 计算目标 Q 值\n", 322 | " with torch.no_grad():\n", 323 | " next_q_values = self.target_model(next_states).max(1)[0]\n", 324 | " target_q_values = rewards + self.gamma * next_q_values * (1 - dones)\n", 325 | "\n", 326 | " # 计算损失并更新模型\n", 327 | " loss = self.criterion(current_q_values.squeeze(), target_q_values)\n", 328 | " self.optimizer.zero_grad() # 清空梯度\n", 329 | " loss.backward() # 反向传播\n", 330 | " self.optimizer.step() # 更新模型参数\n", 331 | "\n", 332 | " # 更新目标网络\n", 333 | " self.train_steps += 1\n", 334 | " if self.train_steps % self.update_freq == 0:\n", 335 | " self.update_target_model()\n", 336 | " # 更新epsilon\n", 337 | " if self.epsilon > self.epsilon_min:\n", 338 | " self.epsilon *= self.epsilon_decay\n", 339 | "\n", 340 | " def add_experience(self, state, action, reward, next_state, done):\n", 341 | " \"\"\"\n", 342 | " 将经验添加到经验回放缓冲区中。\n", 343 | " 参数:\n", 344 | " state: 当前状态。\n", 345 | " action: 执行的动作。\n", 346 | " reward: 获得的奖励。\n", 347 | " next_state: 下一个状态。\n", 348 | " done: 是否结束。\n", 349 | " \"\"\"\n", 350 | " self.buffer.append((state, action, reward, next_state, done))" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "id": "5c4707d2-d15d-41c0-97f4-f843912ab8e3", 356 | "metadata": {}, 357 | "source": [ 358 | "### 主游戏逻辑" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "id": "c67db4bc-c56c-45c5-87b5-b17340e10f30", 365 | "metadata": {}, 366 | "outputs": [ 367 | { 368 | "name": "stderr", 369 | "output_type": "stream", 370 | "text": [ 371 | "C:\\Users\\PC\\AppData\\Local\\Temp\\ipykernel_3160\\2476721932.py:34: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", 372 | " self.model.load_state_dict(torch.load(weight_path))\n", 373 | "C:\\Users\\PC\\AppData\\Local\\Temp\\ipykernel_3160\\2476721932.py:35: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", 374 | " self.target_model.load_state_dict(torch.load(weight_path))\n" 375 | ] 376 | }, 377 | { 378 | "name": "stdout", 379 | "output_type": "stream", 380 | "text": [ 381 | "Loaded pre-trained weights from ../models/5_RL_Snake/best_weights.pth\n", 382 | "Episode 1/50, Score: 0, Best: 0\n" 383 | ] 384 | }, 385 | { 386 | "name": "stderr", 387 | "output_type": "stream", 388 | "text": [ 389 | "C:\\Users\\PC\\AppData\\Local\\Temp\\ipykernel_3160\\2476721932.py:103: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at C:\\cb\\pytorch_1000000000000\\work\\torch\\csrc\\utils\\tensor_new.cpp:281.)\n", 390 | " states = torch.FloatTensor([sample[0] for sample in batch]) # 当前状态\n" 391 | ] 392 | }, 393 | { 394 | "name": "stdout", 395 | "output_type": "stream", 396 | "text": [ 397 | "Episode 2/50, Score: 1, Best: 1\n", 398 | "Episode 3/50, Score: 1, Best: 1\n", 399 | "Episode 4/50, Score: 1, Best: 1\n", 400 | "Episode 5/50, Score: 1, Best: 1\n", 401 | "Episode 6/50, Score: 35, Best: 35\n", 402 | "Episode 7/50, Score: 17, Best: 35\n", 403 | "Episode 8/50, Score: 11, Best: 35\n", 404 | "Episode 9/50, Score: 1, Best: 35\n", 405 | "Episode 10/50, Score: 22, Best: 35\n", 406 | "Episode 11/50, Score: 42, Best: 42\n", 407 | "Episode 12/50, Score: 20, Best: 42\n", 408 | "Episode 13/50, Score: 6, Best: 42\n", 409 | "Episode 14/50, Score: 13, Best: 42\n", 410 | "Episode 15/50, Score: 40, Best: 42\n", 411 | "Episode 16/50, Score: 31, Best: 42\n" 412 | ] 413 | } 414 | ], 415 | "source": [ 416 | "class Game:\n", 417 | " def __init__(self, buffer_size=10000, batch_size=64):\n", 418 | " \"\"\"初始化游戏\"\"\"\n", 419 | " pygame.init()\n", 420 | " self.screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))\n", 421 | " pygame.display.set_caption(\"Snake AI Training\")\n", 422 | " self.clock = pygame.time.Clock()\n", 423 | " self.snake = Snake()\n", 424 | " self.food = Food(self.snake.positions) # 确保食物不生成在蛇身上\n", 425 | " self.ai_player = SnakeAI(buffer_size=10000, batch_size=64)\n", 426 | " self.score = 0\n", 427 | " self.best_score = 0\n", 428 | " self.scores = []\n", 429 | " self.steps = 0 # 未吃到食物的累计步数\n", 430 | " # 初始化字体\n", 431 | " self.font = pygame.font.SysFont(\"Arial\", 24) # 使用 Arial 字体,大小 24\n", 432 | "\n", 433 | " def get_direction(self, action):\n", 434 | " \"\"\"将动作索引转换为方向\"\"\"\n", 435 | " directions = [(0, -1), (0, 1), (-1, 0), (1, 0)] # 上、下、左、右\n", 436 | " return directions[action]\n", 437 | "\n", 438 | " def get_state(self):\n", 439 | " \"\"\"获取当前游戏状态\"\"\"\n", 440 | " head = self.snake.get_head_position()\n", 441 | " food = self.food.position\n", 442 | "\n", 443 | " left = (head[0] - BLOCK_SIZE, head[1])\n", 444 | " right = (head[0] + BLOCK_SIZE, head[1])\n", 445 | " up = (head[0], head[1] - BLOCK_SIZE)\n", 446 | " down = (head[0], head[1] + BLOCK_SIZE)\n", 447 | "\n", 448 | " # 检查边界\n", 449 | " danger_left = left[0] < 0 or left in self.snake.positions[1:]\n", 450 | " danger_right = right[0] >= SCREEN_WIDTH or right in self.snake.positions[1:]\n", 451 | " danger_up = up[1] < 0 or up in self.snake.positions[1:]\n", 452 | " danger_down = down[1] >= SCREEN_HEIGHT or down in self.snake.positions[1:]\n", 453 | "\n", 454 | " state = [\n", 455 | " danger_left, danger_right, danger_up, danger_down, # 四个方向的危险\n", 456 | " food[0] < head[0], food[0] > head[0], # 食物相对位置(左右)\n", 457 | " food[1] < head[1], food[1] > head[1], # 食物相对位置(上下)\n", 458 | " self.snake.direction == (0, -1), # 当前方向(上)\n", 459 | " self.snake.direction == (0, 1), # 当前方向(下)\n", 460 | " self.snake.direction == (-1, 0), # 当前方向(左)\n", 461 | " self.snake.direction == (1, 0) # 当前方向(右)\n", 462 | " ]\n", 463 | " return np.array(state, dtype=np.float32)\n", 464 | "\n", 465 | " def update(self):\n", 466 | " \"\"\"更新游戏状态和AI训练\"\"\"\n", 467 | " state = self.get_state()\n", 468 | " action = self.ai_player.get_action(state)\n", 469 | " old_direction = self.snake.direction\n", 470 | " new_direction = self.get_direction(action)\n", 471 | " \n", 472 | " # 更新蛇的方向\n", 473 | " self.snake.turn(new_direction)\n", 474 | "\n", 475 | " # 与食物的距离\n", 476 | " old_distance = np.sqrt(np.sum((np.array(self.snake.get_head_position()) - np.array(self.food.position)) ** 2))\n", 477 | " self.snake.move()\n", 478 | " \n", 479 | " # 检查游戏结束条件\n", 480 | " done = False\n", 481 | " reward = 0\n", 482 | " \n", 483 | " # 吃到食物\n", 484 | " if self.snake.get_head_position() == self.food.position:\n", 485 | " self.steps = 0 # 重置步数计数器\n", 486 | " self.score += 1\n", 487 | " self.snake.grow()\n", 488 | " self.food.respawn(self.snake.positions) # 重新生成食物\n", 489 | " reward += 10\n", 490 | " # 撞墙或撞自己\n", 491 | " elif not self.snake.check_is_alive():\n", 492 | " self.scores.append(self.snake.length)\n", 493 | " done = True\n", 494 | " reward -= 20\n", 495 | " # 计算距离变化的奖励\n", 496 | " else:\n", 497 | " new_distance = np.sqrt(np.sum((np.array(self.snake.get_head_position()) - np.array(self.food.position)) ** 2))\n", 498 | " reward += 0.2 if new_distance < old_distance else -0.1\n", 499 | " if self.steps > 10: # 长时间未吃到食物\n", 500 | " reward -= 0.1\n", 501 | "\n", 502 | " next_state = self.get_state()\n", 503 | " self.ai_player.add_experience(state, action, reward, next_state, done)\n", 504 | " self.ai_player.train_model() \n", 505 | " \n", 506 | " return done\n", 507 | "\n", 508 | " def run(self):\n", 509 | " \"\"\"主游戏循环\"\"\"\n", 510 | " for episode in range(POP_SIZE):\n", 511 | " self.snake.reset()\n", 512 | " self.food = Food(self.snake.positions)\n", 513 | " self.score = 0\n", 514 | " self.steps = 0\n", 515 | " done = False\n", 516 | "\n", 517 | " while not done:\n", 518 | " for event in pygame.event.get():\n", 519 | " if event.type == pygame.QUIT:\n", 520 | " pygame.quit()\n", 521 | " return\n", 522 | "\n", 523 | " self.steps += 1\n", 524 | " done = self.update()\n", 525 | " \n", 526 | " # 渲染画面\n", 527 | " self.screen.fill(BLACK)\n", 528 | " self.snake.draw(self.screen)\n", 529 | " self.food.draw(self.screen)\n", 530 | " pygame.display.flip()\n", 531 | " # 渲染分数\n", 532 | " score_text = self.font.render(f\"Score: {self.score}\", True, WHITE) # 白色文字\n", 533 | " self.screen.blit(score_text, (10, 10)) # 显示在左上角 (10, 10)\n", 534 | " pygame.display.flip()\n", 535 | " \n", 536 | " # 控制帧率(训练时可以更快,观看时调慢)\n", 537 | " # self.clock.tick(100 if episode < POP_SIZE - 1 else 10)\n", 538 | " self.clock.tick(80)\n", 539 | "\n", 540 | " # 更新最佳分数并保存模型\n", 541 | " if self.score > self.best_score:\n", 542 | " self.best_score = self.score\n", 543 | " torch.save(self.ai_player.model.state_dict(), '../models/5_RL_Snake/best_weights.pth')\n", 544 | " print(f\"Episode {episode + 1}/{POP_SIZE}, Score: {self.score}, Best: {self.best_score}\")\n", 545 | "\n", 546 | " pygame.quit()\n", 547 | "\n", 548 | "if __name__ == \"__main__\":\n", 549 | " game = Game(buffer_size=10000, batch_size=64)\n", 550 | " game.run()" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": null, 556 | "id": "3d4ec8e1-b7c4-41e2-b63d-63b2420de371", 557 | "metadata": {}, 558 | "outputs": [], 559 | "source": [] 560 | } 561 | ], 562 | "metadata": { 563 | "kernelspec": { 564 | "display_name": "Python 3 (ipykernel)", 565 | "language": "python", 566 | "name": "python3" 567 | }, 568 | "language_info": { 569 | "codemirror_mode": { 570 | "name": "ipython", 571 | "version": 3 572 | }, 573 | "file_extension": ".py", 574 | "mimetype": "text/x-python", 575 | "name": "python", 576 | "nbconvert_exporter": "python", 577 | "pygments_lexer": "ipython3", 578 | "version": "3.8.20" 579 | } 580 | }, 581 | "nbformat": 4, 582 | "nbformat_minor": 5 583 | } 584 | -------------------------------------------------------------------------------- /5_RL_Snake/data/AI_Snake.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/5_RL_Snake/data/AI_Snake.gif -------------------------------------------------------------------------------- /5_RL_Snake/readme.md: -------------------------------------------------------------------------------- 1 | 参考文章:[教你如何让 AI 赢得贪吃蛇游戏----强化学习(初探)](https://blog.csdn.net/qq_25218219/article/details/131382451) 2 | 3 | 一开始蛇会随机乱跑,当第一次吃到食物获得正向reward之后,蛇就会变得更加聪明,真正开始学习。 4 | 5 | 6 | 我一直都认为强化学习很难训练,生怕训练不出来,没想到训练几个epoch基本就能有不错的成绩了。可能是这个游戏比较简单吧 -------------------------------------------------------------------------------- /6_Adversarial_Attack/data/accuracy_vs_epsilon_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/6_Adversarial_Attack/data/accuracy_vs_epsilon_plot.png -------------------------------------------------------------------------------- /6_Adversarial_Attack/data/attack_cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/6_Adversarial_Attack/data/attack_cat.jpg -------------------------------------------------------------------------------- /6_Adversarial_Attack/data/attack_dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/6_Adversarial_Attack/data/attack_dog.jpg -------------------------------------------------------------------------------- /6_Adversarial_Attack/data/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/6_Adversarial_Attack/data/cat.jpg -------------------------------------------------------------------------------- /6_Adversarial_Attack/data/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/6_Adversarial_Attack/data/demo.jpg -------------------------------------------------------------------------------- /6_Adversarial_Attack/data/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/6_Adversarial_Attack/data/dog.jpg -------------------------------------------------------------------------------- /6_Adversarial_Attack/data/fgsm_examples_grid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/6_Adversarial_Attack/data/fgsm_examples_grid.png -------------------------------------------------------------------------------- /6_Adversarial_Attack/readme.md: -------------------------------------------------------------------------------- 1 | ## 参考文章: 2 | [对抗攻击(Adversarial Attack)](https://blog.csdn.net/ji_meng/article/details/123931315?spm=1001.2014.3001.5501)(基本介绍) 3 | [【对抗攻击代码实战】对抗样本的生成——FGSM](https://blog.csdn.net/ji_meng/article/details/124366646)(分别基于优化和梯度的白盒攻击示例) 4 | 5 | ## 目标: 6 | 对于一个输入样本 $x$,模型的预测为 $f(x)$,攻击者希望通过添加扰动 $\delta$ 得到对抗样本 $x' = x + \delta$,使得 $f(x') \neq f(x)$ 或达到特定的错误分类。 7 | 8 | ## 白盒攻击(white-box attack) 9 | 假设攻击者知道模型结构、权重等,针对性通过梯度生成对抗样本 10 | 为了更好地展示攻击效果,这里并没有使用数据增强的操作,把项目二的训练代码拷贝了过来,删除数据增强的操作,重新训练了一个猫狗分类模型用于攻击。 11 | 需要先运行`train.ipynb`文件,得到训练好的模型,本项目基于此模型进行攻击。 12 | 13 | ### 基于单样例的 FGSM 攻击 (Fast Gradient Sign Method) 14 | 15 | 16 | ### 基于多样例的 PGD 攻击 (Projected Gradient Descent) 17 | 18 | 19 | ### 黑盒攻击(black-box attack) 20 | 21 | 不知道模型信息。攻击者只是一个标准用户,只知道模型的输出(标签或置信度分数) -------------------------------------------------------------------------------- /7_RL_Pacman/RLAgents.py: -------------------------------------------------------------------------------- 1 | from game import Agent 2 | from game import Directions 3 | from pacman import GameState 4 | import random 5 | import os 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from collections import deque, namedtuple 12 | 13 | # 定义常量 14 | BUFFER_SIZE = 10000 # 经验回放缓冲区大小 15 | BATCH_SIZE = 64 # 小批量训练大小 16 | GAMMA = 0.99 # 折扣因子 17 | TAU = 1e-3 # 目标网络软更新参数 18 | LR = 1e-4 # 学习率 19 | UPDATE_EVERY = 10 # 每隔多少步更新网络 20 | 21 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 22 | print(device) 23 | 24 | class DQNetwork(nn.Module): 25 | """Deep Q Network架构""" 26 | 27 | def __init__(self, state_size, action_size, seed=0): 28 | """初始化参数和构建模型""" 29 | super(DQNetwork, self).__init__() 30 | self.seed = torch.manual_seed(seed) 31 | 32 | # 定义网络层 33 | self.fc1 = nn.Linear(state_size, 256) 34 | self.bn1 = nn.BatchNorm1d(256) 35 | self.dropout1 = nn.Dropout(0.2) 36 | 37 | self.fc2 = nn.Linear(256, 256) 38 | self.bn2 = nn.BatchNorm1d(256) 39 | self.dropout2 = nn.Dropout(0.2) 40 | 41 | self.fc3 = nn.Linear(256, 128) 42 | self.bn3 = nn.BatchNorm1d(128) 43 | self.dropout3 = nn.Dropout(0.2) 44 | 45 | self.fc4 = nn.Linear(128, action_size) 46 | 47 | def forward(self, state): 48 | x = F.relu(self.bn1(self.fc1(state))) 49 | x = self.dropout1(x) 50 | x = F.relu(self.bn2(self.fc2(x))) 51 | x = self.dropout2(x) 52 | x = F.relu(self.bn3(self.fc3(x))) 53 | x = self.dropout3(x) 54 | return self.fc4(x) 55 | 56 | class ReplayBuffer: 57 | """固定大小的经验回放缓冲区""" 58 | 59 | def __init__(self, buffer_size, batch_size, seed=0): 60 | """初始化经验回放缓冲区""" 61 | self.memory = deque(maxlen=buffer_size) 62 | self.batch_size = batch_size 63 | self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) 64 | self.seed = random.seed(seed) 65 | 66 | def add(self, state, action, reward, next_state, done): 67 | """添加新经验到缓冲区""" 68 | e = self.experience(state, action, reward, next_state, done) 69 | self.memory.append(e) 70 | 71 | def sample(self): 72 | """随机抽取一批经验""" 73 | experiences = random.sample(self.memory, k=self.batch_size) 74 | 75 | states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device) 76 | actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device) 77 | rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device) 78 | next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device) 79 | dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device) 80 | 81 | return (states, actions, rewards, next_states, dones) 82 | 83 | def __len__(self): 84 | """返回缓冲区当前大小""" 85 | return len(self.memory) 86 | 87 | class RLAgent(Agent): 88 | """ 89 | 基于DQN的强化学习Pacman智能体 90 | """ 91 | 92 | def __init__(self, index=0, training_mode=True, load_model='../models/7_RL_Pacman/pacman_rl_model.pth'): 93 | """初始化智能体""" 94 | super(RLAgent, self).__init__() 95 | 96 | self.index = index 97 | self.training_mode = training_mode # 是否处于训练模式 98 | 99 | # 定义行动空间 100 | self.actions = [Directions.NORTH, Directions.SOUTH, Directions.EAST, Directions.WEST, Directions.STOP] 101 | self.action_size = len(self.actions) 102 | 103 | # 状态大小(根据特征提取函数确定) 104 | self.state_features = self._extract_features(None) 105 | self.state_size = len(self.state_features) 106 | print(self.state_size) 107 | 108 | # 初始化Q网络和目标Q网络 109 | self.qnetwork_local = DQNetwork(self.state_size, self.action_size).to(device) 110 | self.qnetwork_target = DQNetwork(self.state_size, self.action_size).to(device) 111 | 112 | # 如果提供了模型路径,加载模型 113 | if load_model and os.path.exists(load_model): 114 | self.qnetwork_local.load_state_dict(torch.load(load_model)) 115 | self.qnetwork_target.load_state_dict(torch.load(load_model)) 116 | print(f"Loaded model from {load_model}") 117 | 118 | # 初始化优化器 119 | self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR) 120 | 121 | # 初始化经验回放缓冲区 122 | self.memory = ReplayBuffer(BUFFER_SIZE, BATCH_SIZE) 123 | 124 | # 初始化时间步 125 | self.t_step = 0 126 | 127 | # 探索参数(增加agent的探索能力) 128 | self.epsilon = 1.0 if not training_mode else 1.0 # 初始探索率 129 | self.epsilon_decay = 0.995 # 探索率衰减 130 | self.epsilon_min = 0.05 # 最小探索率 131 | 132 | # 跟踪上一个状态、动作和奖励 133 | self.last_state = None 134 | self.last_action = None 135 | self.last_score = 0 136 | 137 | def getAction(self, state): 138 | """根据当前游戏状态返回行动""" 139 | # 提取状态特征 140 | current_state = self._extract_features(state) 141 | 142 | # 获取合法动作 143 | legal = state.getLegalActions(self.index) 144 | 145 | # 如果是训练模式的第一次调用,初始化last_state 146 | if self.training_mode and self.last_state is None: 147 | self.last_state = current_state 148 | self.last_score = state.getScore() 149 | 150 | # 选择动作 151 | action_idx = self._select_action(current_state) 152 | move = self.actions[action_idx] 153 | 154 | # 如果选择的动作不合法,随机选择一个合法动作 155 | if move not in legal: 156 | if len(legal) > 0: 157 | move = random.choice(legal) 158 | else: 159 | move = Directions.STOP 160 | 161 | # 如果在训练模式下并且不是第一步 162 | if self.training_mode and self.last_action is not None: 163 | # 计算奖励(当前分数与上一步分数的差值) 164 | current_score = state.getScore() 165 | reward = current_score - self.last_score 166 | 167 | # 判断游戏是否结束 168 | done = state.isWin() or state.isLose() 169 | 170 | # 将经验添加到回放缓冲区(状态、动作、奖励、下一状态、是否结束) 171 | self.memory.add(self.last_state, 172 | self.actions.index(self.last_action), 173 | reward, 174 | current_state, 175 | done) 176 | 177 | # 定期学习 178 | self._learn() 179 | 180 | # 在游戏结束时进行额外的学习并保存模型参数 181 | if done: 182 | # if len(self.memory) > BATCH_SIZE: 183 | # for _ in range(10): # 多学习几次 184 | # self._learn() 185 | self.final(state) 186 | 187 | return move 188 | 189 | # 更新状态 190 | self.last_state = current_state 191 | self.last_action = move 192 | self.last_score = state.getScore() 193 | 194 | return move 195 | 196 | def _select_action(self, state): 197 | """选择动作(epsilon贪婪策略)""" 198 | # 将状态转换为tensor 199 | state = torch.from_numpy(state).float().unsqueeze(0).to(device) 200 | 201 | # epsilon策略 202 | if random.random() > self.epsilon: 203 | self.qnetwork_local.eval() 204 | with torch.no_grad(): 205 | action_values = self.qnetwork_local(state) 206 | self.qnetwork_local.train() 207 | 208 | # 选择最佳动作 209 | return torch.argmax(action_values, dim=1).squeeze().cpu().numpy() 210 | 211 | else: 212 | # 随机选择动作 213 | return random.randrange(self.action_size) 214 | 215 | def _learn(self): 216 | """从经验中学习,更新Q网络""" 217 | # 增加时间步并检查是否应该学习 218 | self.t_step = (self.t_step + 1) % UPDATE_EVERY 219 | if self.t_step != 0 or len(self.memory) < BATCH_SIZE: 220 | return 221 | 222 | # 从经验回放缓冲区中采样 223 | experiences = self.memory.sample() 224 | states, actions, rewards, next_states, dones = experiences 225 | 226 | # 获取目标值 227 | self.qnetwork_target.eval() 228 | with torch.no_grad(): 229 | Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1) 230 | self.qnetwork_target.train() 231 | 232 | # 计算目标Q值 233 | Q_targets = rewards + (GAMMA * Q_targets_next * (1 - dones)) 234 | 235 | # 获取当前Q值估计 236 | Q_expected = self.qnetwork_local(states).gather(1, actions) 237 | 238 | # 使用Huber损失函数 239 | loss = F.smooth_l1_loss(Q_expected, Q_targets) 240 | 241 | # 优化模型 242 | self.optimizer.zero_grad() 243 | loss.backward() 244 | self.optimizer.step() 245 | 246 | # 软更新目标网络 247 | self._soft_update(self.qnetwork_local, self.qnetwork_target) 248 | 249 | # 衰减探索率 250 | self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay) 251 | 252 | def _soft_update(self, local_model, target_model): 253 | """软更新目标网络参数""" 254 | for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): 255 | target_param.data.copy_(TAU * local_param.data + (1.0 - TAU) * target_param.data) 256 | 257 | def _extract_features(self, state): 258 | """从游戏状态提取特征向量""" 259 | if state is None: 260 | # 使用初始状态来确定特征维度 261 | layout_name = 'mediumClassic' # 可以根据需要修改为其他布局 262 | from layout import getLayout 263 | layout = getLayout(layout_name) 264 | 265 | init_state = GameState() 266 | init_state.initialize(layout, 2) # 假设有2个幽灵 267 | return self._extract_features(init_state) 268 | 269 | # 获取地图尺寸 270 | walls = state.getWalls() 271 | width, height = walls.width, walls.height 272 | 273 | # 创建一个二维矩阵表示地图状态 274 | # 0: blank, 1: walls, 2: food, 3: capsules, 4: Pacman, 5: ghosts 275 | grid_state = np.zeros((width, height)) 276 | 277 | # 填充墙壁 278 | for x in range(width): 279 | for y in range(height): 280 | if walls[x][y]: 281 | grid_state[x][y] = 1 282 | 283 | # 填充食物 284 | food = state.getFood() 285 | for x in range(width): 286 | for y in range(height): 287 | if food[x][y]: 288 | grid_state[x][y] = 2 289 | 290 | # 填充胶囊 291 | capsules = state.getCapsules() 292 | for x, y in capsules: 293 | grid_state[int(x)][int(y)] = 3 294 | 295 | # 填充Pacman 296 | pacman_x, pacman_y = state.getPacmanPosition() 297 | grid_state[int(pacman_x)][int(pacman_y)] = 4 298 | 299 | # 填充幽灵 300 | ghost_states = state.getGhostStates() 301 | for ghost in ghost_states: 302 | ghost_x, ghost_y = ghost.getPosition() 303 | grid_state[int(ghost_x)][int(ghost_y)] = 5 304 | 305 | # 展平为一维向量 306 | grid_features = grid_state.flatten() 307 | 308 | # 添加额外的非空间特征 309 | 310 | # 1. 得分 311 | score_enc = np.array([state.getScore()]) 312 | 313 | # 2. 剩余食物数量 314 | food_count_enc = np.array([state.getNumFood()]) 315 | 316 | # 3. 剩余胶囊数量 317 | capsule_count_enc = np.array([len(capsules)]) # 假设最多4个胶囊 318 | 319 | # 将所有特征连接成一个向量 320 | features = np.concatenate([ 321 | grid_features, # 地图状态 322 | score_enc, # 得分 323 | food_count_enc, # 剩余食物数量 324 | capsule_count_enc, # 剩余胶囊数量 325 | ]) 326 | 327 | return features.astype(np.float32) 328 | 329 | def save_model(self, filename): 330 | """保存模型""" 331 | torch.save(self.qnetwork_local.state_dict(), filename) 332 | print(f"Model saved to {filename}") 333 | 334 | def final(self, state): 335 | """游戏结束时调用""" 336 | if self.training_mode: 337 | # 保存模型 338 | self.save_model('../models/7_RL_Pacman/pacman_rl_model.pth') 339 | 340 | 341 | # 训练模式智能体 342 | class TrainingRLAgent(RLAgent): 343 | def __init__(self, index=0): 344 | super(TrainingRLAgent, self).__init__(index, training_mode=True) 345 | 346 | # 测试模式智能体(加载训练好的模型) 347 | class TestedRLAgent(RLAgent): 348 | def __init__(self, index=0): 349 | # 加载训练好的模型 350 | super(TestedRLAgent, self).__init__(index, training_mode=False, load_model='../models/7_RL_Pacman/pacman_rl_model.pth') 351 | -------------------------------------------------------------------------------- /7_RL_Pacman/data/example_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/7_RL_Pacman/data/example_image.png -------------------------------------------------------------------------------- /7_RL_Pacman/data/loadAgent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/king-wang123/PyTorch-DeepLearning-Practice-Projects/0a8a6987a505960dd92806f116a6b9ddd7600be1/7_RL_Pacman/data/loadAgent.png -------------------------------------------------------------------------------- /7_RL_Pacman/game.py: -------------------------------------------------------------------------------- 1 | # game.py 2 | # ------- 3 | # Licensing Information: You are free to use or extend these projects for 4 | # educational purposes provided that (1) you do not distribute or publish 5 | # solutions, (2) you retain this notice, and (3) you provide clear 6 | # attribution to UC Berkeley, including a link to http://ai.berkeley.edu. 7 | # 8 | # Attribution Information: The Pacman AI projects were developed at UC Berkeley. 9 | # The core projects and autograders were primarily created by John DeNero 10 | # (denero@cs.berkeley.edu) and Dan Klein (klein@cs.berkeley.edu). 11 | # Student side autograding was added by Brad Miller, Nick Hay, and 12 | # Pieter Abbeel (pabbeel@cs.berkeley.edu). 13 | 14 | 15 | # game.py 16 | # ------- 17 | # Licensing Information: Please do not distribute or publish solutions to this 18 | # project. You are free to use and extend these projects for educational 19 | # purposes. The Pacman AI projects were developed at UC Berkeley, primarily by 20 | # John DeNero (denero@cs.berkeley.edu) and Dan Klein (klein@cs.berkeley.edu). 21 | # For more info, see http://inst.eecs.berkeley.edu/~cs188/sp09/pacman.html 22 | 23 | from util import * 24 | import time, os 25 | import traceback 26 | import sys 27 | 28 | ####################### 29 | # Parts worth reading # 30 | ####################### 31 | 32 | class Agent: 33 | """ 34 | An agent must define a getAction method, but may also define the 35 | following methods which will be called if they exist: 36 | 37 | def registerInitialState(self, state): # inspects the starting state 38 | """ 39 | def __init__(self, index=0): 40 | self.index = index 41 | 42 | def getAction(self, state): 43 | """ 44 | The Agent will receive a GameState (from either {pacman, capture, sonar}.py) and 45 | must return an action from Directions.{North, South, East, West, Stop} 46 | """ 47 | raiseNotDefined() 48 | 49 | class Directions: 50 | NORTH = 'North' 51 | SOUTH = 'South' 52 | EAST = 'East' 53 | WEST = 'West' 54 | STOP = 'Stop' 55 | 56 | LEFT = {NORTH: WEST, 57 | SOUTH: EAST, 58 | EAST: NORTH, 59 | WEST: SOUTH, 60 | STOP: STOP} 61 | 62 | RIGHT = dict([(y,x) for x, y in LEFT.items()]) 63 | 64 | REVERSE = {NORTH: SOUTH, 65 | SOUTH: NORTH, 66 | EAST: WEST, 67 | WEST: EAST, 68 | STOP: STOP} 69 | 70 | class Configuration: 71 | """ 72 | A Configuration holds the (x,y) coordinate of a character, along with its 73 | traveling direction. 74 | 75 | The convention for positions, like a graph, is that (0,0) is the lower left corner, x increases 76 | horizontally and y increases vertically. Therefore, north is the direction of increasing y, or (0,1). 77 | """ 78 | 79 | def __init__(self, pos, direction): 80 | self.pos = pos 81 | self.direction = direction 82 | 83 | def getPosition(self): 84 | return (self.pos) 85 | 86 | def getDirection(self): 87 | return self.direction 88 | 89 | def isInteger(self): 90 | x,y = self.pos 91 | return x == int(x) and y == int(y) 92 | 93 | def __eq__(self, other): 94 | if other == None: return False 95 | return (self.pos == other.pos and self.direction == other.direction) 96 | 97 | def __hash__(self): 98 | x = hash(self.pos) 99 | y = hash(self.direction) 100 | return hash(x + 13 * y) 101 | 102 | def __str__(self): 103 | return "(x,y)="+str(self.pos)+", "+str(self.direction) 104 | 105 | def generateSuccessor(self, vector): 106 | """ 107 | Generates a new configuration reached by translating the current 108 | configuration by the action vector. This is a low-level call and does 109 | not attempt to respect the legality of the movement. 110 | 111 | Actions are movement vectors. 112 | """ 113 | x, y= self.pos 114 | dx, dy = vector 115 | direction = Actions.vectorToDirection(vector) 116 | if direction == Directions.STOP: 117 | direction = self.direction # There is no stop direction 118 | return Configuration((x + dx, y+dy), direction) 119 | 120 | class AgentState: 121 | """ 122 | AgentStates hold the state of an agent (configuration, speed, scared, etc). 123 | """ 124 | 125 | def __init__( self, startConfiguration, isPacman ): 126 | self.start = startConfiguration 127 | self.configuration = startConfiguration 128 | self.isPacman = isPacman 129 | self.scaredTimer = 0 130 | self.numCarrying = 0 131 | self.numReturned = 0 132 | 133 | def __str__( self ): 134 | if self.isPacman: 135 | return "Pacman: " + str( self.configuration ) 136 | else: 137 | return "Ghost: " + str( self.configuration ) 138 | 139 | def __eq__( self, other ): 140 | if other == None: 141 | return False 142 | return self.configuration == other.configuration and self.scaredTimer == other.scaredTimer 143 | 144 | def __hash__(self): 145 | return hash(hash(self.configuration) + 13 * hash(self.scaredTimer)) 146 | 147 | def copy( self ): 148 | state = AgentState( self.start, self.isPacman ) 149 | state.configuration = self.configuration 150 | state.scaredTimer = self.scaredTimer 151 | state.numCarrying = self.numCarrying 152 | state.numReturned = self.numReturned 153 | return state 154 | 155 | def getPosition(self): 156 | if self.configuration == None: return None 157 | return self.configuration.getPosition() 158 | 159 | def getDirection(self): 160 | return self.configuration.getDirection() 161 | 162 | class Grid: 163 | """ 164 | A 2-dimensional array of objects backed by a list of lists. Data is accessed 165 | via grid[x][y] where (x,y) are positions on a Pacman map with x horizontal, 166 | y vertical and the origin (0,0) in the bottom left corner. 167 | 168 | The __str__ method constructs an output that is oriented like a pacman board. 169 | """ 170 | def __init__(self, width, height, initialValue=False, bitRepresentation=None): 171 | if initialValue not in [False, True]: raise Exception('Grids can only contain booleans') 172 | self.CELLS_PER_INT = 30 173 | 174 | self.width = width 175 | self.height = height 176 | self.data = [[initialValue for y in range(height)] for x in range(width)] 177 | if bitRepresentation: 178 | self._unpackBits(bitRepresentation) 179 | 180 | def __getitem__(self, i): 181 | return self.data[i] 182 | 183 | def __setitem__(self, key, item): 184 | self.data[key] = item 185 | 186 | def __str__(self): 187 | out = [[str(self.data[x][y])[0] for x in range(self.width)] for y in range(self.height)] 188 | out.reverse() 189 | return '\n'.join([''.join(x) for x in out]) 190 | 191 | def __eq__(self, other): 192 | if other == None: return False 193 | return self.data == other.data 194 | 195 | def __hash__(self): 196 | # return hash(str(self)) 197 | base = 1 198 | h = 0 199 | for l in self.data: 200 | for i in l: 201 | if i: 202 | h += base 203 | base *= 2 204 | return hash(h) 205 | 206 | def copy(self): 207 | g = Grid(self.width, self.height) 208 | g.data = [x[:] for x in self.data] 209 | return g 210 | 211 | def deepCopy(self): 212 | return self.copy() 213 | 214 | def shallowCopy(self): 215 | g = Grid(self.width, self.height) 216 | g.data = self.data 217 | return g 218 | 219 | def count(self, item =True ): 220 | return sum([x.count(item) for x in self.data]) 221 | 222 | def asList(self, key = True): 223 | list = [] 224 | for x in range(self.width): 225 | for y in range(self.height): 226 | if self[x][y] == key: list.append( (x,y) ) 227 | return list 228 | 229 | def packBits(self): 230 | """ 231 | Returns an efficient int list representation 232 | 233 | (width, height, bitPackedInts...) 234 | """ 235 | bits = [self.width, self.height] 236 | currentInt = 0 237 | for i in range(self.height * self.width): 238 | bit = self.CELLS_PER_INT - (i % self.CELLS_PER_INT) - 1 239 | x, y = self._cellIndexToPosition(i) 240 | if self[x][y]: 241 | currentInt += 2 ** bit 242 | if (i + 1) % self.CELLS_PER_INT == 0: 243 | bits.append(currentInt) 244 | currentInt = 0 245 | bits.append(currentInt) 246 | return tuple(bits) 247 | 248 | def _cellIndexToPosition(self, index): 249 | x = index // self.height 250 | y = index % self.height 251 | return x, y 252 | 253 | def _unpackBits(self, bits): 254 | """ 255 | Fills in data from a bit-level representation 256 | """ 257 | cell = 0 258 | for packed in bits: 259 | for bit in self._unpackInt(packed, self.CELLS_PER_INT): 260 | if cell == self.width * self.height: break 261 | x, y = self._cellIndexToPosition(cell) 262 | self[x][y] = bit 263 | cell += 1 264 | 265 | def _unpackInt(self, packed, size): 266 | bools = [] 267 | if packed < 0: raise ValueError("must be a positive integer") 268 | for i in range(size): 269 | n = 2 ** (self.CELLS_PER_INT - i - 1) 270 | if packed >= n: 271 | bools.append(True) 272 | packed -= n 273 | else: 274 | bools.append(False) 275 | return bools 276 | 277 | def reconstituteGrid(bitRep): 278 | if type(bitRep) is not type((1,2)): 279 | return bitRep 280 | width, height = bitRep[:2] 281 | return Grid(width, height, bitRepresentation= bitRep[2:]) 282 | 283 | #################################### 284 | # Parts you shouldn't have to read # 285 | #################################### 286 | 287 | class Actions: 288 | """ 289 | A collection of static methods for manipulating move actions. 290 | """ 291 | # Directions 292 | _directions = {Directions.NORTH: (0, 1), 293 | Directions.SOUTH: (0, -1), 294 | Directions.EAST: (1, 0), 295 | Directions.WEST: (-1, 0), 296 | Directions.STOP: (0, 0)} 297 | 298 | _directionsAsList = _directions.items() 299 | 300 | TOLERANCE = .001 301 | 302 | def reverseDirection(action): 303 | if action == Directions.NORTH: 304 | return Directions.SOUTH 305 | if action == Directions.SOUTH: 306 | return Directions.NORTH 307 | if action == Directions.EAST: 308 | return Directions.WEST 309 | if action == Directions.WEST: 310 | return Directions.EAST 311 | return action 312 | reverseDirection = staticmethod(reverseDirection) 313 | 314 | def vectorToDirection(vector): 315 | dx, dy = vector 316 | if dy > 0: 317 | return Directions.NORTH 318 | if dy < 0: 319 | return Directions.SOUTH 320 | if dx < 0: 321 | return Directions.WEST 322 | if dx > 0: 323 | return Directions.EAST 324 | return Directions.STOP 325 | vectorToDirection = staticmethod(vectorToDirection) 326 | 327 | def directionToVector(direction, speed = 1.0): 328 | dx, dy = Actions._directions[direction] 329 | return (dx * speed, dy * speed) 330 | directionToVector = staticmethod(directionToVector) 331 | 332 | def getPossibleActions(config, walls): 333 | possible = [] 334 | x, y = config.pos 335 | x_int, y_int = int(x + 0.5), int(y + 0.5) 336 | 337 | # In between grid points, all agents must continue straight 338 | if (abs(x - x_int) + abs(y - y_int) > Actions.TOLERANCE): 339 | return [config.getDirection()] 340 | 341 | for dir, vec in Actions._directionsAsList: 342 | dx, dy = vec 343 | next_y = y_int + dy 344 | next_x = x_int + dx 345 | if not walls[next_x][next_y]: possible.append(dir) 346 | 347 | return possible 348 | 349 | getPossibleActions = staticmethod(getPossibleActions) 350 | 351 | def getLegalNeighbors(position, walls): 352 | x,y = position 353 | x_int, y_int = int(x + 0.5), int(y + 0.5) 354 | neighbors = [] 355 | for dir, vec in Actions._directionsAsList: 356 | dx, dy = vec 357 | next_x = x_int + dx 358 | if next_x < 0 or next_x == walls.width: continue 359 | next_y = y_int + dy 360 | if next_y < 0 or next_y == walls.height: continue 361 | if not walls[next_x][next_y]: neighbors.append((next_x, next_y)) 362 | return neighbors 363 | getLegalNeighbors = staticmethod(getLegalNeighbors) 364 | 365 | def getSuccessor(position, action): 366 | dx, dy = Actions.directionToVector(action) 367 | x, y = position 368 | return (x + dx, y + dy) 369 | getSuccessor = staticmethod(getSuccessor) 370 | 371 | class GameStateData: 372 | """ 373 | 374 | """ 375 | def __init__( self, prevState = None ): 376 | """ 377 | Generates a new data packet by copying information from its predecessor. 378 | """ 379 | if prevState != None: 380 | self.food = prevState.food.shallowCopy() 381 | self.capsules = prevState.capsules[:] 382 | self.agentStates = self.copyAgentStates( prevState.agentStates ) 383 | self.layout = prevState.layout 384 | self._eaten = prevState._eaten 385 | self.score = prevState.score 386 | 387 | self._foodEaten = None 388 | self._foodAdded = None 389 | self._capsuleEaten = None 390 | self._agentMoved = None 391 | self._lose = False 392 | self._win = False 393 | self.scoreChange = 0 394 | 395 | def deepCopy( self ): 396 | state = GameStateData( self ) 397 | state.food = self.food.deepCopy() 398 | state.layout = self.layout.deepCopy() 399 | state._agentMoved = self._agentMoved 400 | state._foodEaten = self._foodEaten 401 | state._foodAdded = self._foodAdded 402 | state._capsuleEaten = self._capsuleEaten 403 | return state 404 | 405 | def copyAgentStates( self, agentStates ): 406 | copiedStates = [] 407 | for agentState in agentStates: 408 | copiedStates.append( agentState.copy() ) 409 | return copiedStates 410 | 411 | def __eq__( self, other ): 412 | """ 413 | Allows two states to be compared. 414 | """ 415 | if other == None: return False 416 | # TODO Check for type of other 417 | if not self.agentStates == other.agentStates: return False 418 | if not self.food == other.food: return False 419 | if not self.capsules == other.capsules: return False 420 | if not self.score == other.score: return False 421 | return True 422 | 423 | def __hash__( self ): 424 | """ 425 | Allows states to be keys of dictionaries. 426 | """ 427 | for i, state in enumerate( self.agentStates ): 428 | try: 429 | int(hash(state)) 430 | except TypeError as e: 431 | print(e) 432 | #hash(state) 433 | return int((hash(tuple(self.agentStates)) + 13*hash(self.food) + 113* hash(tuple(self.capsules)) + 7 * hash(self.score)) % 1048575 ) 434 | 435 | def __str__( self ): 436 | width, height = self.layout.width, self.layout.height 437 | map = Grid(width, height) 438 | if type(self.food) == type((1,2)): 439 | self.food = reconstituteGrid(self.food) 440 | for x in range(width): 441 | for y in range(height): 442 | food, walls = self.food, self.layout.walls 443 | map[x][y] = self._foodWallStr(food[x][y], walls[x][y]) 444 | 445 | for agentState in self.agentStates: 446 | if agentState == None: continue 447 | if agentState.configuration == None: continue 448 | x,y = [int( i ) for i in nearestPoint( agentState.configuration.pos )] 449 | agent_dir = agentState.configuration.direction 450 | if agentState.isPacman: 451 | map[x][y] = self._pacStr( agent_dir ) 452 | else: 453 | map[x][y] = self._ghostStr( agent_dir ) 454 | 455 | for x, y in self.capsules: 456 | map[x][y] = 'o' 457 | 458 | return str(map) + ("\nScore: %d\n" % self.score) 459 | 460 | def _foodWallStr( self, hasFood, hasWall ): 461 | if hasFood: 462 | return '.' 463 | elif hasWall: 464 | return '%' 465 | else: 466 | return ' ' 467 | 468 | def _pacStr( self, dir ): 469 | if dir == Directions.NORTH: 470 | return 'v' 471 | if dir == Directions.SOUTH: 472 | return '^' 473 | if dir == Directions.WEST: 474 | return '>' 475 | return '<' 476 | 477 | def _ghostStr( self, dir ): 478 | return 'G' 479 | if dir == Directions.NORTH: 480 | return 'M' 481 | if dir == Directions.SOUTH: 482 | return 'W' 483 | if dir == Directions.WEST: 484 | return '3' 485 | return 'E' 486 | 487 | def initialize( self, layout, numGhostAgents ): 488 | """ 489 | Creates an initial game state from a layout array (see layout.py). 490 | """ 491 | self.food = layout.food.copy() 492 | #self.capsules = [] 493 | self.capsules = layout.capsules[:] 494 | self.layout = layout 495 | self.score = 0 496 | self.scoreChange = 0 497 | 498 | self.agentStates = [] 499 | numGhosts = 0 500 | for isPacman, pos in layout.agentPositions: 501 | if not isPacman: 502 | if numGhosts == numGhostAgents: continue # Max ghosts reached already 503 | else: numGhosts += 1 504 | self.agentStates.append( AgentState( Configuration( pos, Directions.STOP), isPacman) ) 505 | self._eaten = [False for a in self.agentStates] 506 | 507 | try: 508 | import boinc 509 | _BOINC_ENABLED = True 510 | except: 511 | _BOINC_ENABLED = False 512 | 513 | class Game: 514 | """ 515 | The Game manages the control flow, soliciting actions from agents. 516 | """ 517 | 518 | def __init__( self, agents, display, rules, startingIndex=0, muteAgents=False, catchExceptions=False ): 519 | self.agentCrashed = False 520 | self.agents = agents 521 | self.display = display 522 | self.rules = rules 523 | self.startingIndex = startingIndex 524 | self.gameOver = False 525 | self.muteAgents = muteAgents 526 | self.catchExceptions = catchExceptions 527 | self.moveHistory = [] 528 | self.totalAgentTimes = [0 for agent in agents] 529 | self.totalAgentTimeWarnings = [0 for agent in agents] 530 | self.agentTimeout = False 531 | import io 532 | self.agentOutput = [io.StringIO() for agent in agents] 533 | 534 | def getProgress(self): 535 | if self.gameOver: 536 | return 1.0 537 | else: 538 | return self.rules.getProgress(self) 539 | 540 | def _agentCrash( self, agentIndex, quiet=False): 541 | "Helper method for handling agent crashes" 542 | if not quiet: traceback.print_exc() 543 | self.gameOver = True 544 | self.agentCrashed = True 545 | self.rules.agentCrash(self, agentIndex) 546 | 547 | OLD_STDOUT = None 548 | OLD_STDERR = None 549 | 550 | def mute(self, agentIndex): 551 | if not self.muteAgents: return 552 | global OLD_STDOUT, OLD_STDERR 553 | import io 554 | OLD_STDOUT = sys.stdout 555 | OLD_STDERR = sys.stderr 556 | sys.stdout = self.agentOutput[agentIndex] 557 | sys.stderr = self.agentOutput[agentIndex] 558 | 559 | def unmute(self): 560 | if not self.muteAgents: return 561 | global OLD_STDOUT, OLD_STDERR 562 | # Revert stdout/stderr to originals 563 | sys.stdout = OLD_STDOUT 564 | sys.stderr = OLD_STDERR 565 | 566 | 567 | def run( self ): 568 | """ 569 | Main control loop for game play. 570 | """ 571 | self.display.initialize(self.state.data) 572 | self.numMoves = 0 573 | 574 | ###self.display.initialize(self.state.makeObservation(1).data) 575 | # inform learning agents of the game start 576 | for i in range(len(self.agents)): 577 | agent = self.agents[i] 578 | if not agent: 579 | self.mute(i) 580 | # this is a null agent, meaning it failed to load 581 | # the other team wins 582 | print("Agent %d failed to load" % i, file=sys.stderr) 583 | self.unmute() 584 | self._agentCrash(i, quiet=True) 585 | return 586 | if ("registerInitialState" in dir(agent)): 587 | self.mute(i) 588 | if self.catchExceptions: 589 | try: 590 | timed_func = TimeoutFunction(agent.registerInitialState, int(self.rules.getMaxStartupTime(i))) 591 | try: 592 | start_time = time.time() 593 | timed_func(self.state.deepCopy()) 594 | time_taken = time.time() - start_time 595 | self.totalAgentTimes[i] += time_taken 596 | except TimeoutFunctionException: 597 | print("Agent %d ran out of time on startup!" % i, file=sys.stderr) 598 | self.unmute() 599 | self.agentTimeout = True 600 | self._agentCrash(i, quiet=True) 601 | return 602 | except Exception as data: 603 | self._agentCrash(i, quiet=False) 604 | self.unmute() 605 | return 606 | else: 607 | agent.registerInitialState(self.state.deepCopy()) 608 | ## TODO: could this exceed the total time 609 | self.unmute() 610 | 611 | agentIndex = self.startingIndex 612 | numAgents = len( self.agents ) 613 | 614 | while not self.gameOver: 615 | # Fetch the next agent 616 | agent = self.agents[agentIndex] 617 | move_time = 0 618 | skip_action = False 619 | # Generate an observation of the state 620 | if 'observationFunction' in dir( agent ): 621 | self.mute(agentIndex) 622 | if self.catchExceptions: 623 | try: 624 | timed_func = TimeoutFunction(agent.observationFunction, int(self.rules.getMoveTimeout(agentIndex))) 625 | try: 626 | start_time = time.time() 627 | observation = timed_func(self.state.deepCopy()) 628 | except TimeoutFunctionException: 629 | skip_action = True 630 | move_time += time.time() - start_time 631 | self.unmute() 632 | except Exception as data: 633 | self._agentCrash(agentIndex, quiet=False) 634 | self.unmute() 635 | return 636 | else: 637 | observation = agent.observationFunction(self.state.deepCopy()) 638 | self.unmute() 639 | else: 640 | observation = self.state.deepCopy() 641 | 642 | # Solicit an action 643 | action = None 644 | self.mute(agentIndex) 645 | if self.catchExceptions: 646 | try: 647 | timed_func = TimeoutFunction(agent.getAction, int(self.rules.getMoveTimeout(agentIndex)) - int(move_time)) 648 | try: 649 | start_time = time.time() 650 | if skip_action: 651 | raise TimeoutFunctionException() 652 | action = timed_func( observation ) 653 | except TimeoutFunctionException: 654 | print("Agent %d timed out on a single move!" % agentIndex, file=sys.stderr) 655 | self.agentTimeout = True 656 | self._agentCrash(agentIndex, quiet=True) 657 | self.unmute() 658 | return 659 | 660 | move_time += time.time() - start_time 661 | 662 | if move_time > self.rules.getMoveWarningTime(agentIndex): 663 | self.totalAgentTimeWarnings[agentIndex] += 1 664 | print("Agent %d took too long to make a move! This is warning %d" % (agentIndex, self.totalAgentTimeWarnings[agentIndex]), file=sys.stderr) 665 | if self.totalAgentTimeWarnings[agentIndex] > self.rules.getMaxTimeWarnings(agentIndex): 666 | print("Agent %d exceeded the maximum number of warnings: %d" % (agentIndex, self.totalAgentTimeWarnings[agentIndex]), file=sys.stderr) 667 | self.agentTimeout = True 668 | self._agentCrash(agentIndex, quiet=True) 669 | self.unmute() 670 | return 671 | 672 | self.totalAgentTimes[agentIndex] += move_time 673 | #print("Agent: %d, time: %f, total: %f" % (agentIndex, move_time, self.totalAgentTimes[agentIndex])) 674 | if self.totalAgentTimes[agentIndex] > self.rules.getMaxTotalTime(agentIndex): 675 | print("Agent %d ran out of time! (time: %1.2f)" % (agentIndex, self.totalAgentTimes[agentIndex]), file=sys.stderr) 676 | self.agentTimeout = True 677 | self._agentCrash(agentIndex, quiet=True) 678 | self.unmute() 679 | return 680 | self.unmute() 681 | except Exception as data: 682 | self._agentCrash(agentIndex) 683 | self.unmute() 684 | return 685 | else: 686 | action = agent.getAction(observation) 687 | # print(f'Agent {agentIndex} ; observation: {observation} ; action: {action}') 688 | self.unmute() 689 | 690 | # Execute the action 691 | self.moveHistory.append( (agentIndex, action) ) 692 | if self.catchExceptions: 693 | try: 694 | self.state = self.state.generateSuccessor( agentIndex, action ) 695 | except Exception as data: 696 | self.mute(agentIndex) 697 | self._agentCrash(agentIndex) 698 | self.unmute() 699 | return 700 | else: 701 | self.state = self.state.generateSuccessor( agentIndex, action ) 702 | 703 | # Change the display 704 | self.display.update( self.state.data ) 705 | ###idx = agentIndex - agentIndex % 2 + 1 706 | ###self.display.update( self.state.makeObservation(idx).data ) 707 | 708 | # Allow for game specific conditions (winning, losing, etc.) 709 | self.rules.process(self.state, self) 710 | # Track progress 711 | if agentIndex == numAgents + 1: self.numMoves += 1 712 | # Next agent 713 | agentIndex = ( agentIndex + 1 ) % numAgents 714 | 715 | if _BOINC_ENABLED: 716 | boinc.set_fraction_done(self.getProgress()) 717 | 718 | # inform a learning agent of the game result 719 | for agentIndex, agent in enumerate(self.agents): 720 | if "final" in dir( agent ) : 721 | try: 722 | self.mute(agentIndex) 723 | agent.final( self.state ) 724 | self.unmute() 725 | except Exception as data: 726 | if not self.catchExceptions: raise data 727 | self._agentCrash(agentIndex) 728 | self.unmute() 729 | return 730 | self.display.finish() 731 | -------------------------------------------------------------------------------- /7_RL_Pacman/ghostAgents.py: -------------------------------------------------------------------------------- 1 | # ghostAgents.py 2 | # -------------- 3 | # Licensing Information: You are free to use or extend these projects for 4 | # educational purposes provided that (1) you do not distribute or publish 5 | # solutions, (2) you retain this notice, and (3) you provide clear 6 | # attribution to UC Berkeley, including a link to http://ai.berkeley.edu. 7 | # 8 | # Attribution Information: The Pacman AI projects were developed at UC Berkeley. 9 | # The core projects and autograders were primarily created by John DeNero 10 | # (denero@cs.berkeley.edu) and Dan Klein (klein@cs.berkeley.edu). 11 | # Student side autograding was added by Brad Miller, Nick Hay, and 12 | # Pieter Abbeel (pabbeel@cs.berkeley.edu). 13 | 14 | 15 | from game import Agent 16 | from game import Actions 17 | from game import Directions 18 | import random 19 | from util import manhattanDistance 20 | import util 21 | 22 | class GhostAgent( Agent ): 23 | def __init__( self, index ): 24 | self.index = index 25 | 26 | def getAction( self, state ): 27 | dist = self.getDistribution(state) 28 | if len(dist) == 0: 29 | return Directions.STOP 30 | else: 31 | return util.chooseFromDistribution( dist ) 32 | 33 | def getDistribution(self, state): 34 | "Returns a Counter encoding a distribution over actions from the provided state." 35 | util.raiseNotDefined() 36 | 37 | class RandomGhost( GhostAgent ): 38 | "A ghost that chooses a legal action uniformly at random." 39 | def getDistribution( self, state ): 40 | dist = util.Counter() 41 | for a in state.getLegalActions( self.index ): dist[a] = 1.0 42 | dist.normalize() 43 | return dist 44 | 45 | class DirectionalGhost( GhostAgent ): 46 | "A ghost that prefers to rush Pacman, or flee when scared." 47 | def __init__( self, index, prob_attack=0.8, prob_scaredFlee=0.8 ): 48 | self.index = index 49 | self.prob_attack = prob_attack 50 | self.prob_scaredFlee = prob_scaredFlee 51 | 52 | def getDistribution( self, state ): 53 | # Read variables from state 54 | ghostState = state.getGhostState( self.index ) 55 | legalActions = state.getLegalActions( self.index ) 56 | pos = state.getGhostPosition( self.index ) 57 | isScared = ghostState.scaredTimer > 0 58 | 59 | speed = 1 60 | if isScared: speed = 0.5 61 | 62 | actionVectors = [Actions.directionToVector( a, speed ) for a in legalActions] 63 | newPositions = [( pos[0]+a[0], pos[1]+a[1] ) for a in actionVectors] 64 | pacmanPosition = state.getPacmanPosition() 65 | 66 | # Select best actions given the state 67 | distancesToPacman = [manhattanDistance( pos, pacmanPosition ) for pos in newPositions] 68 | if isScared: 69 | bestScore = max( distancesToPacman ) 70 | bestProb = self.prob_scaredFlee 71 | else: 72 | bestScore = min( distancesToPacman ) 73 | bestProb = self.prob_attack 74 | bestActions = [action for action, distance in zip( legalActions, distancesToPacman ) if distance == bestScore] 75 | 76 | # Construct distribution 77 | dist = util.Counter() 78 | for a in bestActions: dist[a] = bestProb / len(bestActions) 79 | for a in legalActions: dist[a] += ( 1-bestProb ) / len(legalActions) 80 | dist.normalize() 81 | return dist 82 | -------------------------------------------------------------------------------- /7_RL_Pacman/graphicsUtils.py: -------------------------------------------------------------------------------- 1 | # graphicsUtils.py 2 | # ---------------- 3 | # Licensing Information: You are free to use or extend these projects for 4 | # educational purposes provided that (1) you do not distribute or publish 5 | # solutions, (2) you retain this notice, and (3) you provide clear 6 | # attribution to UC Berkeley, including a link to http://ai.berkeley.edu. 7 | # 8 | # Attribution Information: The Pacman AI projects were developed at UC Berkeley. 9 | # The core projects and autograders were primarily created by John DeNero 10 | # (denero@cs.berkeley.edu) and Dan Klein (klein@cs.berkeley.edu). 11 | # Student side autograding was added by Brad Miller, Nick Hay, and 12 | # Pieter Abbeel (pabbeel@cs.berkeley.edu). 13 | 14 | 15 | import sys 16 | import math 17 | import random 18 | import string 19 | import time 20 | import types 21 | import tkinter 22 | import os.path 23 | 24 | _Windows = sys.platform == 'win32' # True if on Win95/98/NT 25 | 26 | _root_window = None # The root window for graphics output 27 | _canvas = None # The canvas which holds graphics 28 | _canvas_xs = None # Size of canvas object 29 | _canvas_ys = None 30 | _canvas_x = None # Current position on canvas 31 | _canvas_y = None 32 | _canvas_col = None # Current colour (set to black below) 33 | _canvas_tsize = 12 34 | _canvas_tserifs = 0 35 | 36 | def formatColor(r, g, b): 37 | return '#%02x%02x%02x' % (int(r * 255), int(g * 255), int(b * 255)) 38 | 39 | def colorToVector(color): 40 | return list(map(lambda x: int(x, 16) / 256.0, [color[1:3], color[3:5], color[5:7]])) 41 | 42 | if _Windows: 43 | _canvas_tfonts = ['times new roman', 'lucida console'] 44 | else: 45 | _canvas_tfonts = ['times', 'lucidasans-24'] 46 | pass # XXX need defaults here 47 | 48 | def sleep(secs): 49 | global _root_window 50 | if _root_window == None: 51 | time.sleep(secs) 52 | else: 53 | _root_window.update_idletasks() 54 | _root_window.after(int(1000 * secs), _root_window.quit) 55 | _root_window.mainloop() 56 | 57 | def begin_graphics(width=640, height=480, color=formatColor(0, 0, 0), title=None): 58 | 59 | global _root_window, _canvas, _canvas_x, _canvas_y, _canvas_xs, _canvas_ys, _bg_color 60 | 61 | # Check for duplicate call 62 | if _root_window is not None: 63 | # Lose the window. 64 | _root_window.destroy() 65 | 66 | # Save the canvas size parameters 67 | _canvas_xs, _canvas_ys = width - 1, height - 1 68 | _canvas_x, _canvas_y = 0, _canvas_ys 69 | _bg_color = color 70 | 71 | # Create the root window 72 | _root_window = tkinter.Tk() 73 | _root_window.protocol('WM_DELETE_WINDOW', _destroy_window) 74 | _root_window.title(title or 'Graphics Window') 75 | _root_window.resizable(0, 0) 76 | 77 | # Create the canvas object 78 | try: 79 | _canvas = tkinter.Canvas(_root_window, width=width, height=height) 80 | _canvas.pack() 81 | draw_background() 82 | _canvas.update() 83 | except: 84 | _root_window = None 85 | raise 86 | 87 | # Bind to key-down and key-up events 88 | _root_window.bind( "", _keypress ) 89 | _root_window.bind( "", _keyrelease ) 90 | _root_window.bind( "", _clear_keys ) 91 | _root_window.bind( "", _clear_keys ) 92 | _root_window.bind( "", _leftclick ) 93 | _root_window.bind( "", _rightclick ) 94 | _root_window.bind( "", _rightclick ) 95 | _root_window.bind( "", _ctrl_leftclick) 96 | _clear_keys() 97 | 98 | _leftclick_loc = None 99 | _rightclick_loc = None 100 | _ctrl_leftclick_loc = None 101 | 102 | def _leftclick(event): 103 | global _leftclick_loc 104 | _leftclick_loc = (event.x, event.y) 105 | 106 | def _rightclick(event): 107 | global _rightclick_loc 108 | _rightclick_loc = (event.x, event.y) 109 | 110 | def _ctrl_leftclick(event): 111 | global _ctrl_leftclick_loc 112 | _ctrl_leftclick_loc = (event.x, event.y) 113 | 114 | def wait_for_click(): 115 | while True: 116 | global _leftclick_loc 117 | global _rightclick_loc 118 | global _ctrl_leftclick_loc 119 | if _leftclick_loc != None: 120 | val = _leftclick_loc 121 | _leftclick_loc = None 122 | return val, 'left' 123 | if _rightclick_loc != None: 124 | val = _rightclick_loc 125 | _rightclick_loc = None 126 | return val, 'right' 127 | if _ctrl_leftclick_loc != None: 128 | val = _ctrl_leftclick_loc 129 | _ctrl_leftclick_loc = None 130 | return val, 'ctrl_left' 131 | sleep(0.05) 132 | 133 | def draw_background(): 134 | corners = [(0,0), (0, _canvas_ys), (_canvas_xs, _canvas_ys), (_canvas_xs, 0)] 135 | polygon(corners, _bg_color, fillColor=_bg_color, filled=True, smoothed=False) 136 | 137 | def _destroy_window(event=None): 138 | sys.exit(0) 139 | # global _root_window 140 | # _root_window.destroy() 141 | # _root_window = None 142 | #print("DESTROY") 143 | 144 | def end_graphics(): 145 | global _root_window, _canvas, _mouse_enabled 146 | try: 147 | try: 148 | sleep(1) 149 | if _root_window != None: 150 | _root_window.destroy() 151 | except SystemExit as e: 152 | print('Ending graphics raised an exception:', e) 153 | finally: 154 | _root_window = None 155 | _canvas = None 156 | _mouse_enabled = 0 157 | _clear_keys() 158 | 159 | def clear_screen(background=None): 160 | global _canvas_x, _canvas_y 161 | _canvas.delete('all') 162 | draw_background() 163 | _canvas_x, _canvas_y = 0, _canvas_ys 164 | 165 | def polygon(coords, outlineColor, fillColor=None, filled=1, smoothed=1, behind=0, width=1): 166 | c = [] 167 | for coord in coords: 168 | c.append(coord[0]) 169 | c.append(coord[1]) 170 | if fillColor == None: fillColor = outlineColor 171 | if filled == 0: fillColor = "" 172 | poly = _canvas.create_polygon(c, outline=outlineColor, fill=fillColor, smooth=smoothed, width=width) 173 | if behind > 0: 174 | _canvas.tag_lower(poly, behind) # Higher should be more visible 175 | return poly 176 | 177 | def square(pos, r, color, filled=1, behind=0): 178 | x, y = pos 179 | coords = [(x - r, y - r), (x + r, y - r), (x + r, y + r), (x - r, y + r)] 180 | return polygon(coords, color, color, filled, 0, behind=behind) 181 | 182 | def circle(pos, r, outlineColor, fillColor=None, endpoints=None, style='pieslice', width=2): 183 | x, y = pos 184 | x0, x1 = x - r - 1, x + r 185 | y0, y1 = y - r - 1, y + r 186 | if endpoints == None: 187 | e = [0, 359] 188 | else: 189 | e = list(endpoints) 190 | while e[0] > e[1]: e[1] = e[1] + 360 191 | 192 | return _canvas.create_arc(x0, y0, x1, y1, outline=outlineColor, fill=fillColor or outlineColor, 193 | extent=e[1] - e[0], start=e[0], style=style, width=width) 194 | 195 | def image(pos, file="../../blueghost.gif"): 196 | x, y = pos 197 | # img = PhotoImage(file=file) 198 | return _canvas.create_image(x, y, image = tkinter.PhotoImage(file=file), anchor = tkinter.NW) 199 | 200 | 201 | def refresh(): 202 | _canvas.update_idletasks() 203 | 204 | def moveCircle(id, pos, r, endpoints=None): 205 | global _canvas_x, _canvas_y 206 | 207 | x, y = pos 208 | # x0, x1 = x - r, x + r + 1 209 | # y0, y1 = y - r, y + r + 1 210 | x0, x1 = x - r - 1, x + r 211 | y0, y1 = y - r - 1, y + r 212 | if endpoints == None: 213 | e = [0, 359] 214 | else: 215 | e = list(endpoints) 216 | while e[0] > e[1]: e[1] = e[1] + 360 217 | 218 | if os.path.isfile('flag'): 219 | edit(id, ('extent', e[1] - e[0])) 220 | else: 221 | edit(id, ('start', e[0]), ('extent', e[1] - e[0])) 222 | move_to(id, x0, y0) 223 | 224 | def edit(id, *args): 225 | _canvas.itemconfigure(id, **dict(args)) 226 | 227 | def text(pos, color, contents, font='Helvetica', size=12, style='normal', anchor="nw"): 228 | global _canvas_x, _canvas_y 229 | x, y = pos 230 | font = (font, str(size), style) 231 | return _canvas.create_text(x, y, fill=color, text=contents, font=font, anchor=anchor) 232 | 233 | def changeText(id, newText, font=None, size=12, style='normal'): 234 | _canvas.itemconfigure(id, text=newText) 235 | if font != None: 236 | _canvas.itemconfigure(id, font=(font, '-%d' % size, style)) 237 | 238 | def changeColor(id, newColor): 239 | _canvas.itemconfigure(id, fill=newColor) 240 | 241 | def line(here, there, color=formatColor(0, 0, 0), width=2): 242 | x0, y0 = here[0], here[1] 243 | x1, y1 = there[0], there[1] 244 | return _canvas.create_line(x0, y0, x1, y1, fill=color, width=width) 245 | 246 | ############################################################################## 247 | ### Keypress handling ######################################################## 248 | ############################################################################## 249 | 250 | # We bind to key-down and key-up events. 251 | 252 | _keysdown = {} 253 | _keyswaiting = {} 254 | # This holds an unprocessed key release. We delay key releases by up to 255 | # one call to keys_pressed() to get round a problem with auto repeat. 256 | _got_release = None 257 | 258 | def _keypress(event): 259 | global _got_release 260 | #remap_arrows(event) 261 | _keysdown[event.keysym] = 1 262 | _keyswaiting[event.keysym] = 1 263 | # print(event.char, event.keycode) 264 | _got_release = None 265 | 266 | def _keyrelease(event): 267 | global _got_release 268 | #remap_arrows(event) 269 | try: 270 | del _keysdown[event.keysym] 271 | except: 272 | pass 273 | _got_release = 1 274 | 275 | def remap_arrows(event): 276 | # TURN ARROW PRESSES INTO LETTERS (SHOULD BE IN KEYBOARD AGENT) 277 | if event.char in ['a', 's', 'd', 'w']: 278 | return 279 | if event.keycode in [37, 101]: # LEFT ARROW (win / x) 280 | event.char = 'a' 281 | if event.keycode in [38, 99]: # UP ARROW 282 | event.char = 'w' 283 | if event.keycode in [39, 102]: # RIGHT ARROW 284 | event.char = 'd' 285 | if event.keycode in [40, 104]: # DOWN ARROW 286 | event.char = 's' 287 | 288 | def _clear_keys(event=None): 289 | global _keysdown, _got_release, _keyswaiting 290 | _keysdown = {} 291 | _keyswaiting = {} 292 | _got_release = None 293 | 294 | def keys_pressed(d_o_e=lambda arg: _root_window.dooneevent(arg), 295 | d_w=tkinter._tkinter.DONT_WAIT): 296 | d_o_e(d_w) 297 | if _got_release: 298 | d_o_e(d_w) 299 | return _keysdown.keys() 300 | 301 | def keys_waiting(): 302 | global _keyswaiting 303 | keys = _keyswaiting.keys() 304 | _keyswaiting = {} 305 | return keys 306 | 307 | # Block for a list of keys... 308 | 309 | def wait_for_keys(): 310 | keys = [] 311 | while keys == []: 312 | keys = keys_pressed() 313 | sleep(0.05) 314 | return keys 315 | 316 | def remove_from_screen(x, 317 | d_o_e=lambda arg: _root_window.dooneevent(arg), 318 | d_w=tkinter._tkinter.DONT_WAIT): 319 | _canvas.delete(x) 320 | d_o_e(d_w) 321 | 322 | def _adjust_coords(coord_list, x, y): 323 | for i in range(0, len(coord_list), 2): 324 | coord_list[i] = coord_list[i] + x 325 | coord_list[i + 1] = coord_list[i + 1] + y 326 | return coord_list 327 | 328 | def move_to(object, x, y=None, 329 | d_o_e=lambda arg: _root_window.dooneevent(arg), 330 | d_w=tkinter._tkinter.DONT_WAIT): 331 | if y is None: 332 | try: x, y = x 333 | except: raise 'incomprehensible coordinates' 334 | 335 | horiz = True 336 | newCoords = [] 337 | current_x, current_y = _canvas.coords(object)[0:2] # first point 338 | for coord in _canvas.coords(object): 339 | if horiz: 340 | inc = x - current_x 341 | else: 342 | inc = y - current_y 343 | horiz = not horiz 344 | 345 | newCoords.append(coord + inc) 346 | 347 | _canvas.coords(object, *newCoords) 348 | d_o_e(d_w) 349 | 350 | def move_by(object, x, y=None, 351 | d_o_e=lambda arg: _root_window.dooneevent(arg), 352 | d_w=tkinter._tkinter.DONT_WAIT, lift=False): 353 | if y is None: 354 | try: x, y = x 355 | except: raise Exception('incomprehensible coordinates') 356 | 357 | horiz = True 358 | newCoords = [] 359 | for coord in _canvas.coords(object): 360 | if horiz: 361 | inc = x 362 | else: 363 | inc = y 364 | horiz = not horiz 365 | 366 | newCoords.append(coord + inc) 367 | 368 | _canvas.coords(object, *newCoords) 369 | d_o_e(d_w) 370 | if lift: 371 | _canvas.tag_raise(object) 372 | 373 | def writePostscript(filename): 374 | "Writes the current canvas to a postscript file." 375 | psfile = open(filename, 'w') 376 | psfile.write(_canvas.postscript(pageanchor='sw', 377 | y='0.c', 378 | x='0.c')) 379 | psfile.close() 380 | 381 | ghost_shape = [ 382 | (0, - 0.5), 383 | (0.25, - 0.75), 384 | (0.5, - 0.5), 385 | (0.75, - 0.75), 386 | (0.75, 0.5), 387 | (0.5, 0.75), 388 | (- 0.5, 0.75), 389 | (- 0.75, 0.5), 390 | (- 0.75, - 0.75), 391 | (- 0.5, - 0.5), 392 | (- 0.25, - 0.75) 393 | ] 394 | 395 | if __name__ == '__main__': 396 | begin_graphics() 397 | clear_screen() 398 | ghost_shape = [(x * 10 + 20, y * 10 + 20) for x, y in ghost_shape] 399 | g = polygon(ghost_shape, formatColor(1, 1, 1)) 400 | move_to(g, (50, 50)) 401 | circle((150, 150), 20, formatColor(0.7, 0.3, 0.0), endpoints=[15, - 15]) 402 | sleep(2) 403 | -------------------------------------------------------------------------------- /7_RL_Pacman/keyboardAgents.py: -------------------------------------------------------------------------------- 1 | # keyboardAgents.py 2 | # ----------------- 3 | # Licensing Information: You are free to use or extend these projects for 4 | # educational purposes provided that (1) you do not distribute or publish 5 | # solutions, (2) you retain this notice, and (3) you provide clear 6 | # attribution to UC Berkeley, including a link to http://ai.berkeley.edu. 7 | # 8 | # Attribution Information: The Pacman AI projects were developed at UC Berkeley. 9 | # The core projects and autograders were primarily created by John DeNero 10 | # (denero@cs.berkeley.edu) and Dan Klein (klein@cs.berkeley.edu). 11 | # Student side autograding was added by Brad Miller, Nick Hay, and 12 | # Pieter Abbeel (pabbeel@cs.berkeley.edu). 13 | 14 | 15 | from game import Agent 16 | from game import Directions 17 | import random 18 | 19 | class KeyboardAgent(Agent): 20 | """ 21 | An agent controlled by the keyboard. 22 | """ 23 | # NOTE: Arrow keys also work. 24 | WEST_KEY = 'a' 25 | EAST_KEY = 'd' 26 | NORTH_KEY = 'w' 27 | SOUTH_KEY = 's' 28 | STOP_KEY = 'q' 29 | 30 | def __init__( self, index = 0 ): 31 | 32 | self.lastMove = Directions.STOP 33 | self.index = index 34 | self.keys = [] 35 | 36 | def getAction( self, state): 37 | # print(state) 38 | from graphicsUtils import keys_waiting 39 | from graphicsUtils import keys_pressed 40 | keys = list(keys_waiting()) + list(keys_pressed()) 41 | if keys != []: 42 | self.keys = keys 43 | 44 | legal = state.getLegalActions(self.index) 45 | move = self.getMove(legal) 46 | 47 | if move == Directions.STOP: 48 | # Try to move in the same direction as before 49 | if self.lastMove in legal: 50 | move = self.lastMove 51 | 52 | if (self.STOP_KEY in self.keys) and Directions.STOP in legal: move = Directions.STOP 53 | 54 | if move not in legal: 55 | move = random.choice(legal) 56 | 57 | self.lastMove = move 58 | return move 59 | 60 | def getMove(self, legal): 61 | move = Directions.STOP 62 | if (self.WEST_KEY in self.keys or 'Left' in self.keys) and Directions.WEST in legal: move = Directions.WEST 63 | if (self.EAST_KEY in self.keys or 'Right' in self.keys) and Directions.EAST in legal: move = Directions.EAST 64 | if (self.NORTH_KEY in self.keys or 'Up' in self.keys) and Directions.NORTH in legal: move = Directions.NORTH 65 | if (self.SOUTH_KEY in self.keys or 'Down' in self.keys) and Directions.SOUTH in legal: move = Directions.SOUTH 66 | return move 67 | 68 | class KeyboardAgent2(KeyboardAgent): 69 | """ 70 | A second agent controlled by the keyboard. 71 | """ 72 | # NOTE: Arrow keys also work. 73 | WEST_KEY = 'j' 74 | EAST_KEY = "l" 75 | NORTH_KEY = 'i' 76 | SOUTH_KEY = 'k' 77 | STOP_KEY = 'u' 78 | 79 | def getMove(self, legal): 80 | move = Directions.STOP 81 | if (self.WEST_KEY in self.keys) and Directions.WEST in legal: move = Directions.WEST 82 | if (self.EAST_KEY in self.keys) and Directions.EAST in legal: move = Directions.EAST 83 | if (self.NORTH_KEY in self.keys) and Directions.NORTH in legal: move = Directions.NORTH 84 | if (self.SOUTH_KEY in self.keys) and Directions.SOUTH in legal: move = Directions.SOUTH 85 | return move 86 | -------------------------------------------------------------------------------- /7_RL_Pacman/layout.py: -------------------------------------------------------------------------------- 1 | # layout.py 2 | # --------- 3 | # Licensing Information: You are free to use or extend these projects for 4 | # educational purposes provided that (1) you do not distribute or publish 5 | # solutions, (2) you retain this notice, and (3) you provide clear 6 | # attribution to UC Berkeley, including a link to http://ai.berkeley.edu. 7 | # 8 | # Attribution Information: The Pacman AI projects were developed at UC Berkeley. 9 | # The core projects and autograders were primarily created by John DeNero 10 | # (denero@cs.berkeley.edu) and Dan Klein (klein@cs.berkeley.edu). 11 | # Student side autograding was added by Brad Miller, Nick Hay, and 12 | # Pieter Abbeel (pabbeel@cs.berkeley.edu). 13 | 14 | 15 | from util import manhattanDistance 16 | from game import Grid 17 | import os 18 | import random 19 | from functools import reduce 20 | 21 | VISIBILITY_MATRIX_CACHE = {} 22 | 23 | class Layout: 24 | """ 25 | A Layout manages the static information about the game board. 26 | """ 27 | 28 | def __init__(self, layoutText): 29 | self.width = len(layoutText[0]) 30 | self.height= len(layoutText) 31 | self.walls = Grid(self.width, self.height, False) 32 | self.food = Grid(self.width, self.height, False) 33 | self.capsules = [] 34 | self.agentPositions = [] 35 | self.numGhosts = 0 36 | self.processLayoutText(layoutText) 37 | self.layoutText = layoutText 38 | self.totalFood = len(self.food.asList()) 39 | # self.initializeVisibilityMatrix() 40 | 41 | def getNumGhosts(self): 42 | return self.numGhosts 43 | 44 | def initializeVisibilityMatrix(self): 45 | global VISIBILITY_MATRIX_CACHE 46 | if reduce(str.__add__, self.layoutText) not in VISIBILITY_MATRIX_CACHE: 47 | from game import Directions 48 | vecs = [(-0.5,0), (0.5,0),(0,-0.5),(0,0.5)] 49 | dirs = [Directions.NORTH, Directions.SOUTH, Directions.WEST, Directions.EAST] 50 | vis = Grid(self.width, self.height, {Directions.NORTH:set(), Directions.SOUTH:set(), Directions.EAST:set(), Directions.WEST:set(), Directions.STOP:set()}) 51 | for x in range(self.width): 52 | for y in range(self.height): 53 | if self.walls[x][y] == False: 54 | for vec, direction in zip(vecs, dirs): 55 | dx, dy = vec 56 | nextx, nexty = x + dx, y + dy 57 | while (nextx + nexty) != int(nextx) + int(nexty) or not self.walls[int(nextx)][int(nexty)] : 58 | vis[x][y][direction].add((nextx, nexty)) 59 | nextx, nexty = x + dx, y + dy 60 | self.visibility = vis 61 | VISIBILITY_MATRIX_CACHE[reduce(str.__add__, self.layoutText)] = vis 62 | else: 63 | self.visibility = VISIBILITY_MATRIX_CACHE[reduce(str.__add__, self.layoutText)] 64 | 65 | def isWall(self, pos): 66 | x, col = pos 67 | return self.walls[x][col] 68 | 69 | def getRandomLegalPosition(self): 70 | x = random.choice(range(self.width)) 71 | y = random.choice(range(self.height)) 72 | while self.isWall( (x, y) ): 73 | x = random.choice(range(self.width)) 74 | y = random.choice(range(self.height)) 75 | return (x,y) 76 | 77 | def getRandomCorner(self): 78 | poses = [(1,1), (1, self.height - 2), (self.width - 2, 1), (self.width - 2, self.height - 2)] 79 | return random.choice(poses) 80 | 81 | def getFurthestCorner(self, pacPos): 82 | poses = [(1,1), (1, self.height - 2), (self.width - 2, 1), (self.width - 2, self.height - 2)] 83 | dist, pos = max([(manhattanDistance(p, pacPos), p) for p in poses]) 84 | return pos 85 | 86 | def isVisibleFrom(self, ghostPos, pacPos, pacDirection): 87 | row, col = [int(x) for x in pacPos] 88 | return ghostPos in self.visibility[row][col][pacDirection] 89 | 90 | def __str__(self): 91 | return "\n".join(self.layoutText) 92 | 93 | def deepCopy(self): 94 | return Layout(self.layoutText[:]) 95 | 96 | def processLayoutText(self, layoutText): 97 | """ 98 | Coordinates are flipped from the input format to the (x,y) convention here 99 | 100 | The shape of the maze. Each character 101 | represents a different type of object. 102 | % - Wall 103 | . - Food 104 | o - Capsule 105 | G - Ghost 106 | P - Pacman 107 | Other characters are ignored. 108 | """ 109 | maxY = self.height - 1 110 | for y in range(self.height): 111 | for x in range(self.width): 112 | layoutChar = layoutText[maxY - y][x] 113 | self.processLayoutChar(x, y, layoutChar) 114 | self.agentPositions.sort() 115 | self.agentPositions = [ ( i == 0, pos) for i, pos in self.agentPositions] 116 | 117 | def processLayoutChar(self, x, y, layoutChar): 118 | if layoutChar == '%': 119 | self.walls[x][y] = True 120 | elif layoutChar == '.': 121 | self.food[x][y] = True 122 | elif layoutChar == 'o': 123 | self.capsules.append((x, y)) 124 | elif layoutChar == 'P': 125 | self.agentPositions.append( (0, (x, y) ) ) 126 | elif layoutChar in ['G']: 127 | self.agentPositions.append( (1, (x, y) ) ) 128 | self.numGhosts += 1 129 | elif layoutChar in ['1', '2', '3', '4']: 130 | self.agentPositions.append( (int(layoutChar), (x,y))) 131 | self.numGhosts += 1 132 | def getLayout(name, back = 2): 133 | if name.endswith('.lay'): 134 | layout = tryToLoad('layouts/' + name) 135 | if layout == None: layout = tryToLoad(name) 136 | else: 137 | layout = tryToLoad('layouts/' + name + '.lay') 138 | if layout == None: layout = tryToLoad(name + '.lay') 139 | if layout == None and back >= 0: 140 | curdir = os.path.abspath('.') 141 | os.chdir('..') 142 | layout = getLayout(name, back -1) 143 | os.chdir(curdir) 144 | return layout 145 | 146 | def tryToLoad(fullname): 147 | if(not os.path.exists(fullname)): return None 148 | f = open(fullname) 149 | try: return Layout([line.strip() for line in f]) 150 | finally: f.close() 151 | -------------------------------------------------------------------------------- /7_RL_Pacman/layouts/mediumClassic.lay: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%% 2 | %o...%........%....% 3 | %.%%.%.%%%%%%.%.%%.% 4 | %.%..............%.% 5 | %.%.%%.%% %%.%%.%.% 6 | %......%G G%......% 7 | %.%.%%.%%%%%%.%%.%.% 8 | %.%..............%.% 9 | %.%%.%.%%%%%%.%.%%.% 10 | %....%...P....%...o% 11 | %%%%%%%%%%%%%%%%%%%% 12 | -------------------------------------------------------------------------------- /7_RL_Pacman/pacman.py: -------------------------------------------------------------------------------- 1 | # pacman.py 2 | # --------- 3 | # Licensing Information: You are free to use or extend these projects for 4 | # educational purposes provided that (1) you do not distribute or publish 5 | # solutions, (2) you retain this notice, and (3) you provide clear 6 | # attribution to UC Berkeley, including a link to http://ai.berkeley.edu. 7 | # 8 | # Attribution Information: The Pacman AI projects were developed at UC Berkeley. 9 | # The core projects and autograders were primarily created by John DeNero 10 | # (denero@cs.berkeley.edu) and Dan Klein (klein@cs.berkeley.edu). 11 | # Student side autograding was added by Brad Miller, Nick Hay, and 12 | # Pieter Abbeel (pabbeel@cs.berkeley.edu). 13 | 14 | 15 | """ 16 | Pacman.py holds the logic for the classic pacman game along with the main 17 | code to run a game. This file is divided into three sections: 18 | 19 | (i) Your interface to the pacman world: 20 | Pacman is a complex environment. You probably don't want to 21 | read through all of the code we wrote to make the game runs 22 | correctly. This section contains the parts of the code 23 | that you will need to understand in order to complete the 24 | project. There is also some code in game.py that you should 25 | understand. 26 | 27 | (ii) The hidden secrets of pacman: 28 | This section contains all of the logic code that the pacman 29 | environment uses to decide who can move where, who dies when 30 | things collide, etc. You shouldn't need to read this section 31 | of code, but you can if you want. 32 | 33 | (iii) Framework to start a game: 34 | The final section contains the code for reading the command 35 | you use to set up the game, then starting up a new game, along with 36 | linking in all the external parts (agent functions, graphics). 37 | Check this section out to see all the options available to you. 38 | 39 | To play your first game, type 'python pacman.py' from the command line. 40 | The keys are 'a', 's', 'd', and 'w' to move (or arrow keys). Have fun! 41 | """ 42 | from game import GameStateData 43 | from game import Game 44 | from game import Directions 45 | from game import Actions 46 | from util import nearestPoint 47 | from util import manhattanDistance 48 | import util, layout 49 | import sys, types, time, random, os 50 | 51 | ################################################### 52 | # YOUR INTERFACE TO THE PACMAN WORLD: A GameState # 53 | ################################################### 54 | 55 | class GameState: 56 | """ 57 | A GameState specifies the full game state, including the food, capsules, 58 | agent configurations and score changes. 59 | 60 | GameStates are used by the Game object to capture the actual state of the game and 61 | can be used by agents to reason about the game. 62 | 63 | Much of the information in a GameState is stored in a GameStateData object. We 64 | strongly suggest that you access that data via the accessor methods below rather 65 | than referring to the GameStateData object directly. 66 | 67 | Note that in classic Pacman, Pacman is always agent 0. 68 | """ 69 | 70 | #################################################### 71 | # Accessor methods: use these to access state data # 72 | #################################################### 73 | 74 | # static variable keeps track of which states have had getLegalActions called 75 | explored = set() 76 | def getAndResetExplored(): 77 | tmp = GameState.explored.copy() 78 | GameState.explored = set() 79 | return tmp 80 | getAndResetExplored = staticmethod(getAndResetExplored) 81 | 82 | def getLegalActions( self, agentIndex=0 ): 83 | """ 84 | Returns the legal actions for the agent specified. 85 | """ 86 | # GameState.explored.add(self) 87 | if self.isWin() or self.isLose(): return [] 88 | 89 | if agentIndex == 0: # Pacman is moving 90 | return PacmanRules.getLegalActions( self ) 91 | else: 92 | return GhostRules.getLegalActions( self, agentIndex ) 93 | 94 | def generateSuccessor( self, agentIndex, action): 95 | """ 96 | Returns the successor state after the specified agent takes the action. 97 | """ 98 | # Check that successors exist 99 | if self.isWin() or self.isLose(): raise Exception('Can\'t generate a successor of a terminal state.') 100 | 101 | # Copy current state 102 | state = GameState(self) 103 | 104 | # Let agent's logic deal with its action's effects on the board 105 | if agentIndex == 0: # Pacman is moving 106 | state.data._eaten = [False for i in range(state.getNumAgents())] 107 | PacmanRules.applyAction( state, action ) 108 | else: # A ghost is moving 109 | GhostRules.applyAction( state, action, agentIndex ) 110 | 111 | # Time passes 112 | if agentIndex == 0: 113 | state.data.scoreChange += -TIME_PENALTY # Penalty for waiting around 114 | else: 115 | GhostRules.decrementTimer( state.data.agentStates[agentIndex] ) 116 | 117 | # Resolve multi-agent effects 118 | GhostRules.checkDeath( state, agentIndex ) 119 | 120 | # Book keeping 121 | state.data._agentMoved = agentIndex 122 | state.data.score += state.data.scoreChange 123 | GameState.explored.add(self) 124 | GameState.explored.add(state) 125 | return state 126 | 127 | def getLegalPacmanActions( self ): 128 | return self.getLegalActions( 0 ) 129 | 130 | def generatePacmanSuccessor( self, action ): 131 | """ 132 | Generates the successor state after the specified pacman move 133 | """ 134 | return self.generateSuccessor( 0, action ) 135 | 136 | def getPacmanState( self ): 137 | """ 138 | Returns an AgentState object for pacman (in game.py) 139 | 140 | state.pos gives the current position 141 | state.direction gives the travel vector 142 | """ 143 | return self.data.agentStates[0].copy() 144 | 145 | def getPacmanPosition( self ): 146 | return self.data.agentStates[0].getPosition() 147 | 148 | def getGhostStates( self ): 149 | return self.data.agentStates[1:] 150 | 151 | def getGhostState( self, agentIndex ): 152 | if agentIndex == 0 or agentIndex >= self.getNumAgents(): 153 | raise Exception("Invalid index passed to getGhostState") 154 | return self.data.agentStates[agentIndex] 155 | 156 | def getGhostPosition( self, agentIndex ): 157 | if agentIndex == 0: 158 | raise Exception("Pacman's index passed to getGhostPosition") 159 | return self.data.agentStates[agentIndex].getPosition() 160 | 161 | def getGhostPositions(self): 162 | return [s.getPosition() for s in self.getGhostStates()] 163 | 164 | def getNumAgents( self ): 165 | return len( self.data.agentStates ) 166 | 167 | def getScore( self ): 168 | return float(self.data.score) 169 | 170 | def getCapsules(self): 171 | """ 172 | Returns a list of positions (x,y) of the remaining capsules. 173 | """ 174 | return self.data.capsules 175 | 176 | def getNumFood( self ): 177 | return self.data.food.count() 178 | 179 | def getFood(self): 180 | """ 181 | Returns a Grid of boolean food indicator variables. 182 | 183 | Grids can be accessed via list notation, so to check 184 | if there is food at (x,y), just call 185 | 186 | currentFood = state.getFood() 187 | if currentFood[x][y] == True: ... 188 | """ 189 | return self.data.food 190 | 191 | def getWalls(self): 192 | """ 193 | Returns a Grid of boolean wall indicator variables. 194 | 195 | Grids can be accessed via list notation, so to check 196 | if there is a wall at (x,y), just call 197 | 198 | walls = state.getWalls() 199 | if walls[x][y] == True: ... 200 | """ 201 | return self.data.layout.walls 202 | 203 | def hasFood(self, x, y): 204 | return self.data.food[x][y] 205 | 206 | def hasWall(self, x, y): 207 | return self.data.layout.walls[x][y] 208 | 209 | def isLose( self ): 210 | return self.data._lose 211 | 212 | def isWin( self ): 213 | return self.data._win 214 | 215 | ############################################# 216 | # Helper methods: # 217 | # You shouldn't need to call these directly # 218 | ############################################# 219 | 220 | def __init__( self, prevState = None ): 221 | """ 222 | Generates a new state by copying information from its predecessor. 223 | """ 224 | if prevState != None: # Initial state 225 | self.data = GameStateData(prevState.data) 226 | else: 227 | self.data = GameStateData() 228 | 229 | def deepCopy( self ): 230 | state = GameState( self ) 231 | state.data = self.data.deepCopy() 232 | return state 233 | 234 | def __eq__( self, other ): 235 | """ 236 | Allows two states to be compared. 237 | """ 238 | return hasattr(other, 'data') and self.data == other.data 239 | 240 | def __hash__( self ): 241 | """ 242 | Allows states to be keys of dictionaries. 243 | """ 244 | return hash( self.data ) 245 | 246 | def __str__( self ): 247 | 248 | return str(self.data) 249 | 250 | def initialize( self, layout, numGhostAgents=1000 ): 251 | """ 252 | Creates an initial game state from a layout array (see layout.py). 253 | """ 254 | self.data.initialize(layout, numGhostAgents) 255 | 256 | ############################################################################ 257 | # THE HIDDEN SECRETS OF PACMAN # 258 | # # 259 | # You shouldn't need to look through the code in this section of the file. # 260 | ############################################################################ 261 | 262 | SCARED_TIME = 40 # Moves ghosts are scared 263 | COLLISION_TOLERANCE = 0.7 # How close ghosts must be to Pacman to kill 264 | TIME_PENALTY = 1 # Number of points lost each round 265 | 266 | class ClassicGameRules: 267 | """ 268 | These game rules manage the control flow of a game, deciding when 269 | and how the game starts and ends. 270 | """ 271 | def __init__(self, timeout=30): 272 | self.timeout = timeout 273 | 274 | def newGame( self, layout, pacmanAgent, ghostAgents, display, quiet = False, catchExceptions=False): 275 | agents = [pacmanAgent] + ghostAgents[:layout.getNumGhosts()] 276 | initState = GameState() 277 | initState.initialize( layout, len(ghostAgents) ) 278 | game = Game(agents, display, self, catchExceptions=catchExceptions) 279 | game.state = initState 280 | self.initialState = initState.deepCopy() 281 | self.quiet = quiet 282 | return game 283 | 284 | def process(self, state, game): 285 | """ 286 | Checks to see whether it is time to end the game. 287 | """ 288 | if state.isWin(): self.win(state, game) 289 | if state.isLose(): self.lose(state, game) 290 | 291 | def win( self, state, game ): 292 | if not self.quiet: print("Pacman emerges victorious! Score: %d" % state.data.score) 293 | game.gameOver = True 294 | 295 | def lose( self, state, game ): 296 | if not self.quiet: print("Pacman died! Score: %d" % state.data.score) 297 | game.gameOver = True 298 | 299 | def getProgress(self, game): 300 | return float(game.state.getNumFood()) / self.initialState.getNumFood() 301 | 302 | def agentCrash(self, game, agentIndex): 303 | if agentIndex == 0: 304 | print("Pacman crashed") 305 | else: 306 | print("A ghost crashed") 307 | 308 | def getMaxTotalTime(self, agentIndex): 309 | return self.timeout 310 | 311 | def getMaxStartupTime(self, agentIndex): 312 | return self.timeout 313 | 314 | def getMoveWarningTime(self, agentIndex): 315 | return self.timeout 316 | 317 | def getMoveTimeout(self, agentIndex): 318 | return self.timeout 319 | 320 | def getMaxTimeWarnings(self, agentIndex): 321 | return 0 322 | 323 | class PacmanRules: 324 | """ 325 | These functions govern how pacman interacts with his environment under 326 | the classic game rules. 327 | """ 328 | PACMAN_SPEED=1 329 | 330 | def getLegalActions( state ): 331 | """ 332 | Returns a list of possible actions. 333 | """ 334 | return Actions.getPossibleActions( state.getPacmanState().configuration, state.data.layout.walls ) 335 | getLegalActions = staticmethod( getLegalActions ) 336 | 337 | def applyAction( state, action ): 338 | """ 339 | Edits the state to reflect the results of the action. 340 | """ 341 | legal = PacmanRules.getLegalActions( state ) 342 | if action not in legal: 343 | raise Exception("Illegal action " + str(action)) 344 | 345 | pacmanState = state.data.agentStates[0] 346 | 347 | # Update Configuration 348 | vector = Actions.directionToVector( action, PacmanRules.PACMAN_SPEED ) 349 | pacmanState.configuration = pacmanState.configuration.generateSuccessor( vector ) 350 | 351 | # Eat 352 | next = pacmanState.configuration.getPosition() 353 | nearest = nearestPoint( next ) 354 | if manhattanDistance( nearest, next ) <= 0.5 : 355 | # Remove food 356 | PacmanRules.consume( nearest, state ) 357 | applyAction = staticmethod( applyAction ) 358 | 359 | def consume( position, state ): 360 | x,y = position 361 | # Eat food 362 | if state.data.food[x][y]: 363 | state.data.scoreChange += 10 364 | state.data.food = state.data.food.copy() 365 | state.data.food[x][y] = False 366 | state.data._foodEaten = position 367 | # TODO: cache numFood? 368 | numFood = state.getNumFood() 369 | if numFood == 0 and not state.data._lose: 370 | state.data.scoreChange += 500 371 | state.data._win = True 372 | # Eat capsule 373 | if( position in state.getCapsules() ): 374 | state.data.capsules.remove( position ) 375 | state.data._capsuleEaten = position 376 | # Reset all ghosts' scared timers 377 | for index in range( 1, len( state.data.agentStates ) ): 378 | state.data.agentStates[index].scaredTimer = SCARED_TIME 379 | consume = staticmethod( consume ) 380 | 381 | class GhostRules: 382 | """ 383 | These functions dictate how ghosts interact with their environment. 384 | """ 385 | GHOST_SPEED=1.0 386 | def getLegalActions( state, ghostIndex ): 387 | """ 388 | Ghosts cannot stop, and cannot turn around unless they 389 | reach a dead end, but can turn 90 degrees at intersections. 390 | """ 391 | conf = state.getGhostState( ghostIndex ).configuration 392 | possibleActions = Actions.getPossibleActions( conf, state.data.layout.walls ) 393 | reverse = Actions.reverseDirection( conf.direction ) 394 | if Directions.STOP in possibleActions: 395 | possibleActions.remove( Directions.STOP ) 396 | if reverse in possibleActions and len( possibleActions ) > 1: 397 | possibleActions.remove( reverse ) 398 | return possibleActions 399 | getLegalActions = staticmethod( getLegalActions ) 400 | 401 | def applyAction( state, action, ghostIndex): 402 | 403 | legal = GhostRules.getLegalActions( state, ghostIndex ) 404 | if action not in legal: 405 | raise Exception("Illegal ghost action " + str(action)) 406 | 407 | ghostState = state.data.agentStates[ghostIndex] 408 | speed = GhostRules.GHOST_SPEED 409 | if ghostState.scaredTimer > 0: speed /= 2.0 410 | vector = Actions.directionToVector( action, speed ) 411 | ghostState.configuration = ghostState.configuration.generateSuccessor( vector ) 412 | applyAction = staticmethod( applyAction ) 413 | 414 | def decrementTimer( ghostState): 415 | timer = ghostState.scaredTimer 416 | if timer == 1: 417 | ghostState.configuration.pos = nearestPoint( ghostState.configuration.pos ) 418 | ghostState.scaredTimer = max( 0, timer - 1 ) 419 | decrementTimer = staticmethod( decrementTimer ) 420 | 421 | def checkDeath( state, agentIndex): 422 | pacmanPosition = state.getPacmanPosition() 423 | if agentIndex == 0: # Pacman just moved; Anyone can kill him 424 | for index in range( 1, len( state.data.agentStates ) ): 425 | ghostState = state.data.agentStates[index] 426 | ghostPosition = ghostState.configuration.getPosition() 427 | if GhostRules.canKill( pacmanPosition, ghostPosition ): 428 | GhostRules.collide( state, ghostState, index ) 429 | else: 430 | ghostState = state.data.agentStates[agentIndex] 431 | ghostPosition = ghostState.configuration.getPosition() 432 | if GhostRules.canKill( pacmanPosition, ghostPosition ): 433 | GhostRules.collide( state, ghostState, agentIndex ) 434 | checkDeath = staticmethod( checkDeath ) 435 | 436 | def collide( state, ghostState, agentIndex): 437 | if ghostState.scaredTimer > 0: 438 | state.data.scoreChange += 200 439 | GhostRules.placeGhost(state, ghostState) 440 | ghostState.scaredTimer = 0 441 | # Added for first-person 442 | state.data._eaten[agentIndex] = True 443 | else: 444 | if not state.data._win: 445 | state.data.scoreChange -= 500 446 | state.data._lose = True 447 | collide = staticmethod( collide ) 448 | 449 | def canKill( pacmanPosition, ghostPosition ): 450 | return manhattanDistance( ghostPosition, pacmanPosition ) <= COLLISION_TOLERANCE 451 | canKill = staticmethod( canKill ) 452 | 453 | def placeGhost(state, ghostState): 454 | ghostState.configuration = ghostState.start 455 | placeGhost = staticmethod( placeGhost ) 456 | 457 | ############################# 458 | # FRAMEWORK TO START A GAME # 459 | ############################# 460 | 461 | def default(str): 462 | return str + ' [Default: %default]' 463 | 464 | def parseAgentArgs(str): 465 | if str == None: return {} 466 | pieces = str.split(',') 467 | opts = {} 468 | for p in pieces: 469 | if '=' in p: 470 | key, val = p.split('=') 471 | else: 472 | key,val = p, 1 473 | opts[key] = val 474 | return opts 475 | 476 | def readCommand( argv ): 477 | """ 478 | Processes the command used to run pacman from the command line. 479 | """ 480 | from optparse import OptionParser 481 | usageStr = """ 482 | USAGE: python pacman.py 483 | EXAMPLES: (1) python pacman.py 484 | - starts an interactive game 485 | (2) python pacman.py --layout smallClassic --zoom 2 486 | OR python pacman.py -l smallClassic -z 2 487 | - starts an interactive game on a smaller board, zoomed in 488 | """ 489 | parser = OptionParser(usageStr) 490 | 491 | parser.add_option('-n', '--numGames', dest='numGames', type='int', 492 | help=default('the number of GAMES to play'), metavar='GAMES', default=1) 493 | parser.add_option('-l', '--layout', dest='layout', 494 | help=default('the LAYOUT_FILE from which to load the map layout'), 495 | metavar='LAYOUT_FILE', default='mediumClassic') 496 | parser.add_option('-p', '--pacman', dest='pacman', 497 | help=default('the agent TYPE in the pacmanAgents module to use'), 498 | metavar='TYPE', default='KeyboardAgent') 499 | parser.add_option('-t', '--textGraphics', action='store_true', dest='textGraphics', 500 | help='Display output as text only', default=False) 501 | parser.add_option('-q', '--quietTextGraphics', action='store_true', dest='quietGraphics', 502 | help='Generate minimal output and no graphics', default=False) 503 | parser.add_option('-g', '--ghosts', dest='ghost', 504 | help=default('the ghost agent TYPE in the ghostAgents module to use'), 505 | metavar = 'TYPE', default='RandomGhost') 506 | parser.add_option('-k', '--numghosts', type='int', dest='numGhosts', 507 | help=default('The maximum number of ghosts to use'), default=4) 508 | parser.add_option('-z', '--zoom', type='float', dest='zoom', 509 | help=default('Zoom the size of the graphics window'), default=1.0) 510 | parser.add_option('-f', '--fixRandomSeed', action='store_true', dest='fixRandomSeed', 511 | help='Fixes the random seed to always play the same game', default=False) 512 | parser.add_option('-r', '--recordActions', action='store_true', dest='record', 513 | help='Writes game histories to a file (named by the time they were played)', default=False) 514 | parser.add_option('--replay', dest='gameToReplay', 515 | help='A recorded game file (pickle) to replay', default=None) 516 | parser.add_option('-a','--agentArgs',dest='agentArgs', 517 | help='Comma separated values sent to agent. e.g. "opt1=val1,opt2,opt3=val3"') 518 | parser.add_option('-x', '--numTraining', dest='numTraining', type='int', 519 | help=default('How many episodes are training (suppresses output)'), default=0) 520 | parser.add_option('--frameTime', dest='frameTime', type='float', 521 | help=default('Time to delay between frames; <0 means keyboard'), default=0.1) 522 | parser.add_option('-c', '--catchExceptions', action='store_true', dest='catchExceptions', 523 | help='Turns on exception handling and timeouts during games', default=False) 524 | parser.add_option('--timeout', dest='timeout', type='int', 525 | help=default('Maximum length of time an agent can spend computing in a single game'), default=30) 526 | 527 | options, otherjunk = parser.parse_args(argv) 528 | if len(otherjunk) != 0: 529 | raise Exception('Command line input not understood: ' + str(otherjunk)) 530 | args = dict() 531 | 532 | # Fix the random seed 533 | if options.fixRandomSeed: random.seed('cs188') 534 | 535 | # Choose a layout 536 | args['layout'] = layout.getLayout( options.layout ) 537 | if args['layout'] == None: raise Exception("The layout " + options.layout + " cannot be found") 538 | 539 | # Choose a Pacman agent 540 | noKeyboard = options.gameToReplay == None and (options.textGraphics or options.quietGraphics) 541 | pacmanType = loadAgent(options.pacman, noKeyboard) 542 | agentOpts = parseAgentArgs(options.agentArgs) 543 | if options.numTraining > 0: 544 | args['numTraining'] = options.numTraining 545 | if 'numTraining' not in agentOpts: agentOpts['numTraining'] = options.numTraining 546 | pacman = pacmanType(**agentOpts) # Instantiate Pacman with agentArgs 547 | args['pacman'] = pacman 548 | 549 | # Don't display training games 550 | if 'numTrain' in agentOpts: 551 | options.numQuiet = int(agentOpts['numTrain']) 552 | options.numIgnore = int(agentOpts['numTrain']) 553 | 554 | # Choose a ghost agent 555 | ghostType = loadAgent(options.ghost, noKeyboard) 556 | args['ghosts'] = [ghostType( i+1 ) for i in range( options.numGhosts )] 557 | 558 | # Choose a display format 559 | if options.quietGraphics: 560 | import textDisplay 561 | args['display'] = textDisplay.NullGraphics() 562 | elif options.textGraphics: 563 | import textDisplay 564 | textDisplay.SLEEP_TIME = options.frameTime 565 | args['display'] = textDisplay.PacmanGraphics() 566 | else: 567 | import graphicsDisplay 568 | args['display'] = graphicsDisplay.PacmanGraphics(options.zoom, frameTime = options.frameTime) 569 | args['numGames'] = options.numGames 570 | args['record'] = options.record 571 | args['catchExceptions'] = options.catchExceptions 572 | args['timeout'] = options.timeout 573 | 574 | # Special case: recorded games don't use the runGames method or args structure 575 | if options.gameToReplay != None: 576 | print('Replaying recorded game %s.' % options.gameToReplay) 577 | import pickle 578 | f = open(options.gameToReplay, 'rb') 579 | try: recorded = pickle.load(f) 580 | finally: f.close() 581 | recorded['display'] = args['display'] 582 | replayGame(**recorded) 583 | sys.exit(0) 584 | 585 | return args 586 | 587 | def loadAgent(pacman, nographics): 588 | # Looks through all pythonPath Directories for the right module, 589 | pythonPathStr = os.path.expandvars("$PYTHONPATH") 590 | if pythonPathStr.find(';') == -1: 591 | pythonPathDirs = pythonPathStr.split(':') 592 | else: 593 | pythonPathDirs = pythonPathStr.split(';') 594 | pythonPathDirs.append('.') 595 | 596 | for moduleDir in pythonPathDirs: 597 | if not os.path.isdir(moduleDir): continue 598 | moduleNames = [f for f in os.listdir(moduleDir) if f.endswith('gents.py')] 599 | for modulename in moduleNames: 600 | try: 601 | module = __import__(modulename[:-3]) 602 | except ImportError: 603 | continue 604 | if pacman in dir(module): 605 | # print(f'pacman : {pacman}, module : {dir(module)}') 606 | if nographics and modulename == 'keyboardAgents.py': 607 | raise Exception('Using the keyboard requires graphics (not text display)') 608 | return getattr(module, pacman) 609 | raise Exception('The agent ' + pacman + ' is not specified in any *Agents.py.') 610 | 611 | def replayGame( layout, actions, display ): 612 | import pacmanAgents, ghostAgents 613 | rules = ClassicGameRules() 614 | agents = [pacmanAgents.GreedyAgent()] + [ghostAgents.RandomGhost(i+1) for i in range(layout.getNumGhosts())] 615 | game = rules.newGame( layout, agents[0], agents[1:], display ) 616 | state = game.state 617 | display.initialize(state.data) 618 | 619 | for action in actions: 620 | # Execute the action 621 | state = state.generateSuccessor( *action ) 622 | # Change the display 623 | display.update( state.data ) 624 | # Allow for game specific conditions (winning, losing, etc.) 625 | rules.process(state, game) 626 | 627 | display.finish() 628 | 629 | def runGames( layout, pacman, ghosts, display, numGames, record, numTraining = 0, catchExceptions=False, timeout=30 ): 630 | import __main__ 631 | __main__.__dict__['_display'] = display 632 | 633 | rules = ClassicGameRules(timeout) 634 | games = [] 635 | 636 | for i in range( numGames ): 637 | beQuiet = i < numTraining 638 | if beQuiet: 639 | # Suppress output and graphics 640 | import textDisplay 641 | gameDisplay = textDisplay.NullGraphics() 642 | rules.quiet = True 643 | else: 644 | gameDisplay = display 645 | rules.quiet = False 646 | game = rules.newGame( layout, pacman, ghosts, gameDisplay, beQuiet, catchExceptions) 647 | game.run() 648 | if not beQuiet: games.append(game) 649 | 650 | if record: 651 | import time, pickle 652 | fname = ('recorded-game-%d' % (i + 1)) + '-'.join([str(t) for t in time.localtime()[1:6]]) 653 | f = open(fname, 'wb') 654 | components = {'layout': layout, 'actions': game.moveHistory} 655 | pickle.dump(components, f) 656 | f.close() 657 | 658 | if (numGames-numTraining) > 0: 659 | scores = [game.state.getScore() for game in games] 660 | wins = [game.state.isWin() for game in games] 661 | winRate = wins.count(True)/ float(len(wins)) 662 | print('Average Score:', sum(scores) / float(len(scores))) 663 | print('Scores: ', ', '.join([str(score) for score in scores])) 664 | print('Win Rate: %d/%d (%.2f)' % (wins.count(True), len(wins), winRate)) 665 | print('Record: ', ', '.join([ ['Loss', 'Win'][int(w)] for w in wins])) 666 | 667 | return games 668 | 669 | if __name__ == '__main__': 670 | """ 671 | The main function called when pacman.py is run 672 | from the command line: 673 | 674 | > python pacman.py 675 | 676 | See the usage string for more details. 677 | 678 | > python pacman.py --help 679 | """ 680 | args = readCommand( sys.argv[1:] ) # Get game components based on input 681 | runGames( **args ) 682 | 683 | # import cProfile 684 | # cProfile.run("runGames( **args )") 685 | pass 686 | -------------------------------------------------------------------------------- /7_RL_Pacman/readme.md: -------------------------------------------------------------------------------- 1 | 运行`python pacman.py`即可启动游戏,用方向键控制agent移动。 2 | ![example_image](./data/example_image.png) 3 | 4 | 待优化: 5 | 在实现RLAgent时,需要用到layout,ghost_num等参数,这里直接用了默认值,可以修改pacman代码更好地对齐。 6 | 7 | 8 | 这个项目乍一看有点复杂,实际上确实很复杂。不过我们只需要找到一些关键的代码和函数接口,在此基础上就可以实现我们的强化学习算法。 9 | 10 | 要基于此项目实现用强化学习控制agent,我们需要: 11 | - 找到键盘传入action控制agent移动的代码,这是我们与环境交互的接口。 12 | - 定义state,包括agent位置,ghosts位置,地图信息,food位置等。 13 | - 执行action之后,要能得到next_state,reward,这些参数用来帮助agent学习。 14 | 15 | 之后我们就可以训练agent,使其能够根据当前时刻的state,选择最优的action来自动完成游戏,获得最大的score。 16 | 17 | 18 | 分析代码: 19 | `pacman.py main()`: 20 | runGames( **args )-->game = rules.newGame(...pacman, ghosts...) 21 | ```python 22 | def newGame( self, layout, pacmanAgent, ghostAgents, display, quiet = False, catchExceptions=False): 23 | agents = [pacmanAgent] + ghostAgents[:layout.getNumGhosts()] 24 | initState = GameState() 25 | initState.initialize( layout, len(ghostAgents) ) 26 | game = Game(agents, display, self, catchExceptions=catchExceptions) 27 | ``` 28 | 初始化game时传递了agents参数,在`game.py`第686行,通过`action = agent.getAction(observation)`来根据当前state选择action。pacaman和ghost实现了通用的框架,我们的目标是修改pacaman的getAction函数。 29 | 查找paman_agent的实现 30 | `runFames(**args)`中传递的参数是`args = readCommand( sys.argv[1:] )`得到的。 31 | ```python 32 | def readCommand( argv ): 33 | ... 34 | pacmanType = loadAgent(options.pacman, noKeyboard) 35 | pacman = pacmanType(**agentOpts) 36 | args['pacman'] = pacman 37 | 38 | def loadAgent(pacman, nographics): 39 | ... 40 | if pacman in dir(module): 41 | # 观察传到这里的参数值 42 | print(f'pacman : {pacman}, module : {dir(module)}') 43 | return getattr(module, pacman) 44 | ``` 45 | 输出结果: 46 | ![loadAgent](./data/loadAgent.png) 47 | 在当前目录中搜索以gents.py结尾的Python模块并检查其中是否存在名为pacman的代理类。 48 | pacman对应的agent类是`KeyboardAgent`, ghost对应的agent类是`RandomGhost`。 49 | 我们需要仿照这agent类的结构实现RLAgent,在运行游戏代码时指定`-p`参数即可调用RLAgent。 50 | ```python 51 | def readCommand( argv ): 52 | ... 53 | parser.add_option('-p', '--pacman', dest='pacman', 54 | help=default('the agent TYPE in the pacmanAgents module to use'), 55 | metavar='TYPE', default='KeyboardAgent') 56 | ``` 57 | 58 | 接下来分析状态,游戏中已经定义了比较完整的GameState(`game.py class GameStateData`),包括agent位置,地图信息,food位置和score等。 59 | 也可以在代码中print(state)观察 60 | ```txt 61 | %%%%%%%%%%%%%%%%%%%% 62 | %o...%........%....% 63 | %.%%.%.%%%%%%.%.%%.% 64 | %.%........G.....%.% 65 | %.%.%%.%%G %%.%%.%.% 66 | %......% %......% 67 | %.%.%%.%%%%%%.%%.%.% 68 | %.%..............%.% 69 | %.%%.%.%%%%%%.%.%%.% 70 | %....%...<....%...o% 71 | %%%%%%%%%%%%%%%%%%%% 72 | Score: -7 73 | ``` 74 | 这个state包含了当前游戏环境中全部信息,看到这种格式也很容易让人想到把状态表示成(多通道)二维向量,然后利用卷积神经网络学习。 75 | 不过考虑到游戏逻辑比较简单,状态空间比较小,可以把不同的特征展平为一维向量再拼接 76 | ···python 77 | def _extract_features(self, state): 78 | """从游戏状态提取特征向量""" 79 | 80 | # 创建一个二维矩阵表示地图状态 81 | # 0: blank, 1: walls, 2: food, 3: capsules, 4: Pacman, 5: ghosts 82 | grid_state = np.zeros((width, height)) 83 | 84 | # 填充墙壁 85 | for x in range(width): 86 | for y in range(height): 87 | if walls[x][y]: 88 | grid_state[x][y] = 1 89 | 90 | # 填充食物 91 | food = state.getFood() 92 | for x in range(width): 93 | for y in range(height): 94 | if food[x][y]: 95 | grid_state[x][y] = 2 96 | 97 | # 填充胶囊 98 | capsules = state.getCapsules() 99 | for x, y in capsules: 100 | grid_state[int(x)][int(y)] = 3 101 | 102 | # 填充Pacman 103 | pacman_x, pacman_y = state.getPacmanPosition() 104 | grid_state[int(pacman_x)][int(pacman_y)] = 4 105 | 106 | # 填充幽灵 107 | ghost_states = state.getGhostStates() 108 | for ghost in ghost_states: 109 | ghost_x, ghost_y = ghost.getPosition() 110 | grid_state[int(ghost_x)][int(ghost_y)] = 5 111 | 112 | # 展平为一维向量 113 | grid_features = grid_state.flatten() 114 | 115 | # 添加额外的非空间特征 116 | 117 | # 1. 得分 118 | score_enc = np.array([state.getScore()]) 119 | 120 | # 2. 剩余食物数量 121 | food_count_enc = np.array([state.getNumFood()]) 122 | 123 | # 3. 剩余胶囊数量 124 | capsule_count_enc = np.array([len(capsules)]) # 假设最多4个胶囊 125 | 126 | # 将所有特征连接成一个向量 127 | features = np.concatenate([ 128 | grid_features, # 地图状态 129 | score_enc, # 得分 130 | food_count_enc, # 剩余食物数量 131 | capsule_count_enc, # 剩余胶囊数量 132 | ]) 133 | 134 | return features.astype(np.float32) 135 | ``` -------------------------------------------------------------------------------- /7_RL_Pacman/textDisplay.py: -------------------------------------------------------------------------------- 1 | # textDisplay.py 2 | # -------------- 3 | # Licensing Information: You are free to use or extend these projects for 4 | # educational purposes provided that (1) you do not distribute or publish 5 | # solutions, (2) you retain this notice, and (3) you provide clear 6 | # attribution to UC Berkeley, including a link to http://ai.berkeley.edu. 7 | # 8 | # Attribution Information: The Pacman AI projects were developed at UC Berkeley. 9 | # The core projects and autograders were primarily created by John DeNero 10 | # (denero@cs.berkeley.edu) and Dan Klein (klein@cs.berkeley.edu). 11 | # Student side autograding was added by Brad Miller, Nick Hay, and 12 | # Pieter Abbeel (pabbeel@cs.berkeley.edu). 13 | 14 | 15 | import time 16 | try: 17 | import pacman 18 | except: 19 | pass 20 | 21 | DRAW_EVERY = 1 22 | SLEEP_TIME = 0 # This can be overwritten by __init__ 23 | DISPLAY_MOVES = False 24 | QUIET = False # Supresses output 25 | 26 | class NullGraphics: 27 | def initialize(self, state, isBlue = False): 28 | pass 29 | 30 | def update(self, state): 31 | pass 32 | 33 | def checkNullDisplay(self): 34 | return True 35 | 36 | def pause(self): 37 | time.sleep(SLEEP_TIME) 38 | 39 | def draw(self, state): 40 | print(state) 41 | 42 | def updateDistributions(self, dist): 43 | pass 44 | 45 | def finish(self): 46 | pass 47 | 48 | class PacmanGraphics: 49 | def __init__(self, speed=None): 50 | if speed != None: 51 | global SLEEP_TIME 52 | SLEEP_TIME = speed 53 | 54 | def initialize(self, state, isBlue = False): 55 | self.draw(state) 56 | self.pause() 57 | self.turn = 0 58 | self.agentCounter = 0 59 | 60 | def update(self, state): 61 | numAgents = len(state.agentStates) 62 | self.agentCounter = (self.agentCounter + 1) % numAgents 63 | if self.agentCounter == 0: 64 | self.turn += 1 65 | if DISPLAY_MOVES: 66 | ghosts = [pacman.nearestPoint(state.getGhostPosition(i)) for i in range(1, numAgents)] 67 | print("%4d) P: %-8s" % (self.turn, str(pacman.nearestPoint(state.getPacmanPosition()))),'| Score: %-5d' % state.score,'| Ghosts:', ghosts) 68 | if self.turn % DRAW_EVERY == 0: 69 | self.draw(state) 70 | self.pause() 71 | if state._win or state._lose: 72 | self.draw(state) 73 | 74 | def pause(self): 75 | time.sleep(SLEEP_TIME) 76 | 77 | def draw(self, state): 78 | print(state) 79 | 80 | def finish(self): 81 | pass 82 | -------------------------------------------------------------------------------- /7_RL_Pacman/util.py: -------------------------------------------------------------------------------- 1 | # util.py 2 | # ------- 3 | # Licensing Information: You are free to use or extend these projects for 4 | # educational purposes provided that (1) you do not distribute or publish 5 | # solutions, (2) you retain this notice, and (3) you provide clear 6 | # attribution to UC Berkeley, including a link to http://ai.berkeley.edu. 7 | # 8 | # Attribution Information: The Pacman AI projects were developed at UC Berkeley. 9 | # The core projects and autograders were primarily created by John DeNero 10 | # (denero@cs.berkeley.edu) and Dan Klein (klein@cs.berkeley.edu). 11 | # Student side autograding was added by Brad Miller, Nick Hay, and 12 | # Pieter Abbeel (pabbeel@cs.berkeley.edu). 13 | 14 | 15 | # util.py 16 | # ------- 17 | # Licensing Information: You are free to use or extend these projects for 18 | # educational purposes provided that (1) you do not distribute or publish 19 | # solutions, (2) you retain this notice, and (3) you provide clear 20 | # attribution to UC Berkeley, including a link to http://ai.berkeley.edu. 21 | # 22 | # Attribution Information: The Pacman AI projects were developed at UC Berkeley. 23 | # The core projects and autograders were primarily created by John DeNero 24 | # (denero@cs.berkeley.edu) and Dan Klein (klein@cs.berkeley.edu). 25 | # Student side autograding was added by Brad Miller, Nick Hay, and 26 | # Pieter Abbeel (pabbeel@cs.berkeley.edu). 27 | 28 | 29 | import sys 30 | import inspect 31 | import heapq, random 32 | 33 | 34 | class FixedRandom: 35 | def __init__(self): 36 | fixedState = (3, (2147483648, 507801126, 683453281, 310439348, 2597246090, \ 37 | 2209084787, 2267831527, 979920060, 3098657677, 37650879, 807947081, 3974896263, \ 38 | 881243242, 3100634921, 1334775171, 3965168385, 746264660, 4074750168, 500078808, \ 39 | 776561771, 702988163, 1636311725, 2559226045, 157578202, 2498342920, 2794591496, \ 40 | 4130598723, 496985844, 2944563015, 3731321600, 3514814613, 3362575829, 3038768745, \ 41 | 2206497038, 1108748846, 1317460727, 3134077628, 988312410, 1674063516, 746456451, \ 42 | 3958482413, 1857117812, 708750586, 1583423339, 3466495450, 1536929345, 1137240525, \ 43 | 3875025632, 2466137587, 1235845595, 4214575620, 3792516855, 657994358, 1241843248, \ 44 | 1695651859, 3678946666, 1929922113, 2351044952, 2317810202, 2039319015, 460787996, \ 45 | 3654096216, 4068721415, 1814163703, 2904112444, 1386111013, 574629867, 2654529343, \ 46 | 3833135042, 2725328455, 552431551, 4006991378, 1331562057, 3710134542, 303171486, \ 47 | 1203231078, 2670768975, 54570816, 2679609001, 578983064, 1271454725, 3230871056, \ 48 | 2496832891, 2944938195, 1608828728, 367886575, 2544708204, 103775539, 1912402393, \ 49 | 1098482180, 2738577070, 3091646463, 1505274463, 2079416566, 659100352, 839995305, \ 50 | 1696257633, 274389836, 3973303017, 671127655, 1061109122, 517486945, 1379749962, \ 51 | 3421383928, 3116950429, 2165882425, 2346928266, 2892678711, 2936066049, 1316407868, \ 52 | 2873411858, 4279682888, 2744351923, 3290373816, 1014377279, 955200944, 4220990860, \ 53 | 2386098930, 1772997650, 3757346974, 1621616438, 2877097197, 442116595, 2010480266, \ 54 | 2867861469, 2955352695, 605335967, 2222936009, 2067554933, 4129906358, 1519608541, \ 55 | 1195006590, 1942991038, 2736562236, 279162408, 1415982909, 4099901426, 1732201505, \ 56 | 2934657937, 860563237, 2479235483, 3081651097, 2244720867, 3112631622, 1636991639, \ 57 | 3860393305, 2312061927, 48780114, 1149090394, 2643246550, 1764050647, 3836789087, \ 58 | 3474859076, 4237194338, 1735191073, 2150369208, 92164394, 756974036, 2314453957, \ 59 | 323969533, 4267621035, 283649842, 810004843, 727855536, 1757827251, 3334960421, \ 60 | 3261035106, 38417393, 2660980472, 1256633965, 2184045390, 811213141, 2857482069, \ 61 | 2237770878, 3891003138, 2787806886, 2435192790, 2249324662, 3507764896, 995388363, \ 62 | 856944153, 619213904, 3233967826, 3703465555, 3286531781, 3863193356, 2992340714, \ 63 | 413696855, 3865185632, 1704163171, 3043634452, 2225424707, 2199018022, 3506117517, \ 64 | 3311559776, 3374443561, 1207829628, 668793165, 1822020716, 2082656160, 1160606415, \ 65 | 3034757648, 741703672, 3094328738, 459332691, 2702383376, 1610239915, 4162939394, \ 66 | 557861574, 3805706338, 3832520705, 1248934879, 3250424034, 892335058, 74323433, \ 67 | 3209751608, 3213220797, 3444035873, 3743886725, 1783837251, 610968664, 580745246, \ 68 | 4041979504, 201684874, 2673219253, 1377283008, 3497299167, 2344209394, 2304982920, \ 69 | 3081403782, 2599256854, 3184475235, 3373055826, 695186388, 2423332338, 222864327, \ 70 | 1258227992, 3627871647, 3487724980, 4027953808, 3053320360, 533627073, 3026232514, \ 71 | 2340271949, 867277230, 868513116, 2158535651, 2487822909, 3428235761, 3067196046, \ 72 | 3435119657, 1908441839, 788668797, 3367703138, 3317763187, 908264443, 2252100381, \ 73 | 764223334, 4127108988, 384641349, 3377374722, 1263833251, 1958694944, 3847832657, \ 74 | 1253909612, 1096494446, 555725445, 2277045895, 3340096504, 1383318686, 4234428127, \ 75 | 1072582179, 94169494, 1064509968, 2681151917, 2681864920, 734708852, 1338914021, \ 76 | 1270409500, 1789469116, 4191988204, 1716329784, 2213764829, 3712538840, 919910444, \ 77 | 1318414447, 3383806712, 3054941722, 3378649942, 1205735655, 1268136494, 2214009444, \ 78 | 2532395133, 3232230447, 230294038, 342599089, 772808141, 4096882234, 3146662953, \ 79 | 2784264306, 1860954704, 2675279609, 2984212876, 2466966981, 2627986059, 2985545332, \ 80 | 2578042598, 1458940786, 2944243755, 3959506256, 1509151382, 325761900, 942251521, \ 81 | 4184289782, 2756231555, 3297811774, 1169708099, 3280524138, 3805245319, 3227360276, \ 82 | 3199632491, 2235795585, 2865407118, 36763651, 2441503575, 3314890374, 1755526087, \ 83 | 17915536, 1196948233, 949343045, 3815841867, 489007833, 2654997597, 2834744136, \ 84 | 417688687, 2843220846, 85621843, 747339336, 2043645709, 3520444394, 1825470818, \ 85 | 647778910, 275904777, 1249389189, 3640887431, 4200779599, 323384601, 3446088641, \ 86 | 4049835786, 1718989062, 3563787136, 44099190, 3281263107, 22910812, 1826109246, \ 87 | 745118154, 3392171319, 1571490704, 354891067, 815955642, 1453450421, 940015623, \ 88 | 796817754, 1260148619, 3898237757, 176670141, 1870249326, 3317738680, 448918002, \ 89 | 4059166594, 2003827551, 987091377, 224855998, 3520570137, 789522610, 2604445123, \ 90 | 454472869, 475688926, 2990723466, 523362238, 3897608102, 806637149, 2642229586, \ 91 | 2928614432, 1564415411, 1691381054, 3816907227, 4082581003, 1895544448, 3728217394, \ 92 | 3214813157, 4054301607, 1882632454, 2873728645, 3694943071, 1297991732, 2101682438, \ 93 | 3952579552, 678650400, 1391722293, 478833748, 2976468591, 158586606, 2576499787, \ 94 | 662690848, 3799889765, 3328894692, 2474578497, 2383901391, 1718193504, 3003184595, \ 95 | 3630561213, 1929441113, 3848238627, 1594310094, 3040359840, 3051803867, 2462788790, \ 96 | 954409915, 802581771, 681703307, 545982392, 2738993819, 8025358, 2827719383, \ 97 | 770471093, 3484895980, 3111306320, 3900000891, 2116916652, 397746721, 2087689510, \ 98 | 721433935, 1396088885, 2751612384, 1998988613, 2135074843, 2521131298, 707009172, \ 99 | 2398321482, 688041159, 2264560137, 482388305, 207864885, 3735036991, 3490348331, \ 100 | 1963642811, 3260224305, 3493564223, 1939428454, 1128799656, 1366012432, 2858822447, \ 101 | 1428147157, 2261125391, 1611208390, 1134826333, 2374102525, 3833625209, 2266397263, \ 102 | 3189115077, 770080230, 2674657172, 4280146640, 3604531615, 4235071805, 3436987249, \ 103 | 509704467, 2582695198, 4256268040, 3391197562, 1460642842, 1617931012, 457825497, \ 104 | 1031452907, 1330422862, 4125947620, 2280712485, 431892090, 2387410588, 2061126784, \ 105 | 896457479, 3480499461, 2488196663, 4021103792, 1877063114, 2744470201, 1046140599, \ 106 | 2129952955, 3583049218, 4217723693, 2720341743, 820661843, 1079873609, 3360954200, \ 107 | 3652304997, 3335838575, 2178810636, 1908053374, 4026721976, 1793145418, 476541615, \ 108 | 973420250, 515553040, 919292001, 2601786155, 1685119450, 3030170809, 1590676150, \ 109 | 1665099167, 651151584, 2077190587, 957892642, 646336572, 2743719258, 866169074, \ 110 | 851118829, 4225766285, 963748226, 799549420, 1955032629, 799460000, 2425744063, \ 111 | 2441291571, 1928963772, 528930629, 2591962884, 3495142819, 1896021824, 901320159, \ 112 | 3181820243, 843061941, 3338628510, 3782438992, 9515330, 1705797226, 953535929, \ 113 | 764833876, 3202464965, 2970244591, 519154982, 3390617541, 566616744, 3438031503, \ 114 | 1853838297, 170608755, 1393728434, 676900116, 3184965776, 1843100290, 78995357, \ 115 | 2227939888, 3460264600, 1745705055, 1474086965, 572796246, 4081303004, 882828851, \ 116 | 1295445825, 137639900, 3304579600, 2722437017, 4093422709, 273203373, 2666507854, \ 117 | 3998836510, 493829981, 1623949669, 3482036755, 3390023939, 833233937, 1639668730, \ 118 | 1499455075, 249728260, 1210694006, 3836497489, 1551488720, 3253074267, 3388238003, \ 119 | 2372035079, 3945715164, 2029501215, 3362012634, 2007375355, 4074709820, 631485888, \ 120 | 3135015769, 4273087084, 3648076204, 2739943601, 1374020358, 1760722448, 3773939706, \ 121 | 1313027823, 1895251226, 4224465911, 421382535, 1141067370, 3660034846, 3393185650, \ 122 | 1850995280, 1451917312, 3841455409, 3926840308, 1397397252, 2572864479, 2500171350, \ 123 | 3119920613, 531400869, 1626487579, 1099320497, 407414753, 2438623324, 99073255, \ 124 | 3175491512, 656431560, 1153671785, 236307875, 2824738046, 2320621382, 892174056, \ 125 | 230984053, 719791226, 2718891946, 624), None) 126 | self.random = random.Random() 127 | self.random.setstate(fixedState) 128 | 129 | """ 130 | Data structures useful for implementing SearchAgents 131 | """ 132 | 133 | class Stack: 134 | "A container with a last-in-first-out (LIFO) queuing policy." 135 | def __init__(self): 136 | self.list = [] 137 | 138 | def push(self,item): 139 | "Push 'item' onto the stack" 140 | self.list.append(item) 141 | 142 | def pop(self): 143 | "Pop the most recently pushed item from the stack" 144 | return self.list.pop() 145 | 146 | def isEmpty(self): 147 | "Returns true if the stack is empty" 148 | return len(self.list) == 0 149 | 150 | class Queue: 151 | "A container with a first-in-first-out (FIFO) queuing policy." 152 | def __init__(self): 153 | self.list = [] 154 | 155 | def push(self,item): 156 | "Enqueue the 'item' into the queue" 157 | self.list.insert(0,item) 158 | 159 | def pop(self): 160 | """ 161 | Dequeue the earliest enqueued item still in the queue. This 162 | operation removes the item from the queue. 163 | """ 164 | return self.list.pop() 165 | 166 | def isEmpty(self): 167 | "Returns true if the queue is empty" 168 | return len(self.list) == 0 169 | 170 | class PriorityQueue: 171 | """ 172 | Implements a priority queue data structure. Each inserted item 173 | has a priority associated with it and the client is usually interested 174 | in quick retrieval of the lowest-priority item in the queue. This 175 | data structure allows O(1) access to the lowest-priority item. 176 | """ 177 | def __init__(self): 178 | self.heap = [] 179 | self.count = 0 180 | 181 | def push(self, item, priority): 182 | entry = (priority, self.count, item) 183 | heapq.heappush(self.heap, entry) 184 | self.count += 1 185 | 186 | def pop(self): 187 | (_, _, item) = heapq.heappop(self.heap) 188 | return item 189 | 190 | def isEmpty(self): 191 | return len(self.heap) == 0 192 | 193 | def update(self, item, priority): 194 | # If item already in priority queue with higher priority, update its priority and rebuild the heap. 195 | # If item already in priority queue with equal or lower priority, do nothing. 196 | # If item not in priority queue, do the same thing as self.push. 197 | for index, (p, c, i) in enumerate(self.heap): 198 | if i == item: 199 | if p <= priority: 200 | break 201 | del self.heap[index] 202 | self.heap.append((priority, c, item)) 203 | heapq.heapify(self.heap) 204 | break 205 | else: 206 | self.push(item, priority) 207 | 208 | class PriorityQueueWithFunction(PriorityQueue): 209 | """ 210 | Implements a priority queue with the same push/pop signature of the 211 | Queue and the Stack classes. This is designed for drop-in replacement for 212 | those two classes. The caller has to provide a priority function, which 213 | extracts each item's priority. 214 | """ 215 | def __init__(self, priorityFunction): 216 | "priorityFunction (item) -> priority" 217 | self.priorityFunction = priorityFunction # store the priority function 218 | PriorityQueue.__init__(self) # super-class initializer 219 | 220 | def push(self, item): 221 | "Adds an item to the queue with priority from the priority function" 222 | PriorityQueue.push(self, item, self.priorityFunction(item)) 223 | 224 | 225 | def manhattanDistance( xy1, xy2 ): 226 | "Returns the Manhattan distance between points xy1 and xy2" 227 | return abs( xy1[0] - xy2[0] ) + abs( xy1[1] - xy2[1] ) 228 | 229 | """ 230 | Data structures and functions useful for various course projects 231 | 232 | The search project should not need anything below this line. 233 | """ 234 | 235 | class Counter(dict): 236 | """ 237 | A counter keeps track of counts for a set of keys. 238 | 239 | The counter class is an extension of the standard python 240 | dictionary type. It is specialized to have number values 241 | (integers or floats), and includes a handful of additional 242 | functions to ease the task of counting data. In particular, 243 | all keys are defaulted to have value 0. Using a dictionary: 244 | 245 | a = {} 246 | print(a['test']) 247 | 248 | would give an error, while the Counter class analogue: 249 | 250 | >>> a = Counter() 251 | >>> print(a['test']) 252 | 0 253 | 254 | returns the default 0 value. Note that to reference a key 255 | that you know is contained in the counter, 256 | you can still use the dictionary syntax: 257 | 258 | >>> a = Counter() 259 | >>> a['test'] = 2 260 | >>> print(a['test']) 261 | 2 262 | 263 | This is very useful for counting things without initializing their counts, 264 | see for example: 265 | 266 | >>> a['blah'] += 1 267 | >>> print(a['blah']) 268 | 1 269 | 270 | The counter also includes additional functionality useful in implementing 271 | the classifiers for this assignment. Two counters can be added, 272 | subtracted or multiplied together. See below for details. They can 273 | also be normalized and their total count and arg max can be extracted. 274 | """ 275 | def __getitem__(self, idx): 276 | self.setdefault(idx, 0) 277 | return dict.__getitem__(self, idx) 278 | 279 | def incrementAll(self, keys, count): 280 | """ 281 | Increments all elements of keys by the same count. 282 | 283 | >>> a = Counter() 284 | >>> a.incrementAll(['one','two', 'three'], 1) 285 | >>> a['one'] 286 | 1 287 | >>> a['two'] 288 | 1 289 | """ 290 | for key in keys: 291 | self[key] += count 292 | 293 | def argMax(self): 294 | """ 295 | Returns the key with the highest value. 296 | """ 297 | if len(self.keys()) == 0: return None 298 | all = self.items() 299 | values = [x[1] for x in all] 300 | maxIndex = values.index(max(values)) 301 | return all[maxIndex][0] 302 | 303 | def sortedKeys(self): 304 | """ 305 | Returns a list of keys sorted by their values. Keys 306 | with the highest values will appear first. 307 | 308 | >>> a = Counter() 309 | >>> a['first'] = -2 310 | >>> a['second'] = 4 311 | >>> a['third'] = 1 312 | >>> a.sortedKeys() 313 | ['second', 'third', 'first'] 314 | """ 315 | sortedItems = self.items() 316 | compare = lambda x, y: sign(y[1] - x[1]) 317 | sortedItems.sort(cmp=compare) 318 | return [x[0] for x in sortedItems] 319 | 320 | def totalCount(self): 321 | """ 322 | Returns the sum of counts for all keys. 323 | """ 324 | return sum(self.values()) 325 | 326 | def normalize(self): 327 | """ 328 | Edits the counter such that the total count of all 329 | keys sums to 1. The ratio of counts for all keys 330 | will remain the same. Note that normalizing an empty 331 | Counter will result in an error. 332 | """ 333 | total = float(self.totalCount()) 334 | if total == 0: return 335 | for key in self.keys(): 336 | self[key] = self[key] / total 337 | 338 | def divideAll(self, divisor): 339 | """ 340 | Divides all counts by divisor 341 | """ 342 | divisor = float(divisor) 343 | for key in self: 344 | self[key] /= divisor 345 | 346 | def copy(self): 347 | """ 348 | Returns a copy of the counter 349 | """ 350 | return Counter(dict.copy(self)) 351 | 352 | def __mul__(self, y ): 353 | """ 354 | Multiplying two counters gives the dot product of their vectors where 355 | each unique label is a vector element. 356 | 357 | >>> a = Counter() 358 | >>> b = Counter() 359 | >>> a['first'] = -2 360 | >>> a['second'] = 4 361 | >>> b['first'] = 3 362 | >>> b['second'] = 5 363 | >>> a['third'] = 1.5 364 | >>> a['fourth'] = 2.5 365 | >>> a * b 366 | 14 367 | """ 368 | sum = 0 369 | x = self 370 | if len(x) > len(y): 371 | x,y = y,x 372 | for key in x: 373 | if key not in y: 374 | continue 375 | sum += x[key] * y[key] 376 | return sum 377 | 378 | def __radd__(self, y): 379 | """ 380 | Adding another counter to a counter increments the current counter 381 | by the values stored in the second counter. 382 | 383 | >>> a = Counter() 384 | >>> b = Counter() 385 | >>> a['first'] = -2 386 | >>> a['second'] = 4 387 | >>> b['first'] = 3 388 | >>> b['third'] = 1 389 | >>> a += b 390 | >>> a['first'] 391 | 1 392 | """ 393 | for key, value in y.items(): 394 | self[key] += value 395 | 396 | def __add__( self, y ): 397 | """ 398 | Adding two counters gives a counter with the union of all keys and 399 | counts of the second added to counts of the first. 400 | 401 | >>> a = Counter() 402 | >>> b = Counter() 403 | >>> a['first'] = -2 404 | >>> a['second'] = 4 405 | >>> b['first'] = 3 406 | >>> b['third'] = 1 407 | >>> (a + b)['first'] 408 | 1 409 | """ 410 | addend = Counter() 411 | for key in self: 412 | if key in y: 413 | addend[key] = self[key] + y[key] 414 | else: 415 | addend[key] = self[key] 416 | for key in y: 417 | if key in self: 418 | continue 419 | addend[key] = y[key] 420 | return addend 421 | 422 | def __sub__( self, y ): 423 | """ 424 | Subtracting a counter from another gives a counter with the union of all keys and 425 | counts of the second subtracted from counts of the first. 426 | 427 | >>> a = Counter() 428 | >>> b = Counter() 429 | >>> a['first'] = -2 430 | >>> a['second'] = 4 431 | >>> b['first'] = 3 432 | >>> b['third'] = 1 433 | >>> (a - b)['first'] 434 | -5 435 | """ 436 | addend = Counter() 437 | for key in self: 438 | if key in y: 439 | addend[key] = self[key] - y[key] 440 | else: 441 | addend[key] = self[key] 442 | for key in y: 443 | if key in self: 444 | continue 445 | addend[key] = -1 * y[key] 446 | return addend 447 | 448 | def raiseNotDefined(): 449 | fileName = inspect.stack()[1][1] 450 | line = inspect.stack()[1][2] 451 | method = inspect.stack()[1][3] 452 | 453 | print("*** Method not implemented: %s at line %s of %s" % (method, line, fileName)) 454 | sys.exit(1) 455 | 456 | def normalize(vectorOrCounter): 457 | """ 458 | normalize a vector or counter by dividing each value by the sum of all values 459 | """ 460 | normalizedCounter = Counter() 461 | if type(vectorOrCounter) == type(normalizedCounter): 462 | counter = vectorOrCounter 463 | total = float(counter.totalCount()) 464 | if total == 0: return counter 465 | for key in counter.keys(): 466 | value = counter[key] 467 | normalizedCounter[key] = value / total 468 | return normalizedCounter 469 | else: 470 | vector = vectorOrCounter 471 | s = float(sum(vector)) 472 | if s == 0: return vector 473 | return [el / s for el in vector] 474 | 475 | def nSample(distribution, values, n): 476 | if sum(distribution) != 1: 477 | distribution = normalize(distribution) 478 | rand = [random.random() for i in range(n)] 479 | rand.sort() 480 | samples = [] 481 | samplePos, distPos, cdf = 0,0, distribution[0] 482 | while samplePos < n: 483 | if rand[samplePos] < cdf: 484 | samplePos += 1 485 | samples.append(values[distPos]) 486 | else: 487 | distPos += 1 488 | cdf += distribution[distPos] 489 | return samples 490 | 491 | def sample(distribution, values = None): 492 | if type(distribution) == Counter: 493 | items = sorted(distribution.items()) 494 | distribution = [i[1] for i in items] 495 | values = [i[0] for i in items] 496 | if sum(distribution) != 1: 497 | distribution = normalize(distribution) 498 | choice = random.random() 499 | i, total= 0, distribution[0] 500 | while choice > total: 501 | i += 1 502 | total += distribution[i] 503 | return values[i] 504 | 505 | def sampleFromCounter(ctr): 506 | items = sorted(ctr.items()) 507 | return sample([v for k,v in items], [k for k,v in items]) 508 | 509 | def getProbability(value, distribution, values): 510 | """ 511 | Gives the probability of a value under a discrete distribution 512 | defined by (distributions, values). 513 | """ 514 | total = 0.0 515 | for prob, val in zip(distribution, values): 516 | if val == value: 517 | total += prob 518 | return total 519 | 520 | def flipCoin( p ): 521 | r = random.random() 522 | return r < p 523 | 524 | def chooseFromDistribution( distribution ): 525 | "Takes either a counter or a list of (prob, key) pairs and samples" 526 | if type(distribution) == dict or type(distribution) == Counter: 527 | return sample(distribution) 528 | r = random.random() 529 | base = 0.0 530 | for prob, element in distribution: 531 | base += prob 532 | if r <= base: return element 533 | 534 | def nearestPoint( pos ): 535 | """ 536 | Finds the nearest grid point to a position (discretizes). 537 | """ 538 | ( current_row, current_col ) = pos 539 | 540 | grid_row = int( current_row + 0.5 ) 541 | grid_col = int( current_col + 0.5 ) 542 | return ( grid_row, grid_col ) 543 | 544 | def sign( x ): 545 | """ 546 | Returns 1 or -1 depending on the sign of x 547 | """ 548 | if( x >= 0 ): 549 | return 1 550 | else: 551 | return -1 552 | 553 | def arrayInvert(array): 554 | """ 555 | Inverts a matrix stored as a list of lists. 556 | """ 557 | result = [[] for i in array] 558 | for outer in array: 559 | for inner in range(len(outer)): 560 | result[inner].append(outer[inner]) 561 | return result 562 | 563 | def matrixAsList( matrix, value = True ): 564 | """ 565 | Turns a matrix into a list of coordinates matching the specified value 566 | """ 567 | rows, cols = len( matrix ), len( matrix[0] ) 568 | cells = [] 569 | for row in range( rows ): 570 | for col in range( cols ): 571 | if matrix[row][col] == value: 572 | cells.append( ( row, col ) ) 573 | return cells 574 | 575 | def lookup(name, namespace): 576 | """ 577 | Get a method or class from any imported module from its name. 578 | Usage: lookup(functionName, globals()) 579 | """ 580 | dots = name.count('.') 581 | if dots > 0: 582 | moduleName, objName = '.'.join(name.split('.')[:-1]), name.split('.')[-1] 583 | module = __import__(moduleName) 584 | return getattr(module, objName) 585 | else: 586 | modules = [obj for obj in namespace.values() if str(type(obj)) == ""] 587 | options = [getattr(module, name) for module in modules if name in dir(module)] 588 | options += [obj[1] for obj in namespace.items() if obj[0] == name ] 589 | if len(options) == 1: return options[0] 590 | if len(options) > 1: raise Exception('Name conflict for %s') 591 | raise Exception('%s not found as a method or class' % name) 592 | 593 | def pause(): 594 | """ 595 | Pauses the output stream awaiting user feedback. 596 | """ 597 | input("") 598 | 599 | 600 | # code to handle timeouts 601 | # 602 | # FIXME 603 | # NOTE: TimeoutFuncton is NOT reentrant. Later timeouts will silently 604 | # disable earlier timeouts. Could be solved by maintaining a global list 605 | # of active time outs. Currently, questions which have test cases calling 606 | # this have all student code so wrapped. 607 | # 608 | import signal 609 | import time 610 | class TimeoutFunctionException(Exception): 611 | """Exception to raise on a timeout""" 612 | pass 613 | 614 | 615 | class TimeoutFunction: 616 | def __init__(self, function, timeout): 617 | self.timeout = timeout 618 | self.function = function 619 | 620 | def handle_timeout(self, signum, frame): 621 | raise TimeoutFunctionException() 622 | 623 | def __call__(self, *args, **keyArgs): 624 | # If we have SIGALRM signal, use it to cause an exception if and 625 | # when this function runs too long. Otherwise check the time taken 626 | # after the method has returned, and throw an exception then. 627 | if hasattr(signal, 'SIGALRM'): 628 | old = signal.signal(signal.SIGALRM, self.handle_timeout) 629 | signal.alarm(self.timeout) 630 | try: 631 | result = self.function(*args, **keyArgs) 632 | finally: 633 | signal.signal(signal.SIGALRM, old) 634 | signal.alarm(0) 635 | else: 636 | startTime = time.time() 637 | result = self.function(*args, **keyArgs) 638 | timeElapsed = time.time() - startTime 639 | if timeElapsed >= self.timeout: 640 | self.handle_timeout(None, None) 641 | return result 642 | 643 | 644 | 645 | _ORIGINAL_STDOUT = None 646 | _ORIGINAL_STDERR = None 647 | _MUTED = False 648 | 649 | class WritableNull: 650 | def write(self, string): 651 | pass 652 | 653 | def mutePrint(): 654 | global _ORIGINAL_STDOUT, _ORIGINAL_STDERR, _MUTED 655 | if _MUTED: 656 | return 657 | _MUTED = True 658 | 659 | _ORIGINAL_STDOUT = sys.stdout 660 | #_ORIGINAL_STDERR = sys.stderr 661 | sys.stdout = WritableNull() 662 | #sys.stderr = WritableNull() 663 | 664 | def unmutePrint(): 665 | global _ORIGINAL_STDOUT, _ORIGINAL_STDERR, _MUTED 666 | if not _MUTED: 667 | return 668 | _MUTED = False 669 | 670 | sys.stdout = _ORIGINAL_STDOUT 671 | #sys.stderr = _ORIGINAL_STDERR 672 | 673 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # 🚀 PyTorch 深度学习实战教程 🚀 2 | 3 | 欢迎来到 **PyTorch 深度学习实战教程**!🎉 4 | 5 | 本项目汇集了多个基于 PyTorch 的深度学习实战项目,注释详尽、结构清晰,非常适合作为学习和练习 PyTorch 的参考项目。 6 | 7 | **不断完善ing...** 8 | 9 | ## 🎯更新日志 10 | - 2024-11-23: 新增项目1: 手写体数字识别 11 | - 2024-12-02: 新增项目2: 猫狗图像分类 12 | - 2025-01-01: 新增项目3: 中文影评文本分类 13 | - 2025-02-28: 新增项目4: GAN 图像生成 14 | - 2025-03-19: 新增项目5: RL 贪吃蛇 15 | - 2025-04-02: 新增项目6: 对抗样本攻击(待优化) 16 | 17 | --- 18 | 19 | ## 目录 📚 20 | 21 | 1. [手写体数字识别 (Handwritten Digit Recognition)](#1-手写体数字识别-handwritten-digit-recognition) 22 | 2. [猫狗图像分类 (Cat & Dog Image Classification)](#2-猫狗图像分类-cat--dog-image-classification) 23 | 3. [Transformer-情感分类 (Transformer-based Sentiment Classification)](#3-transformer-情感分类-transformer-based-sentiment-classification) 24 | 4. [GAN-图像生成 (GAN Image Generation)](#4-GAN-图像生成-GAN-image-generation) 25 | 5. [RL-贪吃蛇 (Reinforcement Learning Snake)](#5-RL-贪吃蛇-reinforcement-learning-snake) 26 | 6. [对抗样本攻击 (Adversarial Attack)](#6-对抗样本攻击-adversarial-attack) 27 | 28 | --- 29 | 30 | ## 🧐 项目介绍 31 | 32 | - 包含多个经典 PyTorch 项目:手写体数字识别、猫狗分类等。 33 | - 每个项目配备完整的代码流程,包括数据处理、模型训练、结果可视化、模型保存及推理调用等。 34 | - 每个项目在训练好模型之后都会有使用模型进行推理的部分,可以直接把代码拿到别的地方使用,实用性很强。 35 | - 项目中的一些细节优化均有详细的注释, 但是并不保证绝对准确有效。你可以加以分析参考, 并提出自己的见解, 有什么想法欢迎提出issues进行交流。 36 | - **项目会持续更新和优化,欢迎关注!** 37 | 38 | 如果觉得这个项目有帮助, 别忘了点个star支持一下! 🌟, 并留下宝贵的意见。 39 | 如果有任何问题,欢迎提 Issues, 作者在看到后会及时回复。💬 40 | 41 | ### 项目结构 📂 42 | 43 | PyTorch 44 | - datasets 数据集 45 | - models 模型权重 46 | - 1_Handwritten_Digit_Recognition 47 | - 2_Cat_Dog_Image_Classification 48 | - 3_Chinese_Movie_Review_Text_Classification 49 | - readme.md 50 | 51 | ⚠️ **提示** 52 | - 有些模型效果可能不太好。 你可以自行探索,修改模型结构和一些参数以达到更好的效果,还可以在项目代码基础上扩展更丰富的功能。 53 | - 为了达到更好的训练效果, 作者在编写代码训练模型时在不断尝试不同的方法, 但是这个过程在代码中可能并没有体现。 希望大家能够自己多修改一些参数, 使用不同的模型结构和方法, 去感受模型训练的过程 54 | 55 | --- 56 | 57 | ## 环境要求 ⚙️ 58 | 59 | - Python 3.x 60 | - PyTorch 深度学习框架 61 | - Jupyter Notebook 用于交互式编程 62 | - 其他依赖库 63 | 64 | --- 65 | 66 | ## 1. 手写体数字识别 (Handwritten Digit Recognition) 67 | 68 | 🖊️🔢 69 | 70 | ### 🎯 项目概述 71 | 本项目基于经典的 MNIST 数据集,训练一个简单的卷积神经网络, 逐步实现从单个数字到多个数字的手写体识别。 72 | 73 | ### 项目展示 74 | 75 | ![手写体数字识别](./1_Handwritten_Digit_Recognition/data/demo1.png) 76 | 多数字识别 77 | ![手写体数字识别](./1_Handwritten_Digit_Recognition/data/demo2.png) 78 | 79 | ### 💡 一些想法 80 | 81 | - 分割的算法很关键,会直接影响到模型预测结果。如果有一个好的分割算法,用一个识别单个字符的模型就可以实现识别一组字符的功能。 82 | - 现在已经能够识别多个数字,如果自己能够找一些手写运算符号的数据集(+ - × ÷...)一起训练模型,就能让模型识别基本的算术表达式。自己再写函数对表达式进行处理计算结果,就得到了数学计算题目-->运算结果的高级模型😁! 83 | - 如果能够找到字母数据集进行训练, 就可以得到一个简单的OCR模型(汉字数据集太大了,就不考虑了哈哈) 84 | - 这个项目可以不断进行优化迭代, 希望大家能多多尝试! 85 | 86 | ### 进阶 87 | [K-12 手写体(HME100K)数据集](https://ai.100tal.com/dataset) 88 | 利用这里的数据集实现 图片-->对应的markdown公式 89 | 90 | 91 | --- 92 | 93 | ## 2. 猫狗图像分类 (Cat & Dog Image Classification) 94 | 95 | 🐱🐶 96 | 97 | ### 🎯 项目概述 98 | 99 | 本项目基于 [Kaggle 猫狗数据集](https://www.microsoft.com/en-us/download/details.aspx?id=54765) ,使用 PyTorch 构建并训练一个卷积神经网络,用于分类猫和狗的图像。 100 | 101 | 102 | ### 项目展示 103 | 104 | ![猫狗图像分类](./2_Cat_Dog_Image_Classification/data/demo.png) 105 | 106 | 107 | ### 项目亮点 ✨ 108 | - 相比于手写体数字识别的项目, 本项目构建了相对复杂的CNN模型 109 | - 项目readme中记录了不同的训练设置所对应的训练结果, 并且探讨了数据集标准化对训练结果的影响 110 | 111 | ### 待探究的问题 🔬 112 | - 当增加数据增强操作后, 明明数据集大小和数量都没变, 但是不知道为什么训练速度很明显变慢了(一个epoch平均不到5min->10min)。✗ 113 | - 使用基于当前数据集的统计值的数据进行标准化操作能不能达到更好的效果? ✗ 114 | 115 | --- 116 | 117 | ## 3. Transformer-情感分类 (Transformer-based Sentiment Classification) 118 | 119 | 🎬🗣 120 | 121 | ### 🎯 项目概述 122 | 123 | 本项目基于从豆瓣爬取的中文影评数据集,使用 PyTorch 实现了一个 Transformer 模型进行二分类任务(正面/负面情感)。 124 | 项目使用 bert-base-chinese 分词器对中文文本进行预处理,并通过 Transformer 架构捕捉评论的语义信息,旨在实现高效、准确的情感分类。 125 | 126 | 127 | ### 项目展示 128 | ![Transformer-情感分类](./3_Transformer_Sentiment_Classification/data/demo.png) 129 | 130 | ### 项目亮点 ✨ 131 | - **大规模自建数据集**:数据集由作者自己构建,包含 50 万+ 条高质量中文影评,数据丰富,质量较高 132 | - **自定义 Transformer 模型**:从头设计的 TransformerClassifier,包含可学习位置编码、多层 Transformer Encoder 和分类头,灵活适配中文情感分析 133 | 134 | ### 待探究的问题 🔬 135 | - **数据处理:** 本项目中对影评数据处理比较粗糙,可以探究多种中文文本数据预处理方法 136 | - **语义分析:** 提取 Transformer 注意力权重,分析哪些词对情感分类贡献最大,生成情感关键词云 137 | 138 | 139 | --- 140 | 141 | ## 4. GAN-图像生成 (GAN Image Generation) 142 | 143 | 🎨🖼️ 144 | 145 | ### 🎯 项目概述 146 | 147 | 本项目基于 MNIST 数据集, 使用生成对抗网络 (GAN) 框架,用于生成手写数字图像。 148 | 149 | ### 项目展示 150 | 151 | ![demo.png](./4_GAN_Image_Generator/data/demo.png) 152 | 153 | ### 项目亮点 ✨ 154 | - 引入类别标签使生成器能够生成特定类别的图像,(应该)可以直接直接扩展到其他图像生成任务 155 | - 使用 `Wasserstein GAN with Gradient Penalty (WGAN-GP)`损失函数 和 `频谱归一化 (Spectral Normalization)` 提高训练的稳定性 156 | 157 | ### 待探究的问题 🔬 158 | - 生成图像的质量仍有提升空间 159 | - 能否将代码迁移到其他图像生成任务? 160 | 161 | --- 162 | 163 | ## 5. RL-贪吃蛇 (Reinforcement Learning Snake) 164 | 165 | 🐍🎮 166 | ### 🎯 项目概述 167 | 168 | 本项目基于经典的贪吃蛇游戏,使用 PyTorch 实现了一个基于 深度 Q 学习 (DQN) 的强化学习模型。 169 | 贪吃蛇是大家都玩腻了的无聊游戏,不过自己训练AI去玩还是别有一番趣味的。可以尝试修改游戏逻辑和奖励机制,让自己的AI小蛇更聪明! 170 | 171 | ### 项目展示 172 | 173 | ![RL-贪吃蛇](./5_RL_Snake/data/AI_Snake.gif) 174 | 175 | ### 项目亮点 ✨ 176 | - 可以作为强化学习入门练手项目,好玩😋 177 | - 使用 DQN 算法,包含当前策略网络和目标策略网络,提升训练稳定性。 178 | - 实现了 经验回放 (Experience Replay) 和 $\epsilon$-Greedy 策略,平衡探索与利用。 179 | - 模型权重自动保存与加载,支持断点续训。 180 | 181 | --- 182 | 183 | ## 6. 对抗样本攻击 (Adversarial Attack) 184 | 185 | 👊🤖🔥 186 | ### 🎯 项目概述 187 | 188 | 模型对抗样本攻击(Adversarial Attack)指的是通过对输入数据进行微小、难以察觉的扰动,使得模型产生错误的预测或分类结果。 189 | 本项目基于第 2 个项目[猫狗图像分类 (Cat & Dog Image Classification)](./2_Cat_Dog_Image_Classification/)的模型,使用快速梯度符号法(FGSM)和投影梯度下降(PGD)生成对抗样本,分析模型在不同扰动强度下的鲁棒性。。 190 | 191 | ### 项目展示 192 | #### FGSM 批量攻击分析 193 | 194 | 195 | 198 | 201 | 202 |
196 | 正确率曲线 197 | 199 | 攻击样本格 200 |
203 | 204 | #### PGD 真实图片单体攻击分析 205 | ![对抗样本攻击](./6_Adversarial_Attack/data/demo.jpg) 206 | 207 | 208 | ### 项目亮点 ✨ 209 | - 展示了模型准确率随扰动强度(ϵ)急剧下降,验证了深度学习模型对对抗攻击的脆弱性 210 | - 针对真实世界猫狗图像,生成视觉上几乎无差别的对抗样本 211 | - 通过图像对比,直观展示对抗样本的微小扰动及其对模型预测的显著影响 212 | 213 | ### 待探究的问题 🔬 214 | - 猫的图片预测结果没有改变 215 | - 通过分析模型预测的置信度可以更好地理解攻击效果 216 | 217 | 218 | --- 219 | 220 | 221 | ## ✨ 更多项目待续... 222 | - 中英文翻译 [基于Transformer的中英文翻译项目实战](https://www.heywhale.com/mw/project/614314778447b80017694844) 223 | - 语音识别 224 | - Auto-Regressive + Diffusion 生图 225 | - yolo目标检测 226 | - alphago围棋 227 | - 对话系统 228 | - 强化学习实战 229 | 230 | 231 | Happy Coding! 😄 --------------------------------------------------------------------------------