"
2545 | ],
2546 | "text/html": [
2547 | "\n",
2548 | " \n",
2549 | " \n",
2550 | "
\n",
2551 | " [100/100 00:38, Epoch 2/2]\n",
2552 | "
\n",
2553 | " \n",
2554 | " \n",
2555 | " \n",
2556 | " Step | \n",
2557 | " Training Loss | \n",
2558 | "
\n",
2559 | " \n",
2560 | " \n",
2561 | " \n",
2562 | "
"
2563 | ]
2564 | },
2565 | "metadata": {}
2566 | },
2567 | {
2568 | "output_type": "stream",
2569 | "name": "stderr",
2570 | "text": [
2571 | "Saving model checkpoint to ./checkpoint-100\n",
2572 | "\n",
2573 | "\n",
2574 | "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
2575 | "\n",
2576 | "\n",
2577 | "***** Running Evaluation *****\n",
2578 | " Num examples = 50\n",
2579 | " Batch size = 8\n"
2580 | ]
2581 | },
2582 | {
2583 | "output_type": "display_data",
2584 | "data": {
2585 | "text/plain": [
2586 | ""
2587 | ],
2588 | "text/html": [
2589 | "\n",
2590 | " \n",
2591 | " \n",
2592 | "
\n",
2593 | " [7/7 00:00]\n",
2594 | "
\n",
2595 | " "
2596 | ]
2597 | },
2598 | "metadata": {}
2599 | },
2600 | {
2601 | "output_type": "execute_result",
2602 | "data": {
2603 | "text/plain": [
2604 | "{'epoch': 2.0,\n",
2605 | " 'eval_accuracy': 0.34,\n",
2606 | " 'eval_f1': 0.34,\n",
2607 | " 'eval_loss': 1.1039652824401855,\n",
2608 | " 'eval_precision': 0.34,\n",
2609 | " 'eval_recall': 0.34,\n",
2610 | " 'eval_runtime': 0.8177,\n",
2611 | " 'eval_samples_per_second': 61.144,\n",
2612 | " 'eval_steps_per_second': 8.56}"
2613 | ]
2614 | },
2615 | "metadata": {},
2616 | "execution_count": 24
2617 | }
2618 | ]
2619 | },
2620 | {
2621 | "cell_type": "markdown",
2622 | "source": [
2623 | "## predict by fine tuned model"
2624 | ],
2625 | "metadata": {
2626 | "id": "Z0n_haQsj5x4"
2627 | }
2628 | },
2629 | {
2630 | "cell_type": "code",
2631 | "source": [
2632 | "pred = classification_model(torch.tensor(test_dataset[\"input_ids\"][0:10]))\n",
2633 | "pred"
2634 | ],
2635 | "metadata": {
2636 | "colab": {
2637 | "base_uri": "https://localhost:8080/"
2638 | },
2639 | "id": "64vesgQ5j87g",
2640 | "outputId": "431af27c-073d-4da1-dc07-cf7071330e2c"
2641 | },
2642 | "execution_count": null,
2643 | "outputs": [
2644 | {
2645 | "output_type": "execute_result",
2646 | "data": {
2647 | "text/plain": [
2648 | "SequenceClassifierOutput([('logits', tensor([[-0.0239, 0.0068, -0.3391],\n",
2649 | " [ 0.0102, -0.0229, -0.2956],\n",
2650 | " [ 0.0431, 0.0213, -0.2269],\n",
2651 | " [ 0.0300, -0.0063, -0.2523],\n",
2652 | " [ 0.0370, -0.0021, -0.2569],\n",
2653 | " [ 0.0132, -0.0017, -0.2976],\n",
2654 | " [ 0.0859, 0.0508, -0.1927],\n",
2655 | " [-0.0331, -0.0096, -0.3775],\n",
2656 | " [ 0.0159, 0.0032, -0.3178],\n",
2657 | " [ 0.0517, 0.0275, -0.2442]], grad_fn=))])"
2658 | ]
2659 | },
2660 | "metadata": {},
2661 | "execution_count": 25
2662 | }
2663 | ]
2664 | },
2665 | {
2666 | "cell_type": "markdown",
2667 | "source": [
2668 | "## Load fine tuned model"
2669 | ],
2670 | "metadata": {
2671 | "id": "mV6mP3PbhQOe"
2672 | }
2673 | },
2674 | {
2675 | "cell_type": "code",
2676 | "source": [
2677 | "classification_model = model.from_pretrained(\"./checkpoint-100\")\n",
2678 | "classification_model.config.max_length = 256"
2679 | ],
2680 | "metadata": {
2681 | "id": "ZFn6m4NAhP13"
2682 | },
2683 | "execution_count": null,
2684 | "outputs": []
2685 | },
2686 | {
2687 | "cell_type": "markdown",
2688 | "source": [
2689 | "## logits to labels\n",
2690 | "\n",
2691 | "+ entailment (0)\n",
2692 | "+ neutral (1)\n",
2693 | "+ contradiction (2)"
2694 | ],
2695 | "metadata": {
2696 | "id": "BN0Ib25hj0mm"
2697 | }
2698 | },
2699 | {
2700 | "cell_type": "code",
2701 | "source": [
2702 | "pred = pred[0].detach().numpy().tolist()\n",
2703 | "pred = [*map(lambda x: x.index(max(x)), pred)]\n",
2704 | "pred, test_dataset[\"label\"][0:10]"
2705 | ],
2706 | "metadata": {
2707 | "colab": {
2708 | "base_uri": "https://localhost:8080/"
2709 | },
2710 | "id": "q4Ck_M1kjF5H",
2711 | "outputId": "d2788a03-36e7-4000-88be-b5b9c7c758ad"
2712 | },
2713 | "execution_count": null,
2714 | "outputs": [
2715 | {
2716 | "output_type": "execute_result",
2717 | "data": {
2718 | "text/plain": [
2719 | "([1, 0, 0, 0, 0, 0, 0, 1, 0, 0], [1, 2, 0, 2, 2, 2, 2, 1, 2, 1])"
2720 | ]
2721 | },
2722 | "metadata": {},
2723 | "execution_count": 27
2724 | }
2725 | ]
2726 | },
2727 | {
2728 | "cell_type": "markdown",
2729 | "source": [
2730 | "# 生成モデル\n",
2731 | "\n",
2732 | "入力文に含意な文を生成するモデルを作成する"
2733 | ],
2734 | "metadata": {
2735 | "id": "Dzy2lRHlTU_z"
2736 | }
2737 | },
2738 | {
2739 | "cell_type": "code",
2740 | "source": [
2741 | "def compute_metrics(pred):\n",
2742 | " labels = pred.label_ids\n",
2743 | " preds = pred.predictions.argmax(-1)\n",
2744 | " precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')\n",
2745 | " acc = accuracy_score(labels, preds)\n",
2746 | " return {\n",
2747 | " 'accuracy': acc,\n",
2748 | " 'f1': f1,\n",
2749 | " 'precision': precision,\n",
2750 | " 'recall': recall\n",
2751 | " }"
2752 | ],
2753 | "metadata": {
2754 | "id": "rPYWLSKBWbgz"
2755 | },
2756 | "execution_count": null,
2757 | "outputs": []
2758 | },
2759 | {
2760 | "cell_type": "markdown",
2761 | "source": [
2762 | "## load tokenizer and Encoder Decoder Model"
2763 | ],
2764 | "metadata": {
2765 | "id": "xvGWidNEMhia"
2766 | }
2767 | },
2768 | {
2769 | "cell_type": "code",
2770 | "source": [
2771 | "tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')\n",
2772 | "generation_model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints"
2773 | ],
2774 | "metadata": {
2775 | "id": "xyi9WWJfWf1r",
2776 | "colab": {
2777 | "base_uri": "https://localhost:8080/"
2778 | },
2779 | "outputId": "293b5398-0abf-4c35-8a4f-1761d306c116"
2780 | },
2781 | "execution_count": null,
2782 | "outputs": [
2783 | {
2784 | "output_type": "stream",
2785 | "name": "stderr",
2786 | "text": [
2787 | "loading file https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt from cache at /root/.cache/huggingface/transformers/45c3f7a79a80e1cf0a489e5c62b43f173c15db47864303a55d623bb3c96f72a5.d789d64ebfe299b0e416afc4a169632f903f693095b4629a7ea271d5a0cf2c99\n",
2788 | "loading file https://huggingface.co/bert-base-uncased/resolve/main/tokenizer.json from cache at /root/.cache/huggingface/transformers/534479488c54aeaf9c3406f647aa2ec13648c06771ffe269edabebd4c412da1d.7f2721073f19841be16f41b0a70b600ca6b880c8f3df6f3535cbc704371bdfa4\n",
2789 | "loading file https://huggingface.co/bert-base-uncased/resolve/main/added_tokens.json from cache at None\n",
2790 | "loading file https://huggingface.co/bert-base-uncased/resolve/main/special_tokens_map.json from cache at None\n",
2791 | "loading file https://huggingface.co/bert-base-uncased/resolve/main/tokenizer_config.json from cache at /root/.cache/huggingface/transformers/c1d7f0a763fb63861cc08553866f1fc3e5a6f4f07621be277452d26d71303b7e.20430bd8e10ef77a7d2977accefe796051e01bc2fc4aa146bc862997a1a15e79\n",
2792 | "Initializing bert-base-uncased as a decoder model. Cross attention layers are added to bert-base-uncased and randomly initialized if bert-base-uncased's architecture allows for cross attention layers.\n",
2793 | "Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config\n"
2794 | ]
2795 | }
2796 | ]
2797 | },
2798 | {
2799 | "cell_type": "markdown",
2800 | "source": [
2801 | ""
2802 | ],
2803 | "metadata": {
2804 | "id": "44370c_wMnm6"
2805 | }
2806 | },
2807 | {
2808 | "cell_type": "markdown",
2809 | "source": [
2810 | "## Preprocess of data\n",
2811 | "\n",
2812 | "+ データセットから含意ラベルのデータだけを抽出\n",
2813 | "+ 使用しないカラムを削除\n",
2814 | "+ データセットのカラム名を変更"
2815 | ],
2816 | "metadata": {
2817 | "id": "fFtGH7bvNHka"
2818 | }
2819 | },
2820 | {
2821 | "cell_type": "code",
2822 | "source": [
2823 | "raw_datasets = load_dataset(\"multi_nli\")\n",
2824 | "generation_datasets = raw_datasets.filter(lambda x:x[\"label\"]==1)\n",
2825 | "generation_datasets = generation_datasets.remove_columns(\n",
2826 | " [\"promptID\",\"pairID\",\"premise_binary_parse\",\"premise_parse\",\"hypothesis_binary_parse\", \"hypothesis_parse\",\"genre\", \"label\"]\n",
2827 | ")\n",
2828 | "generation_datasets = generation_datasets.rename_column(\"hypothesis\", \"input\")\n",
2829 | "generation_datasets = generation_datasets.rename_column(\"premise\", \"label\")\n",
2830 | "generation_datasets[\"train\"][0]"
2831 | ],
2832 | "metadata": {
2833 | "colab": {
2834 | "base_uri": "https://localhost:8080/",
2835 | "height": 214,
2836 | "referenced_widgets": [
2837 | "2200195c39054a4a95f553a822c3ca4a",
2838 | "0c676fd14879444abd54f70bf1a2326a",
2839 | "f255b11d880f4f958d57fce7293aaf6d",
2840 | "254562c035fc4d958587f456bcc7b8d2",
2841 | "e59046c976ab46f78596f4c630cc2263",
2842 | "685e506dd03643fb84459c0e02de6cbe",
2843 | "af7d8ecd0b2e4f9fa154596a909ddf1e",
2844 | "a632d87f60254c238d2d4db3cf1bcd03",
2845 | "acee3404b44e4aecac9dcbd36e058280",
2846 | "0c4b7491f39641ab86c49efa9151a71e",
2847 | "7404eb799d1a457d8e43f766d6c5e174",
2848 | "52a9cf4f584a4f1b8580a2bbd8759e57",
2849 | "54ad317700f3474092450fc7d32663d9",
2850 | "5263698f1ce34054be5d2d6a0a5bc76b",
2851 | "91c2e2811bad43ee8f055acca9bda414",
2852 | "ff9b6eda171546ea99c338da68b45ee6",
2853 | "d53e7f33fc57487a86cc2ebe66d2db57",
2854 | "844d2c9d36d24cec946d5150410009e5",
2855 | "27a2507f5b4545df8ca658ad762dbbc7",
2856 | "e5001084807c4831a22cc7b00edeb483",
2857 | "d174cc1c593243a6aaff0365bba33baf",
2858 | "8f5940e21b6344249acb11d0470d9422",
2859 | "7794f5a424ac4072aee88f31e09e1e43",
2860 | "d129033ce907443cb225703019488771",
2861 | "fc77432baff54b39901b868545422cba",
2862 | "9a2152ab441d4df99b1cd45660f9fd94",
2863 | "213266ab59fa4c2eb1bd825d173df309",
2864 | "35d662bc09134d9c9e0df17551f4d424",
2865 | "de7c9c356d464fb3919cb3963acd6681",
2866 | "fd1387ff2d4041dc89ee1be1cbd6f009",
2867 | "0a67ca03a2aa42a3992e230d7d74471b",
2868 | "0fe377abfcfe4c4ca8478fbe12cc73c1",
2869 | "1e34d874c513479292fdfff54cf3c635",
2870 | "f426ad93ecef450b989f545ac2f3f384",
2871 | "6706b8c4486e4041932a4fed66b03d73",
2872 | "003de4ed5051453cb621e0a220a96e43",
2873 | "c57c9843e9f3450e8e5049d7af3082bb",
2874 | "6e010421305144adba37845f17645440",
2875 | "00cb0b13b04f4d9084fccbff7be6df03",
2876 | "c9194f2ed9b648e49eb83256bbbf5c47",
2877 | "df4578b45bc64931a05b2fa1fd85ef6e",
2878 | "bb5dbbb1de40411ab7503d1262b3cde9",
2879 | "5f5d6b430d284e7c996fa98a115c96bb",
2880 | "e67a67b21c6c4985a4933554354bf11e"
2881 | ]
2882 | },
2883 | "id": "UXladez3WheU",
2884 | "outputId": "c11e300c-94c0-41ab-fc65-d48f9d519c21"
2885 | },
2886 | "execution_count": null,
2887 | "outputs": [
2888 | {
2889 | "output_type": "stream",
2890 | "name": "stderr",
2891 | "text": [
2892 | "Using custom data configuration default\n",
2893 | "Reusing dataset multi_nli (/root/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39)\n"
2894 | ]
2895 | },
2896 | {
2897 | "output_type": "display_data",
2898 | "data": {
2899 | "text/plain": [
2900 | " 0%| | 0/3 [00:00, ?it/s]"
2901 | ],
2902 | "application/vnd.jupyter.widget-view+json": {
2903 | "version_major": 2,
2904 | "version_minor": 0,
2905 | "model_id": "2200195c39054a4a95f553a822c3ca4a"
2906 | }
2907 | },
2908 | "metadata": {}
2909 | },
2910 | {
2911 | "output_type": "display_data",
2912 | "data": {
2913 | "text/plain": [
2914 | " 0%| | 0/393 [00:00, ?ba/s]"
2915 | ],
2916 | "application/vnd.jupyter.widget-view+json": {
2917 | "version_major": 2,
2918 | "version_minor": 0,
2919 | "model_id": "52a9cf4f584a4f1b8580a2bbd8759e57"
2920 | }
2921 | },
2922 | "metadata": {}
2923 | },
2924 | {
2925 | "output_type": "display_data",
2926 | "data": {
2927 | "text/plain": [
2928 | " 0%| | 0/10 [00:00, ?ba/s]"
2929 | ],
2930 | "application/vnd.jupyter.widget-view+json": {
2931 | "version_major": 2,
2932 | "version_minor": 0,
2933 | "model_id": "7794f5a424ac4072aee88f31e09e1e43"
2934 | }
2935 | },
2936 | "metadata": {}
2937 | },
2938 | {
2939 | "output_type": "display_data",
2940 | "data": {
2941 | "text/plain": [
2942 | " 0%| | 0/10 [00:00, ?ba/s]"
2943 | ],
2944 | "application/vnd.jupyter.widget-view+json": {
2945 | "version_major": 2,
2946 | "version_minor": 0,
2947 | "model_id": "f426ad93ecef450b989f545ac2f3f384"
2948 | }
2949 | },
2950 | "metadata": {}
2951 | },
2952 | {
2953 | "output_type": "execute_result",
2954 | "data": {
2955 | "text/plain": [
2956 | "{'input': 'Product and geography are what make cream skimming work. ',\n",
2957 | " 'label': 'Conceptually cream skimming has two basic dimensions - product and geography.'}"
2958 | ]
2959 | },
2960 | "metadata": {},
2961 | "execution_count": 30
2962 | }
2963 | ]
2964 | },
2965 | {
2966 | "cell_type": "markdown",
2967 | "source": [
2968 | "## Classification Modelとは違いpandasを使ってtraining dataを作ります\n",
2969 | "\n",
2970 | "+ データ数が多いと学習が終わらないので,ランダムに100個サンプリング"
2971 | ],
2972 | "metadata": {
2973 | "id": "bGNWB5fwNXVi"
2974 | }
2975 | },
2976 | {
2977 | "cell_type": "code",
2978 | "source": [
2979 | "df_dataset = pd.DataFrame({\n",
2980 | " \"inputs\":generation_datasets[\"train\"][\"input\"],\n",
2981 | " \"label\":generation_datasets[\"train\"][\"label\"]\n",
2982 | "})\n",
2983 | "df_dataset = df_dataset.sample(100).reset_index(drop=True)\n",
2984 | "df_dataset.head(1)"
2985 | ],
2986 | "metadata": {
2987 | "colab": {
2988 | "base_uri": "https://localhost:8080/",
2989 | "height": 81
2990 | },
2991 | "id": "M-MJLX3m7YXI",
2992 | "outputId": "dd5ce11d-bedd-4011-d2c3-6c0474f4054d"
2993 | },
2994 | "execution_count": null,
2995 | "outputs": [
2996 | {
2997 | "output_type": "execute_result",
2998 | "data": {
2999 | "text/plain": [
3000 | " inputs \\\n",
3001 | "0 Stevens was a talkative guy, and many couldn't... \n",
3002 | "\n",
3003 | " label \n",
3004 | "0 You Stevens shut your trap! Muller's roar brou... "
3005 | ],
3006 | "text/html": [
3007 | "\n",
3008 | " \n",
3009 | "
\n",
3010 | "
\n",
3011 | "\n",
3024 | "
\n",
3025 | " \n",
3026 | " \n",
3027 | " | \n",
3028 | " inputs | \n",
3029 | " label | \n",
3030 | "
\n",
3031 | " \n",
3032 | " \n",
3033 | " \n",
3034 | " 0 | \n",
3035 | " Stevens was a talkative guy, and many couldn't... | \n",
3036 | " You Stevens shut your trap! Muller's roar brou... | \n",
3037 | "
\n",
3038 | " \n",
3039 | "
\n",
3040 | "
\n",
3041 | "
\n",
3051 | " \n",
3052 | " \n",
3089 | "\n",
3090 | " \n",
3114 | "
\n",
3115 | "
\n",
3116 | " "
3117 | ]
3118 | },
3119 | "metadata": {},
3120 | "execution_count": 31
3121 | }
3122 | ]
3123 | },
3124 | {
3125 | "cell_type": "markdown",
3126 | "source": [
3127 | "## 入力テキストと出力ラベル(文)をそれぞれencodeして学習,評価データを作成"
3128 | ],
3129 | "metadata": {
3130 | "id": "cUYHWt6eNnMD"
3131 | }
3132 | },
3133 | {
3134 | "cell_type": "code",
3135 | "source": [
3136 | "inputs = tokenizer.batch_encode_plus(\n",
3137 | " df_dataset[\"inputs\"].tolist(),\n",
3138 | " return_tensors=\"pt\", \n",
3139 | " add_special_tokens=False,\n",
3140 | " truncation=True,\n",
3141 | " padding=\"max_length\",\n",
3142 | " max_length=256\n",
3143 | " )\n",
3144 | "labels = tokenizer.batch_encode_plus(\n",
3145 | " df_dataset[\"label\"].tolist(),\n",
3146 | " return_tensors=\"pt\", \n",
3147 | " add_special_tokens=True,\n",
3148 | " truncation=True,\n",
3149 | " padding=\"max_length\",\n",
3150 | " max_length=256\n",
3151 | " )\n",
3152 | "train_data = []\n",
3153 | "for i in range(len(inputs[\"input_ids\"])):\n",
3154 | " train_data.append(\n",
3155 | " {\n",
3156 | " \"input_ids\":inputs[\"input_ids\"][i],\n",
3157 | " \"token_type_ids\":inputs[\"token_type_ids\"][i],\n",
3158 | " \"attention_mask\":inputs[\"attention_mask\"][i],\n",
3159 | " \"label\":labels[\"input_ids\"][i] \n",
3160 | " }\n",
3161 | " )\n",
3162 | "random.shuffle(train_data)\n",
3163 | "train_size = int(len(train_data)*0.98)\n",
3164 | "eval_data = train_data[train_size:]"
3165 | ],
3166 | "metadata": {
3167 | "id": "ZKHuFNey7YTG"
3168 | },
3169 | "execution_count": null,
3170 | "outputs": []
3171 | },
3172 | {
3173 | "cell_type": "markdown",
3174 | "source": [
3175 | "## model configuration"
3176 | ],
3177 | "metadata": {
3178 | "id": "KcMaW-p5Nw-D"
3179 | }
3180 | },
3181 | {
3182 | "cell_type": "code",
3183 | "source": [
3184 | "generation_model.config.decoder_start_token_id = tokenizer.cls_token_id\n",
3185 | "generation_model.config.eos_token_id = tokenizer.sep_token_id\n",
3186 | "generation_model.config.pad_token_id = tokenizer.pad_token_id\n",
3187 | "# sensible parameters for beam search\n",
3188 | "generation_model.config.vocab_size = generation_model.config.decoder.vocab_size\n",
3189 | "generation_model.config.max_length = 100\n",
3190 | "generation_model.config.min_length = 20\n",
3191 | "generation_model.config.no_repeat_ngram_size = 1\n",
3192 | "generation_model.config.early_stopping = True\n",
3193 | "generation_model.config.length_penalty = 2.0\n",
3194 | "generation_model.config.num_beams = 20\n"
3195 | ],
3196 | "metadata": {
3197 | "id": "L5seOKthX74U"
3198 | },
3199 | "execution_count": null,
3200 | "outputs": []
3201 | },
3202 | {
3203 | "cell_type": "markdown",
3204 | "source": [
3205 | "## Training\n",
3206 | "\n",
3207 | "+ 生成モデルはSeq2SeqTrainerを使う\n",
3208 | "+ 自分でbackward()の処理を書いても可"
3209 | ],
3210 | "metadata": {
3211 | "id": "0adljFa_N0cr"
3212 | }
3213 | },
3214 | {
3215 | "cell_type": "code",
3216 | "source": [
3217 | "# Train Param\n",
3218 | "batch_size = 8\n",
3219 | "generation_model.train()\n",
3220 | "# https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments\n",
3221 | "training_args = Seq2SeqTrainingArguments(\n",
3222 | " output_dir='./',\n",
3223 | " evaluation_strategy=\"steps\",\n",
3224 | " per_device_train_batch_size=batch_size,\n",
3225 | " per_device_eval_batch_size=batch_size,\n",
3226 | " predict_with_generate=True,\n",
3227 | " logging_steps=10,\n",
3228 | " save_steps=30,\n",
3229 | " eval_steps=5000,\n",
3230 | " warmup_steps=1000,\n",
3231 | " overwrite_output_dir=True,\n",
3232 | " save_total_limit=5,\n",
3233 | " fp16=False,\n",
3234 | " num_train_epochs=3,\n",
3235 | " no_cuda=not CUDA_AVAILABLE\n",
3236 | ")\n",
3237 | "\n",
3238 | "# instantiate trainer\n",
3239 | "trainer = Seq2SeqTrainer(\n",
3240 | " model=generation_model,\n",
3241 | " tokenizer=tokenizer,\n",
3242 | " args=training_args,\n",
3243 | " train_dataset=train_data,\n",
3244 | " eval_dataset=eval_data\n",
3245 | ")\n",
3246 | "trainer.train()"
3247 | ],
3248 | "metadata": {
3249 | "colab": {
3250 | "base_uri": "https://localhost:8080/",
3251 | "height": 529
3252 | },
3253 | "id": "rGEGsIfUVbmj",
3254 | "outputId": "53508fa9-926e-437a-d669-98853a52955d"
3255 | },
3256 | "execution_count": null,
3257 | "outputs": [
3258 | {
3259 | "output_type": "stream",
3260 | "name": "stderr",
3261 | "text": [
3262 | "PyTorch: setting up devices\n",
3263 | "The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n",
3264 | "/usr/local/lib/python3.7/dist-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
3265 | " FutureWarning,\n",
3266 | "***** Running training *****\n",
3267 | " Num examples = 100\n",
3268 | " Num Epochs = 3\n",
3269 | " Instantaneous batch size per device = 8\n",
3270 | " Total train batch size (w. parallel, distributed & accumulation) = 8\n",
3271 | " Gradient Accumulation steps = 1\n",
3272 | " Total optimization steps = 39\n",
3273 | "The following columns in the training set don't have a corresponding argument in `EncoderDecoderModel.forward` and have been ignored: token_type_ids. If token_type_ids are not expected by `EncoderDecoderModel.forward`, you can safely ignore this message.\n",
3274 | "/usr/local/lib/python3.7/dist-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py:532: FutureWarning: Version v4.12.0 introduces a better way to train encoder-decoder models by computing the loss inside the encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning a model trained with versions anterior to 4.12.0. The decoder_input_ids are now created based on the labels, no need to pass them yourself anymore.\n",
3275 | " warnings.warn(DEPRECATION_WARNING, FutureWarning)\n"
3276 | ]
3277 | },
3278 | {
3279 | "output_type": "display_data",
3280 | "data": {
3281 | "text/plain": [
3282 | ""
3283 | ],
3284 | "text/html": [
3285 | "\n",
3286 | " \n",
3287 | " \n",
3288 | "
\n",
3289 | " [39/39 00:45, Epoch 3/3]\n",
3290 | "
\n",
3291 | " \n",
3292 | " \n",
3293 | " \n",
3294 | " Step | \n",
3295 | " Training Loss | \n",
3296 | " Validation Loss | \n",
3297 | "
\n",
3298 | " \n",
3299 | " \n",
3300 | " \n",
3301 | "
"
3302 | ]
3303 | },
3304 | "metadata": {}
3305 | },
3306 | {
3307 | "output_type": "stream",
3308 | "name": "stderr",
3309 | "text": [
3310 | "Saving model checkpoint to ./checkpoint-30\n",
3311 | "tokenizer config file saved in ./checkpoint-30/tokenizer_config.json\n",
3312 | "Special tokens file saved in ./checkpoint-30/special_tokens_map.json\n",
3313 | "/usr/local/lib/python3.7/dist-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py:532: FutureWarning: Version v4.12.0 introduces a better way to train encoder-decoder models by computing the loss inside the encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning a model trained with versions anterior to 4.12.0. The decoder_input_ids are now created based on the labels, no need to pass them yourself anymore.\n",
3314 | " warnings.warn(DEPRECATION_WARNING, FutureWarning)\n",
3315 | "\n",
3316 | "\n",
3317 | "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
3318 | "\n",
3319 | "\n"
3320 | ]
3321 | },
3322 | {
3323 | "output_type": "execute_result",
3324 | "data": {
3325 | "text/plain": [
3326 | "TrainOutput(global_step=39, training_loss=10.61686765230619, metrics={'train_runtime': 46.3326, 'train_samples_per_second': 6.475, 'train_steps_per_second': 0.842, 'total_flos': 92018115072000.0, 'train_loss': 10.61686765230619, 'epoch': 3.0})"
3327 | ]
3328 | },
3329 | "metadata": {},
3330 | "execution_count": 34
3331 | }
3332 | ]
3333 | },
3334 | {
3335 | "cell_type": "markdown",
3336 | "source": [
3337 | "# Read created model"
3338 | ],
3339 | "metadata": {
3340 | "id": "S0JLo1oIODgT"
3341 | }
3342 | },
3343 | {
3344 | "cell_type": "code",
3345 | "source": [
3346 | "created_model = generation_model.from_pretrained(\"./checkpoint-30\")"
3347 | ],
3348 | "metadata": {
3349 | "id": "X3fpsZJo-05X"
3350 | },
3351 | "execution_count": null,
3352 | "outputs": []
3353 | },
3354 | {
3355 | "cell_type": "markdown",
3356 | "source": [
3357 | "## Generate entailment sentence"
3358 | ],
3359 | "metadata": {
3360 | "id": "Izd6o18xOFtb"
3361 | }
3362 | },
3363 | {
3364 | "cell_type": "code",
3365 | "source": [
3366 | "tokenized = tokenizer(df_dataset[\"inputs\"][0], return_tensors=\"pt\", truncation=True, padding=True, max_length=256)\n",
3367 | "pred = created_model.generate(tokenized[\"input_ids\"])\n",
3368 | "pred"
3369 | ],
3370 | "metadata": {
3371 | "colab": {
3372 | "base_uri": "https://localhost:8080/"
3373 | },
3374 | "id": "yhFMvdOw_UfI",
3375 | "outputId": "2ad95288-bad5-4961-9139-e0e2b64d4ae8"
3376 | },
3377 | "execution_count": null,
3378 | "outputs": [
3379 | {
3380 | "output_type": "execute_result",
3381 | "data": {
3382 | "text/plain": [
3383 | "tensor([[ 101, 1996, 1012, 1025, 999, 1010, 1585, 30112, 30114, 1584,\n",
3384 | " 1586, 30111, 27876, 1583, 1587, 30132, 30130, 30129, 30131, 1141,\n",
3385 | " 1536, 25292, 1592, 30113, 1591, 19174, 1064, 4414, 1621, 17928,\n",
3386 | " 3031, 1588, 28637, 8778, 1607, 11916, 20955, 2004, 3022, 2133,\n",
3387 | " 18880, 16302, 13811, 27362, 2000, 2830, 8848, 2091, 2067, 2101,\n",
3388 | " 2083, 2627, 11165, 24288, 29053, 29051, 1998, 5685, 26379, 2664,\n",
3389 | " 16808, 5743, 15834, 7652, 19442, 25430, 13366, 1510, 4125, 2368,\n",
3390 | " 24333, 12942, 2046, 10359, 22625, 25693, 17741, 3413, 5235, 4084,\n",
3391 | " 10024, 22953, 26864, 11563, 4063, 15454, 5441, 2663, 2062, 22302,\n",
3392 | " 5963, 3553, 20755, 13806, 13776, 2721, 10278, 7367, 2061, 5879]])"
3393 | ]
3394 | },
3395 | "metadata": {},
3396 | "execution_count": 36
3397 | }
3398 | ]
3399 | },
3400 | {
3401 | "cell_type": "markdown",
3402 | "source": [
3403 | "## Decode predicted tensors"
3404 | ],
3405 | "metadata": {
3406 | "id": "dHAPe4MqONG6"
3407 | }
3408 | },
3409 | {
3410 | "cell_type": "code",
3411 | "source": [
3412 | "df_dataset[\"inputs\"][0], tokenizer.decode(pred[0], skip_special_tokens=True, truncation=True, padding=True, max_length=256)"
3413 | ],
3414 | "metadata": {
3415 | "colab": {
3416 | "base_uri": "https://localhost:8080/"
3417 | },
3418 | "id": "fiiERZY3AK23",
3419 | "outputId": "6ec40ff3-c3e3-4ca6-d1dc-a487671c20f4"
3420 | },
3421 | "execution_count": null,
3422 | "outputs": [
3423 | {
3424 | "output_type": "execute_result",
3425 | "data": {
3426 | "text/plain": [
3427 | "(\"Stevens was a talkative guy, and many couldn't stand him.\",\n",
3428 | " 'the. ;!, →↑↓ ↑ ↓←missive ← ↔∪∨∧∩ ʲ ⁰ vis ∂→ ⇒ nanny | respectively ☆ api onto ↦gree protocol ≡ency hartley asas... pasadenacare dominancerath to forward backward down back later through past successive successively travers hays andandndt yet moor forth onwardwardbeat sw def ᵥ riseenfastlum intolike fidelity gorman mcmahon pass passes stepsbra bro barnetnderder reject maintain win more hotterback closeribelatelatedlalam se so norman')"
3429 | ]
3430 | },
3431 | "metadata": {},
3432 | "execution_count": 37
3433 | }
3434 | ]
3435 | },
3436 | {
3437 | "cell_type": "markdown",
3438 | "source": [
3439 | "# Option"
3440 | ],
3441 | "metadata": {
3442 | "id": "JC87qnJeORy0"
3443 | }
3444 | },
3445 | {
3446 | "cell_type": "markdown",
3447 | "source": [
3448 | "+ \"pytorch_tutorial.ipynb\"を\"pytorch_tutorial.py\"に変換するコマンド\n",
3449 | "\n",
3450 | "```jupyter nbconvert --to script pytorch_tutorial.ipynb```\n",
3451 | "\n",
3452 | "+ 使うGPUを指定して実行する方法(os.environでも可)\n",
3453 | "+ 特にTrainerは視えてるGPUを全て使うので指定してあげる必要がある.\n",
3454 | "+ GPU:0~2を使って実行したい場合\n",
3455 | "\n",
3456 | "```CUDA_VISIBLE_DEVICES=0,1,2 python pytorch_tutorial.py ```"
3457 | ],
3458 | "metadata": {
3459 | "id": "doolAe-7OXGj"
3460 | }
3461 | }
3462 | ]
3463 | }
--------------------------------------------------------------------------------