├── src ├── __init__.py ├── data │ ├── .gitkeep │ ├── __init__.py │ ├── download_datasets.py │ └── utils.py ├── contrastive │ ├── models │ │ ├── metrics.py │ │ ├── loss.py │ │ └── modeling.py │ ├── abtbuy │ │ ├── run_pretraining_ssv_roberta.sh │ │ ├── run_pretraining_roberta.sh │ │ ├── run_pretraining_clean_roberta.sh │ │ ├── run_finetune_siamese_frozen_nosplit_roberta.sh │ │ ├── run_finetune_ssv_siamese_frozen_roberta.sh │ │ ├── run_finetune_siamese_frozen_roberta.sh │ │ ├── run_finetune_siamese_unfrozen_nosplit_roberta.sh │ │ ├── run_finetune_siamese_unfrozen_roberta.sh │ │ └── run_finetune_ssv_siamese_unfrozen_roberta.sh │ ├── lspc │ │ ├── run_pretraining_roberta.sh │ │ ├── run_pretraining_ssv_roberta.sh │ │ ├── run_finetune_siamese_frozen_roberta.sh │ │ ├── run_finetune_ssv_siamese_frozen_roberta.sh │ │ ├── run_finetune_siamese_unfrozen_roberta.sh │ │ └── run_finetune_ssv_siamese_unfrozen_roberta.sh │ ├── amazongoogle │ │ ├── run_pretraining_ssv_roberta.sh │ │ ├── run_pretraining_clean_roberta.sh │ │ ├── run_pretraining_roberta.sh │ │ ├── run_finetune_siamese_frozen_nosplit_roberta.sh │ │ ├── run_finetune_ssv_siamese_frozen_roberta.sh │ │ ├── run_finetune_siamese_frozen_roberta.sh │ │ ├── run_finetune_siamese_unfrozen_nosplit_roberta.sh │ │ ├── run_finetune_siamese_unfrozen_roberta.sh │ │ └── run_finetune_ssv_siamese_unfrozen_roberta.sh │ ├── data │ │ ├── data_collators.py │ │ └── datasets.py │ ├── run_pretraining_deepmatcher.py │ ├── run_pretraining_deepmatcher_nosplit.py │ ├── run_pretraining_ssv.py │ ├── run_pretraining.py │ └── run_finetune_siamese.py └── processing │ ├── preprocess │ ├── preprocess_corpus.py │ ├── preprocess_ts_gs.py │ └── preprocess-deepmatcher-datasets.py │ └── contrastive │ ├── prepare-data.py │ └── prepare-data-deepmatcher.py ├── setup.py ├── LICENSE ├── .gitignore ├── README.md ├── contrastive-product-matching.yml └── requirements.txt /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name='src', 5 | packages=find_packages(), 6 | version='0.1.0', 7 | description='Data Integration Research', 8 | author='Ralph Peeters', 9 | license='BSD-3', 10 | ) 11 | -------------------------------------------------------------------------------- /src/contrastive/models/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score 3 | 4 | def compute_metrics_bce(eval_pred): 5 | logits, labels = eval_pred 6 | 7 | logits[logits>=0.5] = 1 8 | logits[logits<0.5] = 0 9 | predictions = logits.reshape(-1) 10 | labels = labels.reshape(-1) 11 | 12 | accuracy = accuracy_score(labels, predictions) 13 | f1 = f1_score(labels, predictions, pos_label=1, average='binary') 14 | precision = precision_score(labels, predictions, pos_label=1, average='binary') 15 | recall = recall_score(labels, predictions, pos_label=1, average='binary') 16 | 17 | return {"accuracy": accuracy, "f1": f1, "precision": precision, "recall": recall} -------------------------------------------------------------------------------- /src/contrastive/abtbuy/run_pretraining_ssv_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | AUG=$4 11 | python run_pretraining_ssv.py \ 12 | --do_train \ 13 | --dataset_name=abt-buy \ 14 | --train_file /your_path/contrastive-product-matching/data/processed/abt-buy/contrastive/abt-buy-train.pkl.gz \ 15 | --id_deduction_set /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-train.json.gz \ 16 | --tokenizer="roberta-base" \ 17 | --grad_checkpoint=True \ 18 | --output_dir /your_path/contrastive-product-matching/reports/contrastive/abtbuy-ssv-$AUG$BATCH-$LR-$TEMP-roberta-base/ \ 19 | --temperature=$TEMP \ 20 | --per_device_train_batch_size=$BATCH \ 21 | --learning_rate=$LR \ 22 | --weight_decay=0.01 \ 23 | --num_train_epochs=200 \ 24 | --lr_scheduler_type="linear" \ 25 | --warmup_ratio=0.05 \ 26 | --max_grad_norm=1.0 \ 27 | --fp16 \ 28 | --dataloader_num_workers=4 \ 29 | --disable_tqdm=True \ 30 | --save_strategy="epoch" \ 31 | --logging_strategy="epoch" \ 32 | --augment=$AUG \ -------------------------------------------------------------------------------- /src/processing/preprocess/preprocess_corpus.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | np.random.seed(42) 4 | import random 5 | random.seed(42) 6 | import os 7 | 8 | from src.data import utils 9 | 10 | if __name__ == '__main__': 11 | 12 | print('PREPROCESSING CORPUS') 13 | 14 | corpus = pd.read_json('../../../data/raw/wdc-lspc/corpus/offers_corpus_english_v2_non_norm.json.gz', lines=True) 15 | 16 | # preprocess english corpus 17 | 18 | print('BUILDING PREPROCESSED CORPUS...') 19 | corpus['title'] = corpus['title'].apply(utils.clean_string_wdcv2) 20 | corpus['description'] = corpus['description'].apply(utils.clean_string_wdcv2) 21 | corpus['brand'] = corpus['brand'].apply(utils.clean_string_wdcv2) 22 | corpus['price'] = corpus['price'].apply(utils.clean_string_wdcv2) 23 | corpus['specTableContent'] = corpus['specTableContent'].apply(utils.clean_specTableContent_wdcv2) 24 | 25 | os.makedirs(os.path.dirname('../../../data/interim/wdc-lspc/corpus/'), exist_ok=True) 26 | corpus.to_pickle('../../../data/interim/wdc-lspc/corpus/preprocessed_english_corpus.pkl.gz') 27 | print('FINISHED BUILDING PREPROCESSED CORPUS...') 28 | -------------------------------------------------------------------------------- /src/contrastive/lspc/run_pretraining_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | SIZE=$4 11 | AUG=$5 12 | python run_pretraining.py \ 13 | --do_train \ 14 | --train_file /your_path/contrastive-product-matching/data/processed/wdc-lspc/contrastive/pre-train/computers/computers_train_$SIZE.pkl.gz \ 15 | --id_deduction_set /your_path/contrastive-product-matching/data/raw/wdc-lspc/training-sets/computers_train_$SIZE.json.gz \ 16 | --tokenizer="roberta-base" \ 17 | --grad_checkpoint=True \ 18 | --output_dir /your_path/contrastive-product-matching/reports/contrastive/computers-$SIZE-$AUG$BATCH-$LR-$TEMP-roberta-base/ \ 19 | --temperature=$TEMP \ 20 | --per_device_train_batch_size=$BATCH \ 21 | --learning_rate=$LR \ 22 | --weight_decay=0.01 \ 23 | --num_train_epochs=200 \ 24 | --lr_scheduler_type="linear" \ 25 | --warmup_ratio=0.05 \ 26 | --max_grad_norm=1.0 \ 27 | --fp16 \ 28 | --dataloader_num_workers=4 \ 29 | --disable_tqdm=True \ 30 | --save_strategy="epoch" \ 31 | --logging_strategy="epoch" \ 32 | --augment=$AUG \ -------------------------------------------------------------------------------- /src/contrastive/abtbuy/run_pretraining_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | AUG=$4 11 | python run_pretraining_deepmatcher_nosplit.py \ 12 | --do_train \ 13 | --dataset_name=abt-buy \ 14 | --clean=True \ 15 | --train_file /your_path/contrastive-product-matching/data/processed/abt-buy/contrastive/abt-buy-train.pkl.gz \ 16 | --id_deduction_set /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-train.json.gz \ 17 | --tokenizer="roberta-base" \ 18 | --grad_checkpoint=True \ 19 | --output_dir /your_path/contrastive-product-matching/reports/contrastive/abtbuy-$AUG$BATCH-$LR-$TEMP-roberta-base/ \ 20 | --temperature=$TEMP \ 21 | --per_device_train_batch_size=$BATCH \ 22 | --learning_rate=$LR \ 23 | --weight_decay=0.01 \ 24 | --num_train_epochs=200 \ 25 | --lr_scheduler_type="linear" \ 26 | --warmup_ratio=0.05 \ 27 | --max_grad_norm=1.0 \ 28 | --fp16 \ 29 | --dataloader_num_workers=4 \ 30 | --disable_tqdm=True \ 31 | --save_strategy="epoch" \ 32 | --logging_strategy="epoch" \ 33 | --augment=$AUG \ -------------------------------------------------------------------------------- /src/contrastive/abtbuy/run_pretraining_clean_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | AUG=$4 11 | python run_pretraining_deepmatcher.py \ 12 | --do_train \ 13 | --dataset_name=abt-buy \ 14 | --clean=True \ 15 | --train_file /your_path/contrastive-product-matching/data/processed/abt-buy/contrastive/abt-buy-train.pkl.gz \ 16 | --id_deduction_set /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-train.json.gz \ 17 | --tokenizer="roberta-base" \ 18 | --grad_checkpoint=True \ 19 | --output_dir /your_path/contrastive-product-matching/reports/contrastive/abtbuy-clean-$AUG$BATCH-$LR-$TEMP-roberta-base/ \ 20 | --temperature=$TEMP \ 21 | --per_device_train_batch_size=$BATCH \ 22 | --learning_rate=$LR \ 23 | --weight_decay=0.01 \ 24 | --num_train_epochs=5 \ 25 | --lr_scheduler_type="linear" \ 26 | --warmup_ratio=0.05 \ 27 | --max_grad_norm=1.0 \ 28 | --fp16 \ 29 | --dataloader_num_workers=4 \ 30 | --disable_tqdm=True \ 31 | --save_strategy="epoch" \ 32 | --logging_strategy="epoch" \ 33 | --augment=$AUG \ -------------------------------------------------------------------------------- /src/contrastive/lspc/run_pretraining_ssv_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | SIZE=$4 11 | AUG=$5 12 | python run_pretraining_ssv.py \ 13 | --do_train \ 14 | --train_file /your_path/contrastive-product-matching/data/processed/wdc-lspc/contrastive/pre-train/computers/computers_train_$SIZE.pkl.gz \ 15 | --id_deduction_set /your_path/contrastive-product-matching/data/raw/wdc-lspc/training-sets/computers_train_$SIZE.json.gz \ 16 | --tokenizer="roberta-base" \ 17 | --grad_checkpoint=True \ 18 | --output_dir /your_path/contrastive-product-matching/reports/contrastive/computers-ssv-$SIZE-$AUG$BATCH-$LR-$TEMP-roberta-base/ \ 19 | --temperature=$TEMP \ 20 | --per_device_train_batch_size=$BATCH \ 21 | --learning_rate=$LR \ 22 | --weight_decay=0.01 \ 23 | --num_train_epochs=200 \ 24 | --lr_scheduler_type="linear" \ 25 | --warmup_ratio=0.05 \ 26 | --max_grad_norm=1.0 \ 27 | --fp16 \ 28 | --dataloader_num_workers=4 \ 29 | --disable_tqdm=True \ 30 | --save_strategy="epoch" \ 31 | --logging_strategy="epoch" \ 32 | --augment=$AUG \ -------------------------------------------------------------------------------- /src/contrastive/amazongoogle/run_pretraining_ssv_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | AUG=$4 11 | python run_pretraining_ssv.py \ 12 | --do_train \ 13 | --dataset_name=amazon-google \ 14 | --train_file /your_path/contrastive-product-matching/data/processed/amazon-google/contrastive/amazon-google-train.pkl.gz \ 15 | --id_deduction_set /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-train.json.gz \ 16 | --tokenizer="roberta-base" \ 17 | --grad_checkpoint=True \ 18 | --output_dir /your_path/contrastive-product-matching/reports/contrastive/amazongoogle-ssv-$AUG$BATCH-$LR-$TEMP-roberta-base/ \ 19 | --temperature=$TEMP \ 20 | --per_device_train_batch_size=$BATCH \ 21 | --learning_rate=$LR \ 22 | --weight_decay=0.01 \ 23 | --num_train_epochs=200 \ 24 | --lr_scheduler_type="linear" \ 25 | --warmup_ratio=0.05 \ 26 | --max_grad_norm=1.0 \ 27 | --fp16 \ 28 | --dataloader_num_workers=4 \ 29 | --disable_tqdm=True \ 30 | --save_strategy="epoch" \ 31 | --logging_strategy="epoch" \ 32 | --augment=$AUG \ -------------------------------------------------------------------------------- /src/contrastive/amazongoogle/run_pretraining_clean_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | AUG=$4 11 | python run_pretraining_deepmatcher.py \ 12 | --do_train \ 13 | --dataset_name=amazon-google \ 14 | --clean=True \ 15 | --train_file /your_path/contrastive-product-matching/data/processed/amazon-google/contrastive/amazon-google-train.pkl.gz \ 16 | --id_deduction_set /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-train.json.gz \ 17 | --tokenizer="roberta-base" \ 18 | --grad_checkpoint=True \ 19 | --output_dir /your_path/contrastive-product-matching/reports/contrastive/amazongoogle-clean-$AUG$BATCH-$LR-$TEMP-roberta-base/ \ 20 | --temperature=$TEMP \ 21 | --per_device_train_batch_size=$BATCH \ 22 | --learning_rate=$LR \ 23 | --weight_decay=0.01 \ 24 | --num_train_epochs=200 \ 25 | --lr_scheduler_type="linear" \ 26 | --warmup_ratio=0.05 \ 27 | --max_grad_norm=1.0 \ 28 | --fp16 \ 29 | --dataloader_num_workers=4 \ 30 | --disable_tqdm=True \ 31 | --save_strategy="epoch" \ 32 | --logging_strategy="epoch" \ 33 | --augment=$AUG \ -------------------------------------------------------------------------------- /src/contrastive/amazongoogle/run_pretraining_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | AUG=$4 11 | python run_pretraining_deepmatcher_nosplit.py \ 12 | --do_train \ 13 | --dataset_name=amazon-google \ 14 | --clean=True \ 15 | --train_file /your_path/contrastive-product-matching/data/processed/amazon-google/contrastive/amazon-google-train.pkl.gz \ 16 | --id_deduction_set /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-train.json.gz \ 17 | --tokenizer="roberta-base" \ 18 | --grad_checkpoint=True \ 19 | --output_dir /your_path/contrastive-product-matching/reports/contrastive/amazongoogle-$AUG$BATCH-$LR-$TEMP-roberta-base/ \ 20 | --temperature=$TEMP \ 21 | --per_device_train_batch_size=$BATCH \ 22 | --learning_rate=$LR \ 23 | --weight_decay=0.01 \ 24 | --num_train_epochs=200 \ 25 | --lr_scheduler_type="linear" \ 26 | --warmup_ratio=0.05 \ 27 | --max_grad_norm=1.0 \ 28 | --fp16 \ 29 | --dataloader_num_workers=4 \ 30 | --disable_tqdm=True \ 31 | --save_strategy="epoch" \ 32 | --logging_strategy="epoch" \ 33 | --augment=$AUG \ -------------------------------------------------------------------------------- /src/processing/contrastive/prepare-data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | np.random.seed(42) 4 | import random 5 | random.seed(42) 6 | 7 | from pathlib import Path 8 | import shutil 9 | 10 | if __name__ == '__main__': 11 | 12 | categories = ['computers'] 13 | train_sizes = ['small', 'medium', 'large', 'xlarge'] 14 | 15 | data = pd.read_pickle('../../../data/interim/wdc-lspc/corpus/preprocessed_english_corpus.pkl.gz') 16 | 17 | relevant_cols = ['id', 'cluster_id', 'brand', 'title', 'description', 'specTableContent'] 18 | 19 | for category in categories: 20 | out_path = f'../../../data/processed/wdc-lspc/contrastive/pre-train/{category}/' 21 | shutil.rmtree(out_path, ignore_errors=True) 22 | Path(out_path).mkdir(parents=True, exist_ok=True) 23 | for train_size in train_sizes: 24 | ids = pd.read_pickle(f'../../../data/interim/wdc-lspc/training-sets/preprocessed_{category}_train_{train_size}.pkl.gz') 25 | 26 | relevant_ids = set() 27 | relevant_ids.update(ids['id_left']) 28 | relevant_ids.update(ids['id_right']) 29 | 30 | data_selection = data[data['id'].isin(relevant_ids)] 31 | data_selection = data_selection[relevant_cols] 32 | data_selection = data_selection.reset_index(drop=True) 33 | data_selection.to_pickle(f'{out_path}{category}_train_{train_size}.pkl.gz') -------------------------------------------------------------------------------- /src/contrastive/abtbuy/run_finetune_siamese_frozen_nosplit_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | AUG=$4 11 | PREAUG=$5 12 | python run_finetune_siamese.py \ 13 | --model_pretrained_checkpoint /your_path/contrastive-product-matching/reports/contrastive/abtbuy-$PREAUG$BATCH-$LR-$TEMP-roberta-base/pytorch_model.bin \ 14 | --do_train \ 15 | --dataset_name=abt-buy \ 16 | --train_file /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-train.json.gz \ 17 | --validation_file /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-train.json.gz \ 18 | --test_file /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-gs.json.gz \ 19 | --evaluation_strategy=epoch \ 20 | --tokenizer="roberta-base" \ 21 | --grad_checkpoint=False \ 22 | --output_dir /your_path/contrastive-product-matching/reports/contrastive-ft-siamese/abtbuy-$AUG$BATCH-$PREAUG$LR-$TEMP-frozen-roberta-base/ \ 23 | --per_device_train_batch_size=64 \ 24 | --learning_rate=5e-05 \ 25 | --weight_decay=0.01 \ 26 | --num_train_epochs=50 \ 27 | --lr_scheduler_type="linear" \ 28 | --warmup_ratio=0.05 \ 29 | --max_grad_norm=1.0 \ 30 | --fp16 \ 31 | --metric_for_best_model=loss \ 32 | --dataloader_num_workers=4 \ 33 | --disable_tqdm=True \ 34 | --save_strategy="epoch" \ 35 | --load_best_model_at_end \ 36 | --augment=$AUG \ 37 | #--do_param_opt \ -------------------------------------------------------------------------------- /src/contrastive/abtbuy/run_finetune_ssv_siamese_frozen_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | AUG=$4 11 | PREAUG=$5 12 | python run_finetune_siamese.py \ 13 | --model_pretrained_checkpoint /your_path/contrastive-product-matching/reports/contrastive/abtbuy-ssv-$PREAUG$BATCH-$LR-$TEMP-roberta-base/pytorch_model.bin \ 14 | --do_train \ 15 | --dataset_name=abt-buy \ 16 | --train_file /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-train.json.gz \ 17 | --validation_file /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-train.json.gz \ 18 | --test_file /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-gs.json.gz \ 19 | --evaluation_strategy=epoch \ 20 | --tokenizer="roberta-base" \ 21 | --grad_checkpoint=True \ 22 | --output_dir /your_path/contrastive-product-matching/reports/contrastive-ft-siamese/abtbuy-ssv-$AUG$BATCH-$PREAUG$LR-$TEMP-frozen-roberta-base/ \ 23 | --per_device_train_batch_size=64 \ 24 | --learning_rate=5e-05 \ 25 | --weight_decay=0.01 \ 26 | --num_train_epochs=50 \ 27 | --lr_scheduler_type="linear" \ 28 | --warmup_ratio=0.05 \ 29 | --max_grad_norm=1.0 \ 30 | --fp16 \ 31 | --metric_for_best_model=f1 \ 32 | --dataloader_num_workers=4 \ 33 | --disable_tqdm=True \ 34 | --save_strategy="epoch" \ 35 | --load_best_model_at_end \ 36 | --augment=$AUG \ 37 | #--do_param_opt \ -------------------------------------------------------------------------------- /src/contrastive/abtbuy/run_finetune_siamese_frozen_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | AUG=$4 11 | PREAUG=$5 12 | python run_finetune_siamese.py \ 13 | --model_pretrained_checkpoint /your_path/contrastive-product-matching/reports/contrastive/abtbuy-clean-$PREAUG$BATCH-$LR-$TEMP-roberta-base/pytorch_model.bin \ 14 | --do_train \ 15 | --dataset_name=abt-buy \ 16 | --train_file /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-train.json.gz \ 17 | --validation_file /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-train.json.gz \ 18 | --test_file /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-gs.json.gz \ 19 | --evaluation_strategy=epoch \ 20 | --tokenizer="roberta-base" \ 21 | --grad_checkpoint=False \ 22 | --output_dir /your_path/contrastive-product-matching/reports/contrastive-ft-siamese/abtbuy-clean-$AUG$BATCH-$PREAUG$LR-$TEMP-frozen-roberta-base/ \ 23 | --per_device_train_batch_size=64 \ 24 | --learning_rate=5e-05 \ 25 | --weight_decay=0.01 \ 26 | --num_train_epochs=5 \ 27 | --lr_scheduler_type="linear" \ 28 | --warmup_ratio=0.05 \ 29 | --max_grad_norm=1.0 \ 30 | --fp16 \ 31 | --metric_for_best_model=loss \ 32 | --dataloader_num_workers=4 \ 33 | --disable_tqdm=True \ 34 | --save_strategy="epoch" \ 35 | --load_best_model_at_end \ 36 | --augment=$AUG \ 37 | #--do_param_opt \ -------------------------------------------------------------------------------- /src/contrastive/abtbuy/run_finetune_siamese_unfrozen_nosplit_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | AUG=$4 11 | PREAUG=$5 12 | python run_finetune_siamese.py \ 13 | --model_pretrained_checkpoint /your_path/contrastive-product-matching/reports/contrastive/abtbuy-$PREAUG$BATCH-$LR-$TEMP-roberta-base/pytorch_model.bin \ 14 | --do_train \ 15 | --dataset_name=abt-buy \ 16 | --frozen=False \ 17 | --train_file /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-train.json.gz \ 18 | --validation_file /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-train.json.gz \ 19 | --test_file /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-gs.json.gz \ 20 | --evaluation_strategy=epoch \ 21 | --tokenizer="roberta-base" \ 22 | --grad_checkpoint=False \ 23 | --output_dir /your_path/contrastive-product-matching/reports/contrastive-ft-siamese/abtbuy-$AUG$BATCH-$PREAUG$LR-$TEMP-unfrozen-roberta-base/ \ 24 | --per_device_train_batch_size=64 \ 25 | --learning_rate=5e-05 \ 26 | --weight_decay=0.01 \ 27 | --num_train_epochs=50 \ 28 | --lr_scheduler_type="linear" \ 29 | --warmup_ratio=0.05 \ 30 | --max_grad_norm=1.0 \ 31 | --fp16 \ 32 | --metric_for_best_model=loss \ 33 | --dataloader_num_workers=4 \ 34 | --disable_tqdm=True \ 35 | --save_strategy="epoch" \ 36 | --load_best_model_at_end \ 37 | --augment=$AUG \ 38 | #--do_param_opt \ -------------------------------------------------------------------------------- /src/contrastive/abtbuy/run_finetune_siamese_unfrozen_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | AUG=$4 11 | PREAUG=$5 12 | python run_finetune_siamese.py \ 13 | --model_pretrained_checkpoint /your_path/contrastive-product-matching/reports/contrastive/abtbuy-clean-$PREAUG$BATCH-$LR-$TEMP-roberta-base/pytorch_model.bin \ 14 | --do_train \ 15 | --dataset_name=abt-buy \ 16 | --frozen=False \ 17 | --train_file /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-train.json.gz \ 18 | --validation_file /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-train.json.gz \ 19 | --test_file /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-gs.json.gz \ 20 | --evaluation_strategy=epoch \ 21 | --tokenizer="roberta-base" \ 22 | --grad_checkpoint=False \ 23 | --output_dir /your_path/contrastive-product-matching/reports/contrastive-ft-siamese/abtbuy-clean-$AUG$BATCH-$PREAUG$LR-$TEMP-unfrozen-roberta-base/ \ 24 | --per_device_train_batch_size=64 \ 25 | --learning_rate=5e-05 \ 26 | --weight_decay=0.01 \ 27 | --num_train_epochs=50 \ 28 | --lr_scheduler_type="linear" \ 29 | --warmup_ratio=0.05 \ 30 | --max_grad_norm=1.0 \ 31 | --fp16 \ 32 | --metric_for_best_model=loss \ 33 | --dataloader_num_workers=4 \ 34 | --disable_tqdm=True \ 35 | --save_strategy="epoch" \ 36 | --load_best_model_at_end \ 37 | --augment=$AUG \ 38 | #--do_param_opt \ -------------------------------------------------------------------------------- /src/contrastive/abtbuy/run_finetune_ssv_siamese_unfrozen_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | AUG=$4 11 | PREAUG=$5 12 | python run_finetune_siamese.py \ 13 | --model_pretrained_checkpoint /your_path/contrastive-product-matching/reports/contrastive/abtbuy-ssv-$PREAUG$BATCH-$LR-$TEMP-roberta-base/pytorch_model.bin \ 14 | --do_train \ 15 | --dataset_name=abt-buy \ 16 | --frozen=False \ 17 | --train_file /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-train.json.gz \ 18 | --validation_file /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-train.json.gz \ 19 | --test_file /your_path/contrastive-product-matching/data/interim/abt-buy/abt-buy-gs.json.gz \ 20 | --evaluation_strategy=epoch \ 21 | --tokenizer="roberta-base" \ 22 | --grad_checkpoint=True \ 23 | --output_dir /your_path/contrastive-product-matching/reports/contrastive-ft-siamese/abtbuy-ssv-$AUG$BATCH-$PREAUG$LR-$TEMP-unfrozen-roberta-base/ \ 24 | --per_device_train_batch_size=64 \ 25 | --learning_rate=5e-05 \ 26 | --weight_decay=0.01 \ 27 | --num_train_epochs=50 \ 28 | --lr_scheduler_type="linear" \ 29 | --warmup_ratio=0.05 \ 30 | --max_grad_norm=1.0 \ 31 | --fp16 \ 32 | --metric_for_best_model=f1 \ 33 | --dataloader_num_workers=4 \ 34 | --disable_tqdm=True \ 35 | --save_strategy="epoch" \ 36 | --load_best_model_at_end \ 37 | --augment=$AUG \ 38 | #--do_param_opt \ -------------------------------------------------------------------------------- /src/contrastive/amazongoogle/run_finetune_siamese_frozen_nosplit_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | AUG=$4 11 | PREAUG=$5 12 | python run_finetune_siamese.py \ 13 | --model_pretrained_checkpoint /your_path/contrastive-product-matching/reports/contrastive/amazongoogle-$PREAUG$BATCH-$LR-$TEMP-roberta-base/pytorch_model.bin \ 14 | --do_train \ 15 | --dataset_name=amazon-google \ 16 | --train_file /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-train.json.gz \ 17 | --validation_file /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-train.json.gz \ 18 | --test_file /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-gs.json.gz \ 19 | --evaluation_strategy=epoch \ 20 | --tokenizer="roberta-base" \ 21 | --grad_checkpoint=False \ 22 | --output_dir /your_path/contrastive-product-matching/reports/contrastive-ft-siamese/amazongoogle-$AUG$BATCH-$PREAUG$LR-$TEMP-frozen-roberta-base/ \ 23 | --per_device_train_batch_size=64 \ 24 | --learning_rate=5e-05 \ 25 | --weight_decay=0.01 \ 26 | --num_train_epochs=50 \ 27 | --lr_scheduler_type="linear" \ 28 | --warmup_ratio=0.05 \ 29 | --max_grad_norm=1.0 \ 30 | --fp16 \ 31 | --metric_for_best_model=loss \ 32 | --dataloader_num_workers=4 \ 33 | --disable_tqdm=True \ 34 | --save_strategy="epoch" \ 35 | --load_best_model_at_end \ 36 | --augment=$AUG \ 37 | #--do_param_opt \ -------------------------------------------------------------------------------- /src/contrastive/amazongoogle/run_finetune_ssv_siamese_frozen_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | AUG=$4 11 | PREAUG=$5 12 | python run_finetune_siamese.py \ 13 | --model_pretrained_checkpoint /your_path/contrastive-product-matching/reports/contrastive/amazongoogle-ssv-$PREAUG$BATCH-$LR-$TEMP-roberta-base/pytorch_model.bin \ 14 | --do_train \ 15 | --dataset_name=amazon-google \ 16 | --train_file /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-train.json.gz \ 17 | --validation_file /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-train.json.gz \ 18 | --test_file /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-gs.json.gz \ 19 | --evaluation_strategy=epoch \ 20 | --tokenizer="roberta-base" \ 21 | --grad_checkpoint=True \ 22 | --output_dir /your_path/contrastive-product-matching/reports/contrastive-ft-siamese/amazongoogle-ssv-$AUG$BATCH-$PREAUG$LR-$TEMP-frozen-roberta-base/ \ 23 | --per_device_train_batch_size=64 \ 24 | --learning_rate=5e-05 \ 25 | --weight_decay=0.01 \ 26 | --num_train_epochs=50 \ 27 | --lr_scheduler_type="linear" \ 28 | --warmup_ratio=0.05 \ 29 | --max_grad_norm=1.0 \ 30 | --fp16 \ 31 | --metric_for_best_model=f1 \ 32 | --dataloader_num_workers=4 \ 33 | --disable_tqdm=True \ 34 | --save_strategy="epoch" \ 35 | --load_best_model_at_end \ 36 | --augment=$AUG \ 37 | #--do_param_opt \ -------------------------------------------------------------------------------- /src/contrastive/amazongoogle/run_finetune_siamese_frozen_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | AUG=$4 11 | PREAUG=$5 12 | python run_finetune_siamese.py \ 13 | --model_pretrained_checkpoint /your_path/contrastive-product-matching/reports/contrastive/amazongoogle-clean-$PREAUG$BATCH-$LR-$TEMP-roberta-base/pytorch_model.bin \ 14 | --do_train \ 15 | --dataset_name=amazon-google \ 16 | --train_file /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-train.json.gz \ 17 | --validation_file /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-train.json.gz \ 18 | --test_file /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-gs.json.gz \ 19 | --evaluation_strategy=epoch \ 20 | --tokenizer="roberta-base" \ 21 | --grad_checkpoint=False \ 22 | --output_dir /your_path/contrastive-product-matching/reports/contrastive-ft-siamese/amazongoogle-clean-$AUG$BATCH-$PREAUG$LR-$TEMP-frozen-roberta-base/ \ 23 | --per_device_train_batch_size=64 \ 24 | --learning_rate=5e-05 \ 25 | --weight_decay=0.01 \ 26 | --num_train_epochs=50 \ 27 | --lr_scheduler_type="linear" \ 28 | --warmup_ratio=0.05 \ 29 | --max_grad_norm=1.0 \ 30 | --fp16 \ 31 | --metric_for_best_model=loss \ 32 | --dataloader_num_workers=4 \ 33 | --disable_tqdm=True \ 34 | --save_strategy="epoch" \ 35 | --load_best_model_at_end \ 36 | --augment=$AUG \ 37 | #--do_param_opt \ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Copyright (c) 2019, Ralph Peeters 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, 6 | are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, this 12 | list of conditions and the following disclaimer in the documentation and/or 13 | other materials provided with the distribution. 14 | 15 | * Neither the name of di-research nor the names of its 16 | contributors may be used to endorse or promote products derived from this 17 | software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 20 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 21 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 22 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 23 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 24 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 26 | OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE 27 | OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 28 | OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | -------------------------------------------------------------------------------- /src/contrastive/amazongoogle/run_finetune_siamese_unfrozen_nosplit_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | AUG=$4 11 | PREAUG=$5 12 | python run_finetune_siamese.py \ 13 | --model_pretrained_checkpoint /your_path/contrastive-product-matching/reports/contrastive/amazongoogle-$PREAUG$BATCH-$LR-$TEMP-roberta-base/pytorch_model.bin \ 14 | --do_train \ 15 | --dataset_name=amazon-google \ 16 | --frozen=False \ 17 | --train_file /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-train.json.gz \ 18 | --validation_file /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-train.json.gz \ 19 | --test_file /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-gs.json.gz \ 20 | --evaluation_strategy=epoch \ 21 | --tokenizer="roberta-base" \ 22 | --grad_checkpoint=False \ 23 | --output_dir /your_path/contrastive-product-matching/reports/contrastive-ft-siamese/amazongoogle-$AUG$BATCH-$PREAUG$LR-$TEMP-unfrozen-roberta-base/ \ 24 | --per_device_train_batch_size=64 \ 25 | --learning_rate=5e-05 \ 26 | --weight_decay=0.01 \ 27 | --num_train_epochs=50 \ 28 | --lr_scheduler_type="linear" \ 29 | --warmup_ratio=0.05 \ 30 | --max_grad_norm=1.0 \ 31 | --fp16 \ 32 | --metric_for_best_model=loss \ 33 | --dataloader_num_workers=4 \ 34 | --disable_tqdm=True \ 35 | --save_strategy="epoch" \ 36 | --load_best_model_at_end \ 37 | --augment=$AUG \ 38 | #--do_param_opt \ -------------------------------------------------------------------------------- /src/contrastive/amazongoogle/run_finetune_siamese_unfrozen_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | AUG=$4 11 | PREAUG=$5 12 | python run_finetune_siamese.py \ 13 | --model_pretrained_checkpoint /your_path/contrastive-product-matching/reports/contrastive/amazongoogle-clean-$PREAUG$BATCH-$LR-$TEMP-roberta-base/pytorch_model.bin \ 14 | --do_train \ 15 | --dataset_name=amazon-google \ 16 | --frozen=False \ 17 | --train_file /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-train.json.gz \ 18 | --validation_file /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-train.json.gz \ 19 | --test_file /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-gs.json.gz \ 20 | --evaluation_strategy=epoch \ 21 | --tokenizer="roberta-base" \ 22 | --grad_checkpoint=False \ 23 | --output_dir /your_path/contrastive-product-matching/reports/contrastive-ft-siamese/amazongoogle-clean-$AUG$BATCH-$PREAUG$LR-$TEMP-unfrozen-roberta-base/ \ 24 | --per_device_train_batch_size=64 \ 25 | --learning_rate=5e-05 \ 26 | --weight_decay=0.01 \ 27 | --num_train_epochs=50 \ 28 | --lr_scheduler_type="linear" \ 29 | --warmup_ratio=0.05 \ 30 | --max_grad_norm=1.0 \ 31 | --fp16 \ 32 | --metric_for_best_model=loss \ 33 | --dataloader_num_workers=4 \ 34 | --disable_tqdm=True \ 35 | --save_strategy="epoch" \ 36 | --load_best_model_at_end \ 37 | --augment=$AUG \ 38 | #--do_param_opt \ -------------------------------------------------------------------------------- /src/contrastive/amazongoogle/run_finetune_ssv_siamese_unfrozen_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | AUG=$4 11 | PREAUG=$5 12 | python run_finetune_siamese.py \ 13 | --model_pretrained_checkpoint /your_path/contrastive-product-matching/reports/contrastive/amazongoogle-ssv-$PREAUG$BATCH-$LR-$TEMP-roberta-base/pytorch_model.bin \ 14 | --do_train \ 15 | --dataset_name=amazon-google \ 16 | --frozen=False \ 17 | --train_file /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-train.json.gz \ 18 | --validation_file /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-train.json.gz \ 19 | --test_file /your_path/contrastive-product-matching/data/interim/amazon-google/amazon-google-gs.json.gz \ 20 | --evaluation_strategy=epoch \ 21 | --tokenizer="roberta-base" \ 22 | --grad_checkpoint=True \ 23 | --output_dir /your_path/contrastive-product-matching/reports/contrastive-ft-siamese/amazongoogle-ssv-$AUG$BATCH-$PREAUG$LR-$TEMP-unfrozen-roberta-base/ \ 24 | --per_device_train_batch_size=64 \ 25 | --learning_rate=5e-05 \ 26 | --weight_decay=0.01 \ 27 | --num_train_epochs=50 \ 28 | --lr_scheduler_type="linear" \ 29 | --warmup_ratio=0.05 \ 30 | --max_grad_norm=1.0 \ 31 | --fp16 \ 32 | --metric_for_best_model=f1 \ 33 | --dataloader_num_workers=4 \ 34 | --disable_tqdm=True \ 35 | --save_strategy="epoch" \ 36 | --load_best_model_at_end \ 37 | --augment=$AUG \ 38 | #--do_param_opt \ -------------------------------------------------------------------------------- /src/contrastive/lspc/run_finetune_siamese_frozen_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | SIZE=$4 11 | AUG=$5 12 | PREAUG=$6 13 | python run_finetune_siamese.py \ 14 | --model_pretrained_checkpoint /your_path/contrastive-product-matching/reports/contrastive/computers-$SIZE-$PREAUG$BATCH-$LR-$TEMP-roberta-base/pytorch_model.bin \ 15 | --do_train \ 16 | --train_file /your_path/contrastive-product-matching/data/interim/wdc-lspc/training-sets/preprocessed_computers_train_$SIZE.pkl.gz \ 17 | --train_size=$SIZE \ 18 | --validation_file /your_path/contrastive-product-matching/data/interim/wdc-lspc/training-sets/preprocessed_computers_train_$SIZE.pkl.gz \ 19 | --test_file /your_path/contrastive-product-matching/data/interim/wdc-lspc/gold-standards/preprocessed_computers_gs.pkl.gz \ 20 | --evaluation_strategy=epoch \ 21 | --tokenizer="roberta-base" \ 22 | --grad_checkpoint=True \ 23 | --output_dir /your_path/contrastive-product-matching/reports/contrastive-ft-siamese/computers-$SIZE-$AUG$BATCH-$PREAUG$LR-$TEMP-frozen-roberta-base/ \ 24 | --per_device_train_batch_size=64 \ 25 | --learning_rate=5e-05 \ 26 | --weight_decay=0.01 \ 27 | --num_train_epochs=50 \ 28 | --lr_scheduler_type="linear" \ 29 | --warmup_ratio=0.05 \ 30 | --max_grad_norm=1.0 \ 31 | --fp16 \ 32 | --metric_for_best_model=loss \ 33 | --dataloader_num_workers=4 \ 34 | --disable_tqdm=True \ 35 | --save_strategy="epoch" \ 36 | --load_best_model_at_end \ 37 | --augment=$AUG \ 38 | #--do_param_opt \ -------------------------------------------------------------------------------- /src/contrastive/lspc/run_finetune_ssv_siamese_frozen_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | SIZE=$4 11 | AUG=$5 12 | PREAUG=$6 13 | python run_finetune_siamese.py \ 14 | --model_pretrained_checkpoint /your_path/contrastive-product-matching/reports/contrastive/computers-ssv-$SIZE-$PREAUG$BATCH-$LR-$TEMP-roberta-base/pytorch_model.bin \ 15 | --do_train \ 16 | --train_file /your_path/contrastive-product-matching/data/interim/wdc-lspc/training-sets/preprocessed_computers_train_$SIZE.pkl.gz \ 17 | --train_size=$SIZE \ 18 | --validation_file /your_path/contrastive-product-matching/data/interim/wdc-lspc/training-sets/preprocessed_computers_train_$SIZE.pkl.gz \ 19 | --test_file /your_path/contrastive-product-matching/data/interim/wdc-lspc/gold-standards/preprocessed_computers_gs.pkl.gz \ 20 | --evaluation_strategy=epoch \ 21 | --tokenizer="roberta-base" \ 22 | --grad_checkpoint=True \ 23 | --output_dir /your_path/contrastive-product-matching/reports/contrastive-ft-siamese/computers-ssv-$SIZE-$AUG$BATCH-$PREAUG$LR-$TEMP-frozen-roberta-base/ \ 24 | --per_device_train_batch_size=64 \ 25 | --learning_rate=5e-05 \ 26 | --weight_decay=0.01 \ 27 | --num_train_epochs=50 \ 28 | --lr_scheduler_type="linear" \ 29 | --warmup_ratio=0.05 \ 30 | --max_grad_norm=1.0 \ 31 | --fp16 \ 32 | --metric_for_best_model=f1 \ 33 | --dataloader_num_workers=4 \ 34 | --disable_tqdm=True \ 35 | --save_strategy="epoch" \ 36 | --load_best_model_at_end \ 37 | --augment=$AUG \ 38 | #--do_param_opt \ -------------------------------------------------------------------------------- /src/contrastive/lspc/run_finetune_siamese_unfrozen_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | SIZE=$4 11 | AUG=$5 12 | PREAUG=$6 13 | python run_finetune_siamese.py \ 14 | --model_pretrained_checkpoint /your_path/contrastive-product-matching/reports/contrastive/computers-$SIZE-$PREAUG$BATCH-$LR-$TEMP-roberta-base/pytorch_model.bin \ 15 | --do_train \ 16 | --frozen=False \ 17 | --train_file /your_path/contrastive-product-matching/data/interim/wdc-lspc/training-sets/preprocessed_computers_train_$SIZE.pkl.gz \ 18 | --train_size=$SIZE \ 19 | --validation_file /your_path/contrastive-product-matching/data/interim/wdc-lspc/training-sets/preprocessed_computers_train_$SIZE.pkl.gz \ 20 | --test_file /your_path/contrastive-product-matching/data/interim/wdc-lspc/gold-standards/preprocessed_computers_gs.pkl.gz \ 21 | --evaluation_strategy=epoch \ 22 | --tokenizer="roberta-base" \ 23 | --grad_checkpoint=True \ 24 | --output_dir /your_path/contrastive-product-matching/reports/contrastive-ft-siamese/computers-$SIZE-$AUG$BATCH-$PREAUG$LR-$TEMP-unfrozen-roberta-base/ \ 25 | --per_device_train_batch_size=64 \ 26 | --learning_rate=5e-05 \ 27 | --weight_decay=0.01 \ 28 | --num_train_epochs=50 \ 29 | --lr_scheduler_type="linear" \ 30 | --warmup_ratio=0.05 \ 31 | --max_grad_norm=1.0 \ 32 | --fp16 \ 33 | --metric_for_best_model=loss \ 34 | --dataloader_num_workers=4 \ 35 | --disable_tqdm=True \ 36 | --save_strategy="epoch" \ 37 | --load_best_model_at_end \ 38 | --augment=$AUG \ 39 | #--do_param_opt \ -------------------------------------------------------------------------------- /src/contrastive/lspc/run_finetune_ssv_siamese_unfrozen_roberta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu_8 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --nodes=1 5 | #SBATCH --time=12:00:00 6 | #SBATCH --export=NONE 7 | BATCH=$1 8 | LR=$2 9 | TEMP=$3 10 | SIZE=$4 11 | AUG=$5 12 | PREAUG=$6 13 | python run_finetune_siamese.py \ 14 | --model_pretrained_checkpoint /your_path/contrastive-product-matching/reports/contrastive/computers-ssv-$SIZE-$PREAUG$BATCH-$LR-$TEMP-roberta-base/pytorch_model.bin \ 15 | --do_train \ 16 | --frozen=False \ 17 | --train_file /your_path/contrastive-product-matching/data/interim/wdc-lspc/training-sets/preprocessed_computers_train_$SIZE.pkl.gz \ 18 | --train_size=$SIZE \ 19 | --validation_file /your_path/contrastive-product-matching/data/interim/wdc-lspc/training-sets/preprocessed_computers_train_$SIZE.pkl.gz \ 20 | --test_file /your_path/contrastive-product-matching/data/interim/wdc-lspc/gold-standards/preprocessed_computers_gs.pkl.gz \ 21 | --evaluation_strategy=epoch \ 22 | --tokenizer="roberta-base" \ 23 | --grad_checkpoint=True \ 24 | --output_dir /your_path/contrastive-product-matching/reports/contrastive-ft-siamese/computers-ssv-$SIZE-$AUG$BATCH-$PREAUG$LR-$TEMP-unfrozen-roberta-base/ \ 25 | --per_device_train_batch_size=64 \ 26 | --learning_rate=5e-05 \ 27 | --weight_decay=0.01 \ 28 | --num_train_epochs=50 \ 29 | --lr_scheduler_type="linear" \ 30 | --warmup_ratio=0.05 \ 31 | --max_grad_norm=1.0 \ 32 | --fp16 \ 33 | --metric_for_best_model=f1 \ 34 | --dataloader_num_workers=4 \ 35 | --disable_tqdm=True \ 36 | --save_strategy="epoch" \ 37 | --load_best_model_at_end \ 38 | --augment=$AUG \ 39 | #--do_param_opt \ -------------------------------------------------------------------------------- /src/data/download_datasets.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from zipfile import ZipFile 3 | from pathlib import Path 4 | 5 | DATASETS = [ 6 | 'http://data.dws.informatik.uni-mannheim.de/largescaleproductcorpus/data/v2/repo-download/contrastive-data.zip' 7 | ] 8 | 9 | 10 | def download_datasets(): 11 | for link in DATASETS: 12 | 13 | '''iterate through all links in DATASETS 14 | and download them one by one''' 15 | 16 | # obtain filename by splitting url and getting 17 | # last string 18 | file_name = link.split('/')[-1] 19 | 20 | print("Downloading file:%s" % file_name) 21 | 22 | # create response object 23 | r = requests.get(link, stream=True) 24 | 25 | # download started 26 | with open(f'../../data/{file_name}', 'wb') as f: 27 | for chunk in r.iter_content(chunk_size=1024 * 1024): 28 | if chunk: 29 | f.write(chunk) 30 | 31 | print("%s downloaded!\n" % file_name) 32 | 33 | print("All files downloaded!") 34 | return 35 | 36 | 37 | def unzip_files(): 38 | for link in DATASETS: 39 | file_name = link.split('/')[-1] 40 | # opening the zip file in READ mode 41 | with ZipFile(f'../../data/{file_name}', 'r') as zip: 42 | # printing all the contents of the zip file 43 | zip.printdir() 44 | 45 | # extracting all the files 46 | print('Extracting all the files now...') 47 | zip.extractall(path='../../') 48 | print('Done!') 49 | 50 | 51 | if __name__ == "__main__": 52 | Path('../../data/').mkdir(parents=True, exist_ok=True) 53 | download_datasets() 54 | unzip_files() 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *.cover 45 | 46 | # Translations 47 | *.mo 48 | *.pot 49 | 50 | # Django stuff: 51 | *.log 52 | 53 | # Sphinx documentation 54 | docs/_build/ 55 | 56 | # PyBuilder 57 | target/ 58 | 59 | # DotEnv configuration 60 | .env 61 | 62 | # Database 63 | *.db 64 | *.rdb 65 | 66 | # Pycharm 67 | .idea 68 | 69 | # VS Code 70 | .vscode/ 71 | 72 | # Spyder 73 | .spyproject/ 74 | 75 | # Jupyter NB Checkpoints 76 | .ipynb_checkpoints/ 77 | 78 | # exclude data from source control by default 79 | /data/ 80 | 81 | # exclude cache from source control by default 82 | /cache/ 83 | 84 | # exclude models from source control by default 85 | /models/ 86 | 87 | /reports/contrastive/ 88 | /reports/contrastive-ft-siamese/ 89 | /reports/contrastive-ft-siamese-preaug/ 90 | 91 | # Mac OS-specific storage files 92 | .DS_Store 93 | 94 | # vim 95 | *.swp 96 | *.swo 97 | 98 | # Mypy cache 99 | .mypy_cache/ 100 | -------------------------------------------------------------------------------- /src/data/utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | np.random.seed(42) 4 | import random 5 | random.seed(42) 6 | 7 | import nltk 8 | from nltk import PorterStemmer 9 | from nltk.corpus import stopwords 10 | 11 | from copy import deepcopy 12 | 13 | import re 14 | PATTERN1 = re.compile("\"@\S+\s+") 15 | PATTERN2 = re.compile("\s+") 16 | 17 | def clean_string_wdcv2(words): 18 | if not words: 19 | return None 20 | words = words.partition('"')[2] 21 | words = words.rpartition('"')[0] 22 | words = re.sub(PATTERN1, ' ', words) 23 | words = re.sub(PATTERN2, ' ', words) 24 | words = words.replace('"', '') 25 | words = words.strip() 26 | return words 27 | 28 | def clean_specTableContent_wdcv2(words): 29 | if not words: 30 | return None 31 | words = re.sub(PATTERN2, ' ', words) 32 | words = words.strip() 33 | return words 34 | 35 | def tokenize(words, delimiter=None): 36 | #check for NaN 37 | if isinstance(words, float): 38 | if words != words: 39 | return [] 40 | words = str(words) 41 | return words.split(sep=delimiter) 42 | 43 | def remove_stopwords(words, lower=False): 44 | #check for NaN 45 | if isinstance(words, float): 46 | if words != words: 47 | return words 48 | stop_words_list = deepcopy(stopwords.words('english')) 49 | if lower: 50 | stop_words_list = list(map(lambda x: x.lower(), stop_words_list)) 51 | word_list = tokenize(words) 52 | word_list_stopwords_removed = [x for x in word_list if x not in stop_words_list] 53 | words_processed = ' '.join(word_list_stopwords_removed) 54 | return words_processed 55 | 56 | def stem(words): 57 | stemmer = PorterStemmer() 58 | word_list = tokenize(words) 59 | stemmed_words = [stemmer.stem(x) for x in word_list] 60 | words_processed = ' '.join(stemmed_words) 61 | return words_processed -------------------------------------------------------------------------------- /src/processing/preprocess/preprocess_ts_gs.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | np.random.seed(42) 4 | import random 5 | random.seed(42) 6 | 7 | import os 8 | import glob 9 | 10 | from src.data import utils 11 | 12 | 13 | def _cut_lspc(row): 14 | attributes = {'title_left': 50, 15 | 'title_right': 50, 16 | 'brand_left': 5, 17 | 'brand_right': 5, 18 | 'description_left': 100, 19 | 'description_right': 100, 20 | 'specTableContent_left': 200, 21 | 'specTableContent_right': 200} 22 | 23 | for attr, value in attributes.items(): 24 | try: 25 | row[attr] = ' '.join(row[attr].split(' ')[:value]) 26 | except AttributeError: 27 | continue 28 | return row 29 | 30 | if __name__ == '__main__': 31 | # preprocess training sets and gold standards 32 | print('BUILDING PREPROCESSED TRAINING SETS AND GOLD STANDARDS...') 33 | os.makedirs(os.path.dirname('../../../data/interim/wdc-lspc/training-sets/'), exist_ok=True) 34 | os.makedirs(os.path.dirname('../../../data/interim/wdc-lspc/gold-standards/'), exist_ok=True) 35 | 36 | for file in glob.glob('../../../data/raw/wdc-lspc/training-sets/*'): 37 | df = pd.read_json(file, lines=True) 38 | df['title_left'] = df['title_left'].apply(utils.clean_string_wdcv2) 39 | df['description_left'] = df['description_left'].apply(utils.clean_string_wdcv2) 40 | df['brand_left'] = df['brand_left'].apply(utils.clean_string_wdcv2) 41 | df['price_left'] = df['price_left'].apply(utils.clean_string_wdcv2) 42 | df['specTableContent_left'] = df['specTableContent_left'].apply(utils.clean_specTableContent_wdcv2) 43 | df['title_right'] = df['title_right'].apply(utils.clean_string_wdcv2) 44 | df['description_right'] = df['description_right'].apply(utils.clean_string_wdcv2) 45 | df['brand_right'] = df['brand_right'].apply(utils.clean_string_wdcv2) 46 | df['price_right'] = df['price_right'].apply(utils.clean_string_wdcv2) 47 | df['specTableContent_right'] = df['specTableContent_right'].apply(utils.clean_specTableContent_wdcv2) 48 | 49 | df = df.apply(_cut_lspc, axis=1) 50 | 51 | file = os.path.basename(file) 52 | file = file.replace('.json.gz', '.pkl.gz') 53 | file = f'preprocessed_{file}' 54 | df.to_pickle(f'../../../data/interim/wdc-lspc/training-sets/{file}') 55 | 56 | for file in glob.glob('../../../data/raw/wdc-lspc/gold-standards/*'): 57 | df = pd.read_json(file, lines=True) 58 | df['title_left'] = df['title_left'].apply(utils.clean_string_wdcv2) 59 | df['description_left'] = df['description_left'].apply(utils.clean_string_wdcv2) 60 | df['brand_left'] = df['brand_left'].apply(utils.clean_string_wdcv2) 61 | df['price_left'] = df['price_left'].apply(utils.clean_string_wdcv2) 62 | df['specTableContent_left'] = df['specTableContent_left'].apply(utils.clean_specTableContent_wdcv2) 63 | df['title_right'] = df['title_right'].apply(utils.clean_string_wdcv2) 64 | df['description_right'] = df['description_right'].apply(utils.clean_string_wdcv2) 65 | df['brand_right'] = df['brand_right'].apply(utils.clean_string_wdcv2) 66 | df['price_right'] = df['price_right'].apply(utils.clean_string_wdcv2) 67 | df['specTableContent_right'] = df['specTableContent_right'].apply(utils.clean_specTableContent_wdcv2) 68 | 69 | df = df.apply(_cut_lspc, axis=1) 70 | 71 | try: 72 | df = df.drop(columns='sampling') 73 | except KeyError: 74 | pass 75 | file = os.path.basename(file) 76 | file = file.replace('.json.gz', '.pkl.gz') 77 | file = f'preprocessed_{file}' 78 | df.to_pickle(f'../../../data/interim/wdc-lspc/gold-standards/{file}') 79 | 80 | print('FINISHED BUILDING PREPROCESSED TRAINING SETS AND GOLD STANDARDS...') 81 | 82 | -------------------------------------------------------------------------------- /src/contrastive/models/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Yonglong Tian (yonglong@mit.edu) 3 | Date: May 07, 2020 4 | """ 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class SupConLoss(nn.Module): 12 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 13 | It also supports the unsupervised contrastive loss in SimCLR""" 14 | def __init__(self, temperature=0.07, contrast_mode='all', 15 | base_temperature=0.07): 16 | super(SupConLoss, self).__init__() 17 | self.temperature = temperature 18 | self.contrast_mode = contrast_mode 19 | self.base_temperature = base_temperature 20 | 21 | def forward(self, features, labels=None, mask=None): 22 | """Compute loss for model. If both `labels` and `mask` are None, 23 | it degenerates to SimCLR unsupervised loss: 24 | https://arxiv.org/pdf/2002.05709.pdf 25 | Args: 26 | features: hidden vector of shape [bsz, n_views, ...]. 27 | labels: ground truth of shape [bsz]. 28 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 29 | has the same class as sample i. Can be asymmetric. 30 | Returns: 31 | A loss scalar. 32 | """ 33 | device = (torch.device('cuda') 34 | if features.is_cuda 35 | else torch.device('cpu')) 36 | 37 | if len(features.shape) < 3: 38 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 39 | 'at least 3 dimensions are required') 40 | if len(features.shape) > 3: 41 | features = features.view(features.shape[0], features.shape[1], -1) 42 | 43 | batch_size = features.shape[0] 44 | if labels is not None and mask is not None: 45 | raise ValueError('Cannot define both `labels` and `mask`') 46 | elif labels is None and mask is None: 47 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 48 | elif labels is not None: 49 | labels = labels.contiguous().view(-1, 1) 50 | if labels.shape[0] != batch_size: 51 | raise ValueError('Num of labels does not match num of features') 52 | mask = torch.eq(labels, labels.T).float().to(device) 53 | else: 54 | mask = mask.float().to(device) 55 | 56 | contrast_count = features.shape[1] 57 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 58 | if self.contrast_mode == 'one': 59 | anchor_feature = features[:, 0] 60 | anchor_count = 1 61 | elif self.contrast_mode == 'all': 62 | anchor_feature = contrast_feature 63 | anchor_count = contrast_count 64 | else: 65 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 66 | 67 | # compute logits 68 | anchor_dot_contrast = torch.div( 69 | torch.matmul(anchor_feature, contrast_feature.T), 70 | self.temperature) 71 | # for numerical stability 72 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 73 | logits = anchor_dot_contrast - logits_max.detach() 74 | 75 | # tile mask 76 | mask = mask.repeat(anchor_count, contrast_count) 77 | # mask-out self-contrast cases 78 | logits_mask = torch.scatter( 79 | torch.ones_like(mask), 80 | 1, 81 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 82 | 0 83 | ) 84 | mask = mask * logits_mask 85 | 86 | # compute log_prob 87 | exp_logits = torch.exp(logits) * logits_mask 88 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 89 | 90 | # compute mean of log-likelihood over positive 91 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 92 | 93 | # loss 94 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 95 | loss = loss.view(anchor_count, batch_size).mean() 96 | 97 | return loss -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Contrastive Product Matching 2 | 3 | This repository contains the code and data download links to reproduce the experiments of the paper "Supervised Contrastive Learning for Product Matching" by Ralph Peeters and Christian Bizer. [ArXiv link](https://arxiv.org/abs/2202.02098). A comparison of the results to other systems using different benchmark datasets is found at [Papers with Code - Entity Resolution](https://paperswithcode.com/task/entity-resolution/). 4 | 5 | * **Requirements** 6 | 7 | [Anaconda3](https://www.anaconda.com/products/individual) 8 | 9 | Please keep in mind that the code is not optimized for portable or even non-workstation devices. Some of the scripts may require large amounts of RAM (64GB+) and GPUs. It is advised to use a powerful workstation or server when experimenting with some of the larger files. 10 | 11 | The code has only been used and tested on Linux (CentOS) servers. 12 | 13 | * **Building the conda environment** 14 | 15 | To build the exact conda environment used for the experiments, navigate to the project root folder where the file *contrastive-product-matching.yml* is located and run ```conda env create -f contrastive-product-matching.yml``` 16 | 17 | Furthermore you need to install the project as a package. To do this, activate the environment with ```conda activate contrastive-product-matching```, navigate to the root folder of the project, and run ```pip install -e .``` 18 | 19 | * **Downloading the raw data files** 20 | 21 | Navigate to the *src/data/* folder and run ```python download_datasets.py``` to automatically download the files into the correct locations. 22 | You can find the data at *data/raw/* 23 | 24 | If you are only interested in the separate datasets, you can download the [WDC LSPC datasets](http://webdatacommons.org/largescaleproductcorpus/v2/index.html#toc6) and the [deepmatcher splits](https://github.com/anhaidgroup/deepmatcher/blob/master/Datasets.md) for the abt-buy and amazon-google datasets on the respective websites. 25 | 26 | * **Processing the data** 27 | 28 | To prepare the data for the experiments, run the following scripts in that order. Make sure to navigate to the respective folders first. 29 | 30 | 1. *src/processing/preprocess/preprocess_corpus.py* 31 | 2. *src/processing/preprocess/preprocess_ts_gs.py* 32 | 3. *src/processing/preprocess/preprocess_deepmatcher_datasets.py* 33 | 4. *src/processing/contrastive/prepare_data.py* 34 | 5. *src/processing/contrastive/prepare_data_deepmatcher.py* 35 | 36 | * **Running the Contrastive Pre-training and Cross-entropy Fine-tuning** 37 | 38 | Navigate to *src/contrastive/* 39 | 40 | You can find respective scripts for running the experiments of the paper in the subfolders *lspc/* *abtbuy/* and *amazongoogle/*. Note that you need to adjust the file path in these scripts for your system (replace ```your_path``` with ```path/to/repo```). 41 | 42 | * **Contrastive Pre-training** 43 | 44 | To run contrastive pre-training for the abtbuy dataset for example use 45 | 46 | ```bash abtbuy/run_pretraining_clean_roberta.sh BATCH_SIZE LEARNING_RATE TEMPERATURE (AUG)``` 47 | 48 | You need to specify batch site, learning rate and temperature as arguments here. Optionally you can also apply data augmentation by passing an augmentation method as last argument (use ```all-``` for the augmentation used in the paper). 49 | 50 | For the WDC Computers data you need to also supply the size of the training set, e.g. 51 | 52 | ```bash lspc/run_pretraining_roberta.sh BATCH_SIZE LEARNING_RATE TEMPERATURE TRAIN_SIZE (AUG)``` 53 | 54 | * **Cross-entropy Fine-tuning** 55 | 56 | Finally, to use the pre-trained models for fine-tuning, run any of the fine-tuning scripts in the respective folders, e.g. 57 | 58 | ```bash abtbuy/run_finetune_siamese_frozen_roberta.sh BATCH_SIZE LEARNING_RATE TEMPERATURE (AUG)``` 59 | 60 | Please note, that BATCH_SIZE refers to the batch size used in pre-training. The fine-tuning batch size is locked to 64 but can be adjusted in the bash scripts if needed. 61 | 62 | Analogously for fine-tuning WDC computers, add the train size: 63 | 64 | ```bash lspc/run_finetune_siamese_frozen_roberta.sh BATCH_SIZE LEARNING_RATE TEMPERATURE TRAIN_SIZE (AUG)``` 65 | 66 | 67 | -------- 68 | 69 | Project based on the [cookiecutter data science project template](https://drivendata.github.io/cookiecutter-data-science/). #cookiecutterdatascience 70 | -------------------------------------------------------------------------------- /src/contrastive/data/data_collators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | np.random.seed(42) 3 | import random 4 | random.seed(42) 5 | 6 | from dataclasses import dataclass 7 | from typing import Optional 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase 13 | 14 | from pdb import set_trace 15 | 16 | # collator for self-supervised contrastive pre-training 17 | @dataclass 18 | class DataCollatorContrastivePretrainSelfSupervised: 19 | 20 | tokenizer: PreTrainedTokenizerBase 21 | max_length: Optional[int] = 128 22 | pad_to_multiple_of: Optional[int] = None 23 | return_tensors: str = "pt" 24 | 25 | def __call__(self, input): 26 | 27 | features_left = [x[0]['features'] for x in input] 28 | labels = [x[0]['labels'] for x in input] 29 | 30 | batch = self.tokenizer(features_left, padding=True, truncation=True, max_length=self.max_length, return_tensors=self.return_tensors) 31 | 32 | if 'token_type_ids' in batch.keys(): 33 | del batch['token_type_ids'] 34 | 35 | batch['labels'] = torch.LongTensor(labels) 36 | 37 | return batch 38 | 39 | # collator for supervised contrastive pre-training for WDC Computers 40 | @dataclass 41 | class DataCollatorContrastivePretrain: 42 | 43 | tokenizer: PreTrainedTokenizerBase 44 | max_length: Optional[int] = 128 45 | pad_to_multiple_of: Optional[int] = None 46 | return_tensors: str = "pt" 47 | 48 | def __call__(self, input): 49 | 50 | features_left = [x[0]['features'] for x in input] 51 | features_right = [x[1]['features'] for x in input] 52 | labels = [x[0]['labels'] for x in input] 53 | 54 | batch_left = self.tokenizer(features_left, padding=True, truncation=True, max_length=self.max_length, return_tensors=self.return_tensors) 55 | batch_right = self.tokenizer(features_right, padding=True, truncation=True, max_length=self.max_length, return_tensors=self.return_tensors) 56 | 57 | batch = batch_left 58 | if 'token_type_ids' in batch.keys(): 59 | del batch['token_type_ids'] 60 | batch['input_ids_right'] = batch_right['input_ids'] 61 | batch['attention_mask_right'] = batch_right['attention_mask'] 62 | 63 | batch['labels'] = torch.LongTensor(labels) 64 | 65 | return batch 66 | 67 | # collator for supervised contrastive pre-training for Abt-Buy and Amazon-Google 68 | # randomly chooses the sampling dataset when using source-aware sampling 69 | @dataclass 70 | class DataCollatorContrastivePretrainDeepmatcher: 71 | 72 | tokenizer: PreTrainedTokenizerBase 73 | max_length: Optional[int] = 128 74 | pad_to_multiple_of: Optional[int] = None 75 | return_tensors: str = "pt" 76 | 77 | def __call__(self, input_both): 78 | 79 | rnd = random.choice([0,1]) 80 | input = [x[rnd] for x in input_both] 81 | 82 | features_left = [x[0]['features'] for x in input] 83 | features_right = [x[1]['features'] for x in input] 84 | 85 | labels = [x[0]['labels'] for x in input] 86 | 87 | batch_left = self.tokenizer(features_left, padding=True, truncation=True, max_length=self.max_length, return_tensors=self.return_tensors) 88 | batch_right = self.tokenizer(features_right, padding=True, truncation=True, max_length=self.max_length, return_tensors=self.return_tensors) 89 | 90 | batch = batch_left 91 | if 'token_type_ids' in batch.keys(): 92 | del batch['token_type_ids'] 93 | batch['input_ids_right'] = batch_right['input_ids'] 94 | batch['attention_mask_right'] = batch_right['attention_mask'] 95 | 96 | batch['labels'] = torch.LongTensor(labels) 97 | 98 | return batch 99 | 100 | # collator for pair-wise cross-entropy fine-tuning 101 | @dataclass 102 | class DataCollatorContrastiveClassification: 103 | 104 | tokenizer: PreTrainedTokenizerBase 105 | max_length: Optional[int] = 128 106 | pad_to_multiple_of: Optional[int] = None 107 | return_tensors: str = "pt" 108 | 109 | def __call__(self, input): 110 | 111 | features_left = [x['features_left'] for x in input] 112 | features_right = [x['features_right'] for x in input] 113 | labels = [x['labels'] for x in input] 114 | 115 | batch_left = self.tokenizer(features_left, padding=True, truncation=True, max_length=self.max_length, return_tensors=self.return_tensors) 116 | batch_right = self.tokenizer(features_right, padding=True, truncation=True, max_length=self.max_length, return_tensors=self.return_tensors) 117 | 118 | batch = batch_left 119 | if 'token_type_ids' in batch.keys(): 120 | del batch['token_type_ids'] 121 | batch['input_ids_right'] = batch_right['input_ids'] 122 | batch['attention_mask_right'] = batch_right['attention_mask'] 123 | 124 | batch['labels'] = torch.LongTensor(labels) 125 | 126 | return batch -------------------------------------------------------------------------------- /src/processing/preprocess/preprocess-deepmatcher-datasets.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | np.random.seed(42) 4 | import random 5 | random.seed(42) 6 | 7 | import os 8 | 9 | def assign_clusterid(identifier, cluster_id_dict, cluster_id_amount): 10 | try: 11 | result = cluster_id_dict[identifier] 12 | except KeyError: 13 | result = cluster_id_amount 14 | return result 15 | 16 | def preprocess_dataset(arg_tuple): 17 | 18 | handle, id_handle_left, id_handle_right = arg_tuple 19 | 20 | print(f'BUILDING {handle} TRAIN, VALID, GS...') 21 | 22 | left_df = pd.read_csv(f'../../../data/raw/{handle}/tableA.csv', engine='python') 23 | right_df = pd.read_csv(f'../../../data/raw/{handle}/tableB.csv', engine='python') 24 | 25 | left_df['id'] = f'{id_handle_left}_' + left_df['id'].astype(str) 26 | right_df['id'] = f'{id_handle_right}_' + right_df['id'].astype(str) 27 | 28 | left_df = left_df.set_index('id', drop=False) 29 | right_df = right_df.set_index('id', drop=False) 30 | left_df = left_df.fillna('') 31 | right_df = right_df.fillna('') 32 | 33 | train = pd.read_csv(f'../../../data/raw/{handle}/train.csv') 34 | test = pd.read_csv(f'../../../data/raw/{handle}/test.csv') 35 | valid = pd.read_csv(f'../../../data/raw/{handle}/valid.csv') 36 | 37 | full = train.append(valid, ignore_index=True).append(test, ignore_index=True) 38 | full = full[full['label'] == 1] 39 | 40 | full['ltable_id'] = f'{id_handle_left}_' + full['ltable_id'].astype(str) 41 | full['rtable_id'] = f'{id_handle_right}_' + full['rtable_id'].astype(str) 42 | 43 | bucket_list = [] 44 | for i, row in full.iterrows(): 45 | left = f'{row["ltable_id"]}' 46 | right = f'{row["rtable_id"]}' 47 | found = False 48 | for bucket in bucket_list: 49 | if left in bucket and row['label'] == 1: 50 | bucket.add(right) 51 | found = True 52 | break 53 | elif right in bucket and row['label'] == 1: 54 | bucket.add(left) 55 | found = True 56 | break 57 | if not found: 58 | bucket_list.append(set([left, right])) 59 | 60 | cluster_id_dict = {} 61 | 62 | for i, id_set in enumerate(bucket_list): 63 | for v in id_set: 64 | cluster_id_dict[v] = i 65 | 66 | train['ltable_id'] = f'{id_handle_left}_' + train['ltable_id'].astype(str) 67 | train['rtable_id'] = f'{id_handle_right}_' + train['rtable_id'].astype(str) 68 | 69 | test['ltable_id'] = f'{id_handle_left}_' + test['ltable_id'].astype(str) 70 | test['rtable_id'] = f'{id_handle_right}_' + test['rtable_id'].astype(str) 71 | 72 | valid['ltable_id'] = f'{id_handle_left}_' + valid['ltable_id'].astype(str) 73 | valid['rtable_id'] = f'{id_handle_right}_' + valid['rtable_id'].astype(str) 74 | 75 | train['label'] = train['label'].apply(lambda x: int(x)) 76 | test['label'] = test['label'].apply(lambda x: int(x)) 77 | valid['label'] = valid['label'].apply(lambda x: int(x)) 78 | 79 | valid['pair_id'] = valid['ltable_id'] + '#' + valid['rtable_id'] 80 | 81 | train = train.append(valid, ignore_index=True) 82 | 83 | train_left = left_df.loc[list(train['ltable_id'].values)] 84 | train_right = right_df.loc[list(train['rtable_id'].values)] 85 | train_labels = [int(x) for x in list(train['label'].values)] 86 | 87 | gs_left = left_df.loc[list(test['ltable_id'].values)] 88 | gs_right = right_df.loc[list(test['rtable_id'].values)] 89 | gs_labels = [int(x) for x in list(test['label'].values)] 90 | 91 | train_left = train_left.reset_index(drop=True) 92 | train_right = train_right.reset_index(drop=True) 93 | gs_left = gs_left.reset_index(drop=True) 94 | gs_right = gs_right.reset_index(drop=True) 95 | 96 | cluster_id_amount = len(bucket_list) 97 | 98 | train_left['cluster_id'] = train_left['id'].apply(assign_clusterid, args=(cluster_id_dict, cluster_id_amount)) 99 | train_right['cluster_id'] = train_right['id'].apply(assign_clusterid, args=(cluster_id_dict, cluster_id_amount)) 100 | gs_left['cluster_id'] = gs_left['id'].apply(assign_clusterid, args=(cluster_id_dict, cluster_id_amount)) 101 | gs_right['cluster_id'] = gs_right['id'].apply(assign_clusterid, args=(cluster_id_dict, cluster_id_amount)) 102 | 103 | train_df = train_left.join(train_right, lsuffix='_left', rsuffix='_right') 104 | train_df['label'] = train_labels 105 | train_df['pair_id'] = train_df['id_left'] + '#' + train_df['id_right'] 106 | assert len(train_df) == len(train) 107 | 108 | gs_df = gs_left.join(gs_right, lsuffix='_left', rsuffix='_right') 109 | gs_df['label'] = gs_labels 110 | gs_df['pair_id'] = gs_df['id_left'] + '#' + gs_df['id_right'] 111 | assert len(gs_df) == len(test) 112 | 113 | print(f'Size of training set: {len(train_df)}') 114 | print(f'Size of gold standard: {len(gs_df)}') 115 | print(f'Distribution of training set labels: \n{train_df["label"].value_counts()}') 116 | print(f'Distribution of gold standard labels: \n{gs_df["label"].value_counts()}') 117 | 118 | os.makedirs(os.path.dirname(f'../../../data/interim/{handle}/'), exist_ok=True) 119 | 120 | train_df.to_json(f'../../../data/interim/{handle}/{handle}-train.json.gz', compression='gzip', lines=True, orient='records') 121 | valid['pair_id'].to_csv(f'../../../data/interim/{handle}/{handle}-valid.csv', header=True, index=False) 122 | gs_df.to_json(f'../../../data/interim/{handle}/{handle}-gs.json.gz', compression='gzip', lines=True, orient='records') 123 | 124 | print(f'FINISHED BULDING {handle} DATASETS\n') 125 | 126 | 127 | if __name__ == '__main__': 128 | datasets = [ 129 | ('abt-buy', 'abt', 'buy'), 130 | ('amazon-google', 'amazon', 'google') 131 | ] 132 | for dataset in datasets: 133 | preprocess_dataset(dataset) -------------------------------------------------------------------------------- /src/processing/contrastive/prepare-data-deepmatcher.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | np.random.seed(42) 4 | import random 5 | random.seed(42) 6 | 7 | from pathlib import Path 8 | import shutil 9 | import os 10 | 11 | def assign_clusterid(identifier, cluster_id_dict, cluster_id_amount): 12 | try: 13 | result = cluster_id_dict[identifier] 14 | except KeyError: 15 | result = cluster_id_amount 16 | return result 17 | 18 | def preprocess_dataset(arg_tuple): 19 | 20 | handle, id_handle_left, id_handle_right = arg_tuple 21 | 22 | print(f'BUILDING {handle} TRAIN, VALID, GS...') 23 | 24 | left_df = pd.read_csv(f'../../../data/raw/{handle}/tableA.csv', engine='python') 25 | right_df = pd.read_csv(f'../../../data/raw/{handle}/tableB.csv', engine='python') 26 | 27 | left_df['id'] = f'{id_handle_left}_' + left_df['id'].astype(str) 28 | right_df['id'] = f'{id_handle_right}_' + right_df['id'].astype(str) 29 | 30 | left_df = left_df.set_index('id', drop=False) 31 | right_df = right_df.set_index('id', drop=False) 32 | left_df = left_df.fillna('') 33 | right_df = right_df.fillna('') 34 | 35 | train = pd.read_csv(f'../../../data/raw/{handle}/train.csv') 36 | test = pd.read_csv(f'../../../data/raw/{handle}/test.csv') 37 | valid = pd.read_csv(f'../../../data/raw/{handle}/valid.csv') 38 | 39 | full = train.append(valid, ignore_index=True).append(test, ignore_index=True) 40 | full = full[full['label'] == 1] 41 | 42 | full['ltable_id'] = f'{id_handle_left}_' + full['ltable_id'].astype(str) 43 | full['rtable_id'] = f'{id_handle_right}_' + full['rtable_id'].astype(str) 44 | 45 | 46 | # Build connected components of correspondence graph using binning on positive examples 47 | bucket_list = [] 48 | for i, row in full.iterrows(): 49 | left = f'{row["ltable_id"]}' 50 | right = f'{row["rtable_id"]}' 51 | found = False 52 | for bucket in bucket_list: 53 | if left in bucket and row['label'] == 1: 54 | bucket.add(right) 55 | found = True 56 | break 57 | elif right in bucket and row['label'] == 1: 58 | bucket.add(left) 59 | found = True 60 | break 61 | if not found: 62 | bucket_list.append(set([left, right])) 63 | 64 | cluster_id_dict = {} 65 | 66 | for i, id_set in enumerate(bucket_list): 67 | for v in id_set: 68 | cluster_id_dict[v] = i 69 | 70 | train['ltable_id'] = f'{id_handle_left}_' + train['ltable_id'].astype(str) 71 | train['rtable_id'] = f'{id_handle_right}_' + train['rtable_id'].astype(str) 72 | 73 | test['ltable_id'] = f'{id_handle_left}_' + test['ltable_id'].astype(str) 74 | test['rtable_id'] = f'{id_handle_right}_' + test['rtable_id'].astype(str) 75 | 76 | valid['ltable_id'] = f'{id_handle_left}_' + valid['ltable_id'].astype(str) 77 | valid['rtable_id'] = f'{id_handle_right}_' + valid['rtable_id'].astype(str) 78 | 79 | train['label'] = train['label'].apply(lambda x: int(x)) 80 | test['label'] = test['label'].apply(lambda x: int(x)) 81 | valid['label'] = valid['label'].apply(lambda x: int(x)) 82 | 83 | valid['pair_id'] = valid['ltable_id'] + '#' + valid['rtable_id'] 84 | 85 | train = train.append(valid, ignore_index=True) 86 | 87 | train_left = left_df.loc[list(train['ltable_id'].values)] 88 | train_right = right_df.loc[list(train['rtable_id'].values)] 89 | train_labels = [int(x) for x in list(train['label'].values)] 90 | 91 | gs_left = left_df.loc[list(test['ltable_id'].values)] 92 | gs_right = right_df.loc[list(test['rtable_id'].values)] 93 | gs_labels = [int(x) for x in list(test['label'].values)] 94 | 95 | train_left = train_left.reset_index(drop=True) 96 | train_right = train_right.reset_index(drop=True) 97 | gs_left = gs_left.reset_index(drop=True) 98 | gs_right = gs_right.reset_index(drop=True) 99 | 100 | cluster_id_amount = len(bucket_list) 101 | 102 | train_left['cluster_id'] = train_left['id'].apply(assign_clusterid, args=(cluster_id_dict, cluster_id_amount)) 103 | train_right['cluster_id'] = train_right['id'].apply(assign_clusterid, args=(cluster_id_dict, cluster_id_amount)) 104 | gs_left['cluster_id'] = gs_left['id'].apply(assign_clusterid, args=(cluster_id_dict, cluster_id_amount)) 105 | gs_right['cluster_id'] = gs_right['id'].apply(assign_clusterid, args=(cluster_id_dict, cluster_id_amount)) 106 | 107 | train_df = train_left.join(train_right, lsuffix='_left', rsuffix='_right') 108 | train_df['label'] = train_labels 109 | train_df['pair_id'] = train_df['id_left'] + '#' + train_df['id_right'] 110 | assert len(train_df) == len(train) 111 | 112 | gs_df = gs_left.join(gs_right, lsuffix='_left', rsuffix='_right') 113 | gs_df['label'] = gs_labels 114 | gs_df['pair_id'] = gs_df['id_left'] + '#' + gs_df['id_right'] 115 | assert len(gs_df) == len(test) 116 | 117 | print(f'Size of training set: {len(train_df)}') 118 | print(f'Size of gold standard: {len(gs_df)}') 119 | print(f'Distribution of training set labels: \n{train_df["label"].value_counts()}') 120 | print(f'Distribution of gold standard labels: \n{gs_df["label"].value_counts()}') 121 | 122 | merged_ids = set() 123 | merged_ids.update(train_df['id_left']) 124 | merged_ids.update(train_df['id_right']) 125 | 126 | merged_ids.update(gs_df['id_left']) 127 | merged_ids.update(gs_df['id_right']) 128 | 129 | entity_set = left_df[left_df['id'].isin(merged_ids)] 130 | entity_set = entity_set.append(right_df[right_df['id'].isin(merged_ids)]) 131 | # In next line all connected components are assigned the same label 132 | # Note, that all single nodes are assigned the same label here 133 | entity_set['cluster_id'] = entity_set['id'].apply(assign_clusterid, args=(cluster_id_dict, cluster_id_amount)) 134 | 135 | # assign increasing integer label to single nodes 136 | single_entities = entity_set[entity_set['cluster_id'] == cluster_id_amount].copy() 137 | single_entities = single_entities.reset_index(drop=True) 138 | single_entities['cluster_id'] = single_entities['cluster_id'] + single_entities.index 139 | 140 | entity_set = entity_set.drop(single_entities['id']) 141 | entity_set = entity_set.append(single_entities) 142 | entity_set = entity_set.reset_index(drop=True) 143 | 144 | print(f'Amount of entity descriptions: {len(entity_set)}') 145 | print(f'Amount of clusters: {len(entity_set["cluster_id"].unique())}') 146 | 147 | os.makedirs(os.path.dirname(f'../../../data/processed/{handle}/contrastive/'), exist_ok=True) 148 | 149 | entity_set.to_pickle(f'../../../data/processed/{handle}/contrastive/{handle}-train.pkl.gz', compression='gzip') 150 | 151 | print(f'FINISHED BULDING {handle} DATASETS\n') 152 | 153 | if __name__ == '__main__': 154 | 155 | datasets = [ 156 | ('abt-buy', 'abt', 'buy'), 157 | ('amazon-google', 'amazon', 'google') 158 | ] 159 | for dataset in datasets: 160 | preprocess_dataset(dataset) -------------------------------------------------------------------------------- /src/contrastive/models/modeling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn import BCEWithLogitsLoss 5 | 6 | from transformers import AutoModel, AutoConfig 7 | from src.contrastive.models.loss import SupConLoss 8 | 9 | def mean_pooling(model_output, attention_mask): 10 | token_embeddings = model_output[0] #First element of model_output contains all token embeddings 11 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 12 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 13 | 14 | class BaseEncoder(nn.Module): 15 | 16 | def __init__(self, len_tokenizer, model='huawei-noah/TinyBERT_General_4L_312D'): 17 | super().__init__() 18 | 19 | self.transformer = AutoModel.from_pretrained(model) 20 | self.transformer.resize_token_embeddings(len_tokenizer) 21 | 22 | def forward(self, input_ids, attention_mask): 23 | 24 | output = self.transformer(input_ids, attention_mask) 25 | 26 | return output 27 | 28 | # self-supervised contrastive model 29 | class ContrastiveSelfSupervisedPretrainModel(nn.Module): 30 | 31 | def __init__(self, len_tokenizer, model='huawei-noah/TinyBERT_General_4L_312D', ssv=True, pool=False, proj='mlp', temperature=0.07, num_augments=2): 32 | super().__init__() 33 | 34 | self.ssv = ssv 35 | self.pool = pool 36 | self.proj = proj 37 | self.temperature = temperature 38 | self.num_augments = num_augments 39 | self.criterion = SupConLoss(self.temperature) 40 | 41 | self.encoder = BaseEncoder(len_tokenizer, model) 42 | self.config = self.encoder.transformer.config 43 | 44 | self.contrastive_head = ContrastivePretrainHead(self.config.hidden_size, self.proj) 45 | 46 | 47 | def forward(self, input_ids, attention_mask, labels): 48 | 49 | additional_outputs = [] 50 | if self.pool: 51 | output_left = self.encoder(input_ids, attention_mask) 52 | output_left = mean_pooling(output_left, attention_mask) 53 | 54 | for num in range(self.num_augments-1): 55 | output_right = self.encoder(input_ids, attention_mask) 56 | output_right = mean_pooling(output_right, attention).unsqueeze(1) 57 | additional_outputs.append(output_right) 58 | else: 59 | output_left = self.encoder(input_ids, attention_mask)['pooler_output'].unsqueeze(1) 60 | for num in range(self.num_augments-1): 61 | additional_outputs.append(self.encoder(input_ids, attention_mask)['pooler_output'].unsqueeze(1)) 62 | 63 | output = torch.cat((output_left, *additional_outputs), 1) 64 | 65 | output = F.normalize(output, dim=-1) 66 | 67 | proj_output = self.contrastive_head(output) 68 | 69 | proj_output = F.normalize(proj_output, dim=-1) 70 | 71 | if self.ssv: 72 | loss = self.criterion(proj_output) 73 | else: 74 | loss = self.criterion(proj_output, labels) 75 | 76 | return ((loss,)) 77 | 78 | # supervised contrastive model 79 | class ContrastivePretrainModel(nn.Module): 80 | 81 | def __init__(self, len_tokenizer, model='huawei-noah/TinyBERT_General_4L_312D', pool=True, proj='mlp', temperature=0.07): 82 | super().__init__() 83 | 84 | self.pool = pool 85 | self.proj = proj 86 | self.temperature = temperature 87 | self.criterion = SupConLoss(self.temperature) 88 | 89 | self.encoder = BaseEncoder(len_tokenizer, model) 90 | self.config = self.encoder.transformer.config 91 | 92 | def forward(self, input_ids, attention_mask, labels, input_ids_right, attention_mask_right): 93 | 94 | if self.pool: 95 | output_left = self.encoder(input_ids, attention_mask) 96 | output_left = mean_pooling(output_left, attention_mask) 97 | 98 | output_right = self.encoder(input_ids_right, attention_mask_right) 99 | output_right = mean_pooling(output_right, attention_mask_right) 100 | else: 101 | output_left = self.encoder(input_ids, attention_mask)['pooler_output'] 102 | output_right = self.encoder(input_ids_right, attention_mask_right)['pooler_output'] 103 | 104 | output = torch.cat((output_left.unsqueeze(1), output_right.unsqueeze(1)), 1) 105 | 106 | output = F.normalize(output, dim=-1) 107 | 108 | loss = self.criterion(output, labels) 109 | 110 | return ((loss,)) 111 | 112 | class ContrastivePretrainHead(nn.Module): 113 | 114 | def __init__(self, hidden_size, proj='mlp'): 115 | super().__init__() 116 | if proj == 'linear': 117 | self.proj = nn.Linear(hidden_size, hidden_size) 118 | elif proj == 'mlp': 119 | self.proj = nn.Sequential( 120 | nn.Linear(hidden_size, hidden_size), 121 | nn.ReLU(), 122 | nn.Linear(hidden_size, hidden_size) 123 | ) 124 | 125 | def forward(self, hidden_states): 126 | x = self.proj(hidden_states) 127 | return x 128 | 129 | # cross-entropy fine-tuning model 130 | class ContrastiveClassifierModel(nn.Module): 131 | 132 | def __init__(self, len_tokenizer, checkpoint_path, model='huawei-noah/TinyBERT_General_4L_312D', pool=True, comb_fct='concat-abs-diff-mult', frozen=True, pos_neg=False): 133 | super().__init__() 134 | 135 | self.pool = pool 136 | self.frozen = frozen 137 | self.checkpoint_path = checkpoint_path 138 | self.comb_fct = comb_fct 139 | self.pos_neg = pos_neg 140 | 141 | self.encoder = BaseEncoder(len_tokenizer, model) 142 | self.config = self.encoder.transformer.config 143 | if self.pos_neg: 144 | self.criterion = BCEWithLogitsLoss(pos_weight=torch.Tensor([pos_neg])) 145 | else: 146 | self.criterion = BCEWithLogitsLoss() 147 | self.classification_head = ClassificationHead(self.config, self.comb_fct) 148 | 149 | if self.checkpoint_path: 150 | checkpoint = torch.load(self.checkpoint_path) 151 | self.load_state_dict(checkpoint, strict=False) 152 | 153 | if self.frozen: 154 | for param in self.encoder.parameters(): 155 | param.requires_grad = False 156 | 157 | def forward(self, input_ids, attention_mask, labels, input_ids_right, attention_mask_right): 158 | 159 | if self.pool: 160 | output_left = self.encoder(input_ids, attention_mask) 161 | output_left = mean_pooling(output_left, attention_mask) 162 | 163 | output_right = self.encoder(input_ids_right, attention_mask_right) 164 | output_right = mean_pooling(output_right, attention_mask_right) 165 | else: 166 | output_left = self.encoder(input_ids, attention_mask)['pooler_output'] 167 | output_right = self.encoder(input_ids_right, attention_mask_right)['pooler_output'] 168 | 169 | if self.comb_fct == 'concat-abs-diff': 170 | output = torch.cat((output_left, output_right, torch.abs(output_left - output_right)), -1) 171 | elif self.comb_fct == 'concat-mult': 172 | output = torch.cat((output_left, output_right, output_left * output_right), -1) 173 | elif self.comb_fct == 'concat': 174 | output = torch.cat((output_left, output_right), -1) 175 | elif self.comb_fct == 'abs-diff': 176 | output = torch.abs(output_left - output_right) 177 | elif self.comb_fct == 'mult': 178 | output = output_left * output_right 179 | elif self.comb_fct == 'abs-diff-mult': 180 | output = torch.cat((torch.abs(output_left - output_right), output_left * output_right), -1) 181 | elif self.comb_fct == 'concat-abs-diff-mult': 182 | output = torch.cat((output_left, output_right, torch.abs(output_left - output_right), output_left * output_right), -1) 183 | 184 | proj_output = self.classification_head(output) 185 | 186 | loss = self.criterion(proj_output.view(-1), labels.float()) 187 | 188 | proj_output = torch.sigmoid(proj_output) 189 | 190 | return (loss, proj_output) 191 | 192 | class ClassificationHead(nn.Module): 193 | 194 | def __init__(self, config, comb_fct): 195 | super().__init__() 196 | 197 | if comb_fct in ['concat-abs-diff', 'concat-mult']: 198 | self.hidden_size = 3 * config.hidden_size 199 | elif comb_fct in ['concat', 'abs-diff-mult']: 200 | self.hidden_size = 2 * config.hidden_size 201 | elif comb_fct in ['abs-diff', 'mult']: 202 | self.hidden_size = config.hidden_size 203 | elif comb_fct in ['concat-abs-diff-mult']: 204 | self.hidden_size = 4 * config.hidden_size 205 | 206 | classifier_dropout = config.hidden_dropout_prob 207 | 208 | self.dropout = nn.Dropout(classifier_dropout) 209 | self.out_proj = nn.Linear(self.hidden_size, 1) 210 | 211 | def forward(self, features): 212 | x = self.dropout(features) 213 | x = self.out_proj(x) 214 | return x -------------------------------------------------------------------------------- /src/contrastive/run_pretraining_deepmatcher.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run contrastive pre-training 3 | """ 4 | import numpy as np 5 | np.random.seed(42) 6 | import random 7 | random.seed(42) 8 | 9 | import logging 10 | import os 11 | import sys 12 | from dataclasses import dataclass, field 13 | from typing import Optional 14 | import json 15 | 16 | import torch 17 | 18 | import transformers as transformers 19 | 20 | from transformers import ( 21 | HfArgumentParser, 22 | Trainer, 23 | TrainingArguments, 24 | set_seed 25 | ) 26 | from transformers.file_utils import is_offline_mode 27 | from transformers.trainer_utils import get_last_checkpoint 28 | from transformers.utils import check_min_version 29 | from transformers.utils.versions import require_version 30 | 31 | from src.contrastive.models.modeling import ContrastivePretrainModel 32 | from src.contrastive.data.datasets import ContrastivePretrainDatasetDeepmatcher 33 | from src.contrastive.data.data_collators import DataCollatorContrastivePretrainDeepmatcher 34 | from src.contrastive.models.metrics import compute_metrics_bce 35 | 36 | from transformers import EarlyStoppingCallback 37 | 38 | from pdb import set_trace 39 | 40 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 41 | check_min_version("4.8.2") 42 | 43 | logger = logging.getLogger(__name__) 44 | 45 | MODEL_PARAMS=['pool'] 46 | 47 | @dataclass 48 | class ModelArguments: 49 | """ 50 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 51 | """ 52 | 53 | model_pretrained_checkpoint: Optional[str] = field( 54 | default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 55 | ) 56 | do_param_opt: Optional[bool] = field( 57 | default=False, metadata={"help": "If aou want to do hyperparamter optimization"} 58 | ) 59 | grad_checkpoint: Optional[bool] = field( 60 | default=True, metadata={"help": "If aou want to use gradient checkpointing"} 61 | ) 62 | temperature: Optional[float] = field( 63 | default=0.07, 64 | metadata={ 65 | "help": "Temperature for contrastive loss" 66 | }, 67 | ) 68 | tokenizer: Optional[str] = field( 69 | default='huawei-noah/TinyBERT_General_4L_312D', 70 | metadata={ 71 | "help": "Tokenizer to use" 72 | }, 73 | ) 74 | 75 | @dataclass 76 | class DataTrainingArguments: 77 | """ 78 | Arguments pertaining to what data we are going to input our model for training and eval. 79 | """ 80 | 81 | train_file: Optional[str] = field( 82 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 83 | ) 84 | interm_file: Optional[str] = field( 85 | default=None, metadata={"help": "The intermediate training set."} 86 | ) 87 | clean: Optional[bool] = field( 88 | default=False, metadata={"help": "Only use intermediate training set"} 89 | ) 90 | augment: Optional[str] = field( 91 | default=None, metadata={"help": "The data augmentation to use."} 92 | ) 93 | id_deduction_set: Optional[str] = field( 94 | default=None, metadata={"help": "The size of the training set."} 95 | ) 96 | train_size: Optional[str] = field( 97 | default=None, metadata={"help": "The size of the training set."} 98 | ) 99 | max_train_samples: Optional[int] = field( 100 | default=None, 101 | metadata={ 102 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 103 | "value if set." 104 | }, 105 | ) 106 | validation_file: Optional[str] = field( 107 | default=None, 108 | metadata={ 109 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 110 | "(a jsonlines or csv file)." 111 | }, 112 | ) 113 | max_validation_samples: Optional[int] = field( 114 | default=None, 115 | metadata={ 116 | "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " 117 | "value if set." 118 | }, 119 | ) 120 | test_file: Optional[str] = field( 121 | default=None, 122 | metadata={ 123 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)." 124 | }, 125 | ) 126 | max_test_samples: Optional[int] = field( 127 | default=None, 128 | metadata={ 129 | "help": "For debugging purposes or quicker training, truncate the number of test examples to this " 130 | "value if set." 131 | }, 132 | ) 133 | dataset_name: Optional[str] = field( 134 | default='lspc', 135 | metadata={ 136 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 137 | "(a jsonlines or csv file)." 138 | }, 139 | ) 140 | def __post_init__(self): 141 | if self.train_file is None and self.validation_file is None: 142 | raise ValueError("Need a training file.") 143 | 144 | 145 | 146 | def main(): 147 | 148 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 149 | 150 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 151 | 152 | # Setup logging 153 | logging.basicConfig( 154 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 155 | datefmt="%m/%d/%Y %H:%M:%S", 156 | handlers=[logging.StreamHandler(sys.stdout)], 157 | ) 158 | log_level = training_args.get_process_log_level() 159 | logger.setLevel(log_level) 160 | transformers.utils.logging.set_verbosity(log_level) 161 | transformers.utils.logging.enable_default_handler() 162 | transformers.utils.logging.enable_explicit_format() 163 | 164 | # Log on each process the small summary: 165 | logger.warning( 166 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 167 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 168 | ) 169 | logger.info(f"Training/evaluation parameters {training_args}") 170 | 171 | # Detecting last checkpoint. 172 | last_checkpoint = None 173 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 174 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 175 | 176 | # Set seed before initializing model. 177 | set_seed(training_args.seed) 178 | 179 | data_files = {} 180 | if data_args.train_file is not None: 181 | data_files["train"] = data_args.train_file 182 | if data_args.validation_file is not None: 183 | data_files["validation"] = data_args.validation_file 184 | if data_args.test_file is not None: 185 | data_files["test"] = data_args.test_file 186 | raw_datasets = data_files 187 | 188 | if training_args.do_train: 189 | if "train" not in raw_datasets: 190 | raise ValueError("--do_train requires a train dataset") 191 | train_dataset = raw_datasets["train"] 192 | if data_args.interm_file is not None: 193 | train_dataset = ContrastivePretrainDatasetDeepmatcher(train_dataset, tokenizer=model_args.tokenizer, intermediate_set=data_args.interm_file, clean=data_args.clean, dataset=data_args.dataset_name, deduction_set=data_args.id_deduction_set, aug=data_args.augment) 194 | else: 195 | train_dataset = ContrastivePretrainDatasetDeepmatcher(train_dataset, tokenizer=model_args.tokenizer, clean=data_args.clean, dataset=data_args.dataset_name, deduction_set=data_args.id_deduction_set, aug=data_args.augment) 196 | 197 | # Data collator 198 | data_collator = DataCollatorContrastivePretrainDeepmatcher(tokenizer=train_dataset.tokenizer) 199 | 200 | if model_args.model_pretrained_checkpoint: 201 | model = ContrastivePretrainModel(model_args.model_pretrained_checkpoint, len_tokenizer=len(train_dataset.tokenizer), model=model_args.tokenizer, temperature=model_args.temperature) 202 | if model_args.grad_checkpoint: 203 | model.encoder.transformer._set_gradient_checkpointing(model.encoder.transformer.encoder, True) 204 | else: 205 | model = ContrastivePretrainModel(len_tokenizer=len(train_dataset.tokenizer), model=model_args.tokenizer, temperature=model_args.temperature) 206 | if model_args.grad_checkpoint: 207 | model.encoder.transformer._set_gradient_checkpointing(model.encoder.transformer.encoder, True) 208 | 209 | # Initialize our Trainer 210 | trainer = Trainer( 211 | model=model, 212 | args=training_args, 213 | train_dataset=train_dataset if training_args.do_train else None, 214 | eval_dataset=validation_dataset if training_args.do_eval else None, 215 | data_collator=data_collator, 216 | compute_metrics=compute_metrics_bce 217 | ) 218 | trainer.args.save_total_limit = 1 219 | 220 | # Training 221 | if training_args.do_train: 222 | 223 | checkpoint = None 224 | if training_args.resume_from_checkpoint is not None: 225 | checkpoint = training_args.resume_from_checkpoint 226 | elif last_checkpoint is not None: 227 | checkpoint = last_checkpoint 228 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 229 | trainer.save_model() # Saves the tokenizer too for easy upload 230 | 231 | metrics = train_result.metrics 232 | max_train_samples = ( 233 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 234 | ) 235 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 236 | 237 | trainer.log_metrics("train", metrics) 238 | trainer.save_metrics("train", metrics) 239 | trainer.save_state() 240 | 241 | if __name__ == "__main__": 242 | main() -------------------------------------------------------------------------------- /src/contrastive/run_pretraining_deepmatcher_nosplit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run contrastive pre-training 3 | """ 4 | import numpy as np 5 | np.random.seed(42) 6 | import random 7 | random.seed(42) 8 | 9 | import logging 10 | import os 11 | import sys 12 | from dataclasses import dataclass, field 13 | from typing import Optional 14 | import json 15 | 16 | import torch 17 | 18 | import transformers as transformers 19 | 20 | from transformers import ( 21 | HfArgumentParser, 22 | Trainer, 23 | TrainingArguments, 24 | set_seed 25 | ) 26 | from transformers.file_utils import is_offline_mode 27 | from transformers.trainer_utils import get_last_checkpoint 28 | from transformers.utils import check_min_version 29 | from transformers.utils.versions import require_version 30 | 31 | from src.contrastive.models.modeling import ContrastivePretrainModel 32 | from src.contrastive.data.datasets import ContrastivePretrainDatasetDeepmatcher 33 | from src.contrastive.data.data_collators import DataCollatorContrastivePretrainDeepmatcher 34 | from src.contrastive.models.metrics import compute_metrics_bce 35 | 36 | from transformers import EarlyStoppingCallback 37 | 38 | from pdb import set_trace 39 | 40 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 41 | check_min_version("4.8.2") 42 | 43 | logger = logging.getLogger(__name__) 44 | 45 | MODEL_PARAMS=['pool'] 46 | 47 | @dataclass 48 | class ModelArguments: 49 | """ 50 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 51 | """ 52 | 53 | model_pretrained_checkpoint: Optional[str] = field( 54 | default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 55 | ) 56 | do_param_opt: Optional[bool] = field( 57 | default=False, metadata={"help": "If aou want to do hyperparamter optimization"} 58 | ) 59 | grad_checkpoint: Optional[bool] = field( 60 | default=True, metadata={"help": "If aou want to use gradient checkpointing"} 61 | ) 62 | temperature: Optional[float] = field( 63 | default=0.07, 64 | metadata={ 65 | "help": "Temperature for contrastive loss" 66 | }, 67 | ) 68 | tokenizer: Optional[str] = field( 69 | default='huawei-noah/TinyBERT_General_4L_312D', 70 | metadata={ 71 | "help": "Tokenizer to use" 72 | }, 73 | ) 74 | 75 | @dataclass 76 | class DataTrainingArguments: 77 | """ 78 | Arguments pertaining to what data we are going to input our model for training and eval. 79 | """ 80 | 81 | train_file: Optional[str] = field( 82 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 83 | ) 84 | interm_file: Optional[str] = field( 85 | default=None, metadata={"help": "The intermediate training set."} 86 | ) 87 | clean: Optional[bool] = field( 88 | default=False, metadata={"help": "Only use intermediate training set"} 89 | ) 90 | augment: Optional[str] = field( 91 | default=None, metadata={"help": "The data augmentation to use."} 92 | ) 93 | id_deduction_set: Optional[str] = field( 94 | default=None, metadata={"help": "The size of the training set."} 95 | ) 96 | train_size: Optional[str] = field( 97 | default=None, metadata={"help": "The size of the training set."} 98 | ) 99 | max_train_samples: Optional[int] = field( 100 | default=None, 101 | metadata={ 102 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 103 | "value if set." 104 | }, 105 | ) 106 | validation_file: Optional[str] = field( 107 | default=None, 108 | metadata={ 109 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 110 | "(a jsonlines or csv file)." 111 | }, 112 | ) 113 | max_validation_samples: Optional[int] = field( 114 | default=None, 115 | metadata={ 116 | "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " 117 | "value if set." 118 | }, 119 | ) 120 | test_file: Optional[str] = field( 121 | default=None, 122 | metadata={ 123 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)." 124 | }, 125 | ) 126 | max_test_samples: Optional[int] = field( 127 | default=None, 128 | metadata={ 129 | "help": "For debugging purposes or quicker training, truncate the number of test examples to this " 130 | "value if set." 131 | }, 132 | ) 133 | dataset_name: Optional[str] = field( 134 | default='lspc', 135 | metadata={ 136 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 137 | "(a jsonlines or csv file)." 138 | }, 139 | ) 140 | def __post_init__(self): 141 | if self.train_file is None and self.validation_file is None: 142 | raise ValueError("Need a training file.") 143 | 144 | 145 | 146 | def main(): 147 | 148 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 149 | 150 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 151 | 152 | # Setup logging 153 | logging.basicConfig( 154 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 155 | datefmt="%m/%d/%Y %H:%M:%S", 156 | handlers=[logging.StreamHandler(sys.stdout)], 157 | ) 158 | log_level = training_args.get_process_log_level() 159 | logger.setLevel(log_level) 160 | transformers.utils.logging.set_verbosity(log_level) 161 | transformers.utils.logging.enable_default_handler() 162 | transformers.utils.logging.enable_explicit_format() 163 | 164 | # Log on each process the small summary: 165 | logger.warning( 166 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 167 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 168 | ) 169 | logger.info(f"Training/evaluation parameters {training_args}") 170 | 171 | # Detecting last checkpoint. 172 | last_checkpoint = None 173 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 174 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 175 | 176 | # Set seed before initializing model. 177 | set_seed(training_args.seed) 178 | 179 | data_files = {} 180 | if data_args.train_file is not None: 181 | data_files["train"] = data_args.train_file 182 | if data_args.validation_file is not None: 183 | data_files["validation"] = data_args.validation_file 184 | if data_args.test_file is not None: 185 | data_files["test"] = data_args.test_file 186 | raw_datasets = data_files 187 | 188 | if training_args.do_train: 189 | if "train" not in raw_datasets: 190 | raise ValueError("--do_train requires a train dataset") 191 | train_dataset = raw_datasets["train"] 192 | if data_args.interm_file is not None: 193 | train_dataset = ContrastivePretrainDatasetDeepmatcher(train_dataset, tokenizer=model_args.tokenizer, intermediate_set=data_args.interm_file, clean=data_args.clean, dataset=data_args.dataset_name, deduction_set=data_args.id_deduction_set, aug=data_args.augment, split=False) 194 | else: 195 | train_dataset = ContrastivePretrainDatasetDeepmatcher(train_dataset, tokenizer=model_args.tokenizer, clean=data_args.clean, dataset=data_args.dataset_name, deduction_set=data_args.id_deduction_set, aug=data_args.augment, split=False) 196 | 197 | # Data collator 198 | data_collator = DataCollatorContrastivePretrainDeepmatcher(tokenizer=train_dataset.tokenizer) 199 | 200 | if model_args.model_pretrained_checkpoint: 201 | model = ContrastivePretrainModel(model_args.model_pretrained_checkpoint, len_tokenizer=len(train_dataset.tokenizer), model=model_args.tokenizer, temperature=model_args.temperature) 202 | if model_args.grad_checkpoint: 203 | model.encoder.transformer._set_gradient_checkpointing(model.encoder.transformer.encoder, True) 204 | else: 205 | model = ContrastivePretrainModel(len_tokenizer=len(train_dataset.tokenizer), model=model_args.tokenizer, temperature=model_args.temperature) 206 | if model_args.grad_checkpoint: 207 | model.encoder.transformer._set_gradient_checkpointing(model.encoder.transformer.encoder, True) 208 | 209 | # Initialize our Trainer 210 | trainer = Trainer( 211 | model=model, 212 | args=training_args, 213 | train_dataset=train_dataset if training_args.do_train else None, 214 | eval_dataset=validation_dataset if training_args.do_eval else None, 215 | data_collator=data_collator, 216 | compute_metrics=compute_metrics_bce 217 | ) 218 | trainer.args.save_total_limit = 1 219 | 220 | # Training 221 | if training_args.do_train: 222 | 223 | checkpoint = None 224 | if training_args.resume_from_checkpoint is not None: 225 | checkpoint = training_args.resume_from_checkpoint 226 | elif last_checkpoint is not None: 227 | checkpoint = last_checkpoint 228 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 229 | trainer.save_model() # Saves the tokenizer too for easy upload 230 | 231 | metrics = train_result.metrics 232 | max_train_samples = ( 233 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 234 | ) 235 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 236 | 237 | trainer.log_metrics("train", metrics) 238 | trainer.save_metrics("train", metrics) 239 | trainer.save_state() 240 | 241 | if __name__ == "__main__": 242 | main() -------------------------------------------------------------------------------- /src/contrastive/run_pretraining_ssv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run contrastive pre-training self-supervised 3 | """ 4 | import numpy as np 5 | np.random.seed(42) 6 | import random 7 | random.seed(42) 8 | 9 | import logging 10 | import os 11 | import sys 12 | from dataclasses import dataclass, field 13 | from typing import Optional 14 | import json 15 | 16 | import torch 17 | 18 | import transformers as transformers 19 | 20 | from transformers import ( 21 | HfArgumentParser, 22 | Trainer, 23 | TrainingArguments, 24 | set_seed 25 | ) 26 | from transformers.file_utils import is_offline_mode 27 | from transformers.trainer_utils import get_last_checkpoint 28 | from transformers.utils import check_min_version 29 | from transformers.utils.versions import require_version 30 | 31 | from src.contrastive.models.modeling import ContrastiveSelfSupervisedPretrainModel 32 | from src.contrastive.data.datasets import ContrastivePretrainDataset 33 | from src.contrastive.data.data_collators import DataCollatorContrastivePretrainSelfSupervised 34 | from src.contrastive.models.metrics import compute_metrics_bce 35 | 36 | from transformers import EarlyStoppingCallback 37 | 38 | from pdb import set_trace 39 | 40 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 41 | check_min_version("4.8.2") 42 | 43 | logger = logging.getLogger(__name__) 44 | 45 | MODEL_PARAMS=['pool'] 46 | 47 | @dataclass 48 | class ModelArguments: 49 | """ 50 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 51 | """ 52 | 53 | model_pretrained_checkpoint: Optional[str] = field( 54 | default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 55 | ) 56 | do_param_opt: Optional[bool] = field( 57 | default=False, metadata={"help": "If aou want to do hyperparamter optimization"} 58 | ) 59 | grad_checkpoint: Optional[bool] = field( 60 | default=True, metadata={"help": "If aou want to use gradient checkpointing"} 61 | ) 62 | temperature: Optional[float] = field( 63 | default=0.07, 64 | metadata={ 65 | "help": "Temperature for contrastive loss" 66 | }, 67 | ) 68 | tokenizer: Optional[str] = field( 69 | default='huawei-noah/TinyBERT_General_4L_312D', 70 | metadata={ 71 | "help": "Tokenizer to use" 72 | }, 73 | ) 74 | 75 | @dataclass 76 | class DataTrainingArguments: 77 | """ 78 | Arguments pertaining to what data we are going to input our model for training and eval. 79 | """ 80 | 81 | train_file: Optional[str] = field( 82 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 83 | ) 84 | interm_file: Optional[str] = field( 85 | default=None, metadata={"help": "The intermediate training set."} 86 | ) 87 | only_interm: Optional[bool] = field( 88 | default=False, metadata={"help": "Only use intermediate training set"} 89 | ) 90 | id_deduction_set: Optional[str] = field( 91 | default=None, metadata={"help": "The size of the training set."} 92 | ) 93 | augment: Optional[str] = field( 94 | default=None, metadata={"help": "The data augmentation to use."} 95 | ) 96 | train_size: Optional[str] = field( 97 | default=None, metadata={"help": "The size of the training set."} 98 | ) 99 | max_train_samples: Optional[int] = field( 100 | default=None, 101 | metadata={ 102 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 103 | "value if set." 104 | }, 105 | ) 106 | validation_file: Optional[str] = field( 107 | default=None, 108 | metadata={ 109 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 110 | "(a jsonlines or csv file)." 111 | }, 112 | ) 113 | max_validation_samples: Optional[int] = field( 114 | default=None, 115 | metadata={ 116 | "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " 117 | "value if set." 118 | }, 119 | ) 120 | test_file: Optional[str] = field( 121 | default=None, 122 | metadata={ 123 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)." 124 | }, 125 | ) 126 | max_test_samples: Optional[int] = field( 127 | default=None, 128 | metadata={ 129 | "help": "For debugging purposes or quicker training, truncate the number of test examples to this " 130 | "value if set." 131 | }, 132 | ) 133 | dataset_name: Optional[str] = field( 134 | default='lspc', 135 | metadata={ 136 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 137 | "(a jsonlines or csv file)." 138 | }, 139 | ) 140 | def __post_init__(self): 141 | if self.train_file is None and self.validation_file is None: 142 | raise ValueError("Need a training file.") 143 | 144 | 145 | 146 | def main(): 147 | 148 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 149 | 150 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 151 | 152 | # Setup logging 153 | logging.basicConfig( 154 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 155 | datefmt="%m/%d/%Y %H:%M:%S", 156 | handlers=[logging.StreamHandler(sys.stdout)], 157 | ) 158 | log_level = training_args.get_process_log_level() 159 | logger.setLevel(log_level) 160 | transformers.utils.logging.set_verbosity(log_level) 161 | transformers.utils.logging.enable_default_handler() 162 | transformers.utils.logging.enable_explicit_format() 163 | 164 | # Log on each process the small summary: 165 | logger.warning( 166 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 167 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 168 | ) 169 | logger.info(f"Training/evaluation parameters {training_args}") 170 | 171 | # Detecting last checkpoint. 172 | last_checkpoint = None 173 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 174 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 175 | 176 | # Set seed before initializing model. 177 | set_seed(training_args.seed) 178 | 179 | data_files = {} 180 | if data_args.train_file is not None: 181 | data_files["train"] = data_args.train_file 182 | if data_args.validation_file is not None: 183 | data_files["validation"] = data_args.validation_file 184 | if data_args.test_file is not None: 185 | data_files["test"] = data_args.test_file 186 | raw_datasets = data_files 187 | 188 | if training_args.do_train: 189 | if "train" not in raw_datasets: 190 | raise ValueError("--do_train requires a train dataset") 191 | train_dataset = raw_datasets["train"] 192 | if data_args.interm_file is not None: 193 | train_dataset = ContrastivePretrainDataset(train_dataset, tokenizer=model_args.tokenizer, intermediate_set=data_args.interm_file, only_interm=data_args.only_interm, dataset=data_args.dataset_name, deduction_set=data_args.id_deduction_set, aug=data_args.augment) 194 | else: 195 | train_dataset = ContrastivePretrainDataset(train_dataset, tokenizer=model_args.tokenizer, only_interm=data_args.only_interm, dataset=data_args.dataset_name, deduction_set=data_args.id_deduction_set, aug=data_args.augment) 196 | 197 | # Data collator 198 | data_collator = DataCollatorContrastivePretrainSelfSupervised(tokenizer=train_dataset.tokenizer) 199 | 200 | if model_args.model_pretrained_checkpoint: 201 | model = ContrastiveSelfSupervisedPretrainModel(model_args.model_pretrained_checkpoint, len_tokenizer=len(train_dataset.tokenizer), model=model_args.tokenizer, temperature=model_args.temperature) 202 | if model_args.grad_checkpoint: 203 | model.encoder.transformer._set_gradient_checkpointing(model.encoder.transformer.encoder, True) 204 | else: 205 | model = ContrastiveSelfSupervisedPretrainModel(len_tokenizer=len(train_dataset.tokenizer), model=model_args.tokenizer, temperature=model_args.temperature) 206 | if model_args.grad_checkpoint: 207 | model.encoder.transformer._set_gradient_checkpointing(model.encoder.transformer.encoder, True) 208 | 209 | # Initialize our Trainer 210 | trainer = Trainer( 211 | model=model, 212 | args=training_args, 213 | train_dataset=train_dataset if training_args.do_train else None, 214 | eval_dataset=validation_dataset if training_args.do_eval else None, 215 | data_collator=data_collator, 216 | compute_metrics=compute_metrics_bce 217 | ) 218 | trainer.args.save_total_limit = 1 219 | 220 | # Training 221 | if training_args.do_train: 222 | 223 | checkpoint = None 224 | if training_args.resume_from_checkpoint is not None: 225 | checkpoint = training_args.resume_from_checkpoint 226 | elif last_checkpoint is not None: 227 | checkpoint = last_checkpoint 228 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 229 | trainer.save_model() # Saves the tokenizer too for easy upload 230 | 231 | metrics = train_result.metrics 232 | max_train_samples = ( 233 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 234 | ) 235 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 236 | 237 | trainer.log_metrics("train", metrics) 238 | trainer.save_metrics("train", metrics) 239 | trainer.save_state() 240 | 241 | if __name__ == "__main__": 242 | main() -------------------------------------------------------------------------------- /src/contrastive/run_pretraining.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run contrastive pre-training 3 | """ 4 | import numpy as np 5 | np.random.seed(42) 6 | import random 7 | random.seed(42) 8 | 9 | import logging 10 | import os 11 | import sys 12 | from dataclasses import dataclass, field 13 | from typing import Optional 14 | import json 15 | 16 | import torch 17 | 18 | import transformers as transformers 19 | 20 | from transformers import ( 21 | HfArgumentParser, 22 | Trainer, 23 | TrainingArguments, 24 | set_seed 25 | ) 26 | from transformers.file_utils import is_offline_mode 27 | from transformers.trainer_utils import get_last_checkpoint 28 | from transformers.utils import check_min_version 29 | from transformers.utils.versions import require_version 30 | 31 | from src.contrastive.models.modeling import ContrastivePretrainModel 32 | from src.contrastive.data.datasets import ContrastivePretrainDataset 33 | from src.contrastive.data.data_collators import DataCollatorContrastivePretrain 34 | from src.contrastive.models.metrics import compute_metrics_bce 35 | 36 | from transformers import EarlyStoppingCallback 37 | 38 | from pdb import set_trace 39 | 40 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 41 | check_min_version("4.8.2") 42 | 43 | logger = logging.getLogger(__name__) 44 | 45 | MODEL_PARAMS=['pool'] 46 | 47 | @dataclass 48 | class ModelArguments: 49 | """ 50 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 51 | """ 52 | 53 | model_pretrained_checkpoint: Optional[str] = field( 54 | default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 55 | ) 56 | do_param_opt: Optional[bool] = field( 57 | default=False, metadata={"help": "If you want to do hyperparameter optimization"} 58 | ) 59 | grad_checkpoint: Optional[bool] = field( 60 | default=True, metadata={"help": "If you want to use gradient checkpointing"} 61 | ) 62 | temperature: Optional[float] = field( 63 | default=0.07, 64 | metadata={ 65 | "help": "Temperature for contrastive loss" 66 | }, 67 | ) 68 | tokenizer: Optional[str] = field( 69 | default='huawei-noah/TinyBERT_General_4L_312D', 70 | metadata={ 71 | "help": "Tokenizer to use" 72 | }, 73 | ) 74 | 75 | @dataclass 76 | class DataTrainingArguments: 77 | """ 78 | Arguments pertaining to what data we are going to input our model for training and eval. 79 | """ 80 | 81 | train_file: Optional[str] = field( 82 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 83 | ) 84 | interm_file: Optional[str] = field( 85 | default=None, metadata={"help": "The intermediate training set."} 86 | ) 87 | only_interm: Optional[bool] = field( 88 | default=False, metadata={"help": "Only use intermediate training set"} 89 | ) 90 | clean: Optional[bool] = field( 91 | default=False, metadata={"help": "Only use intermediate training set"} 92 | ) 93 | augment: Optional[str] = field( 94 | default=None, metadata={"help": "The data augmentation to use."} 95 | ) 96 | id_deduction_set: Optional[str] = field( 97 | default=None, metadata={"help": "The size of the training set."} 98 | ) 99 | train_size: Optional[str] = field( 100 | default=None, metadata={"help": "The size of the training set."} 101 | ) 102 | max_train_samples: Optional[int] = field( 103 | default=None, 104 | metadata={ 105 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 106 | "value if set." 107 | }, 108 | ) 109 | validation_file: Optional[str] = field( 110 | default=None, 111 | metadata={ 112 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 113 | "(a jsonlines or csv file)." 114 | }, 115 | ) 116 | max_validation_samples: Optional[int] = field( 117 | default=None, 118 | metadata={ 119 | "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " 120 | "value if set." 121 | }, 122 | ) 123 | test_file: Optional[str] = field( 124 | default=None, 125 | metadata={ 126 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)." 127 | }, 128 | ) 129 | max_test_samples: Optional[int] = field( 130 | default=None, 131 | metadata={ 132 | "help": "For debugging purposes or quicker training, truncate the number of test examples to this " 133 | "value if set." 134 | }, 135 | ) 136 | dataset_name: Optional[str] = field( 137 | default='lspc', 138 | metadata={ 139 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 140 | "(a jsonlines or csv file)." 141 | }, 142 | ) 143 | def __post_init__(self): 144 | if self.train_file is None and self.validation_file is None: 145 | raise ValueError("Need a training file.") 146 | 147 | 148 | 149 | def main(): 150 | 151 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 152 | 153 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 154 | 155 | # Setup logging 156 | logging.basicConfig( 157 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 158 | datefmt="%m/%d/%Y %H:%M:%S", 159 | handlers=[logging.StreamHandler(sys.stdout)], 160 | ) 161 | log_level = training_args.get_process_log_level() 162 | logger.setLevel(log_level) 163 | transformers.utils.logging.set_verbosity(log_level) 164 | transformers.utils.logging.enable_default_handler() 165 | transformers.utils.logging.enable_explicit_format() 166 | 167 | # Log on each process the small summary: 168 | logger.warning( 169 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 170 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 171 | ) 172 | logger.info(f"Training/evaluation parameters {training_args}") 173 | 174 | # Detecting last checkpoint. 175 | last_checkpoint = None 176 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 177 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 178 | 179 | # Set seed before initializing model. 180 | set_seed(training_args.seed) 181 | 182 | data_files = {} 183 | if data_args.train_file is not None: 184 | data_files["train"] = data_args.train_file 185 | if data_args.validation_file is not None: 186 | data_files["validation"] = data_args.validation_file 187 | if data_args.test_file is not None: 188 | data_files["test"] = data_args.test_file 189 | raw_datasets = data_files 190 | 191 | if training_args.do_train: 192 | if "train" not in raw_datasets: 193 | raise ValueError("--do_train requires a train dataset") 194 | train_dataset = raw_datasets["train"] 195 | if data_args.interm_file is not None: 196 | train_dataset = ContrastivePretrainDataset(train_dataset, tokenizer=model_args.tokenizer, intermediate_set=data_args.interm_file, only_interm=data_args.only_interm, clean=data_args.clean, dataset=data_args.dataset_name, deduction_set=data_args.id_deduction_set, aug=data_args.augment) 197 | else: 198 | train_dataset = ContrastivePretrainDataset(train_dataset, tokenizer=model_args.tokenizer, only_interm=data_args.only_interm, clean=data_args.clean, dataset=data_args.dataset_name, deduction_set=data_args.id_deduction_set, aug=data_args.augment) 199 | 200 | # Data collator 201 | data_collator = DataCollatorContrastivePretrain(tokenizer=train_dataset.tokenizer) 202 | 203 | if model_args.model_pretrained_checkpoint: 204 | model = ContrastivePretrainModel(model_args.model_pretrained_checkpoint, len_tokenizer=len(train_dataset.tokenizer), model=model_args.tokenizer, temperature=model_args.temperature) 205 | if model_args.grad_checkpoint: 206 | model.encoder.transformer._set_gradient_checkpointing(model.encoder.transformer.encoder, True) 207 | else: 208 | model = ContrastivePretrainModel(len_tokenizer=len(train_dataset.tokenizer), model=model_args.tokenizer, temperature=model_args.temperature) 209 | if model_args.grad_checkpoint: 210 | model.encoder.transformer._set_gradient_checkpointing(model.encoder.transformer.encoder, True) 211 | 212 | # Initialize our Trainer 213 | trainer = Trainer( 214 | model=model, 215 | args=training_args, 216 | train_dataset=train_dataset if training_args.do_train else None, 217 | eval_dataset=validation_dataset if training_args.do_eval else None, 218 | data_collator=data_collator, 219 | compute_metrics=compute_metrics_bce 220 | ) 221 | trainer.args.save_total_limit = 1 222 | 223 | # Training 224 | if training_args.do_train: 225 | 226 | checkpoint = None 227 | if training_args.resume_from_checkpoint is not None: 228 | checkpoint = training_args.resume_from_checkpoint 229 | elif last_checkpoint is not None: 230 | checkpoint = last_checkpoint 231 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 232 | trainer.save_model() # Saves the tokenizer too for easy upload 233 | 234 | metrics = train_result.metrics 235 | max_train_samples = ( 236 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 237 | ) 238 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 239 | 240 | trainer.log_metrics("train", metrics) 241 | trainer.save_metrics("train", metrics) 242 | trainer.save_state() 243 | 244 | if __name__ == "__main__": 245 | main() -------------------------------------------------------------------------------- /contrastive-product-matching.yml: -------------------------------------------------------------------------------- 1 | name: contrastive-product-matching 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - alabaster=0.7.12=pyhd3eb1b0_0 10 | - alembic=1.6.5=pyhd8ed1ab_0 11 | - anaconda=2021.04=py38_0 12 | - anaconda-client=1.7.2=py38_0 13 | - anaconda-project=0.9.1=pyhd3eb1b0_1 14 | - anyio=2.2.0=py38h06a4308_1 15 | - appdirs=1.4.4=py_0 16 | - argh=0.26.2=py38_0 17 | - argon2-cffi=20.1.0=py38h27cfd23_1 18 | - asn1crypto=1.4.0=py_0 19 | - astroid=2.5=py38h06a4308_1 20 | - astropy=4.2.1=py38h27cfd23_1 21 | - async_generator=1.10=pyhd3eb1b0_0 22 | - atomicwrites=1.4.0=py_0 23 | - attrs=20.3.0=pyhd3eb1b0_0 24 | - autopep8=1.5.6=pyhd3eb1b0_0 25 | - babel=2.9.0=pyhd3eb1b0_0 26 | - backcall=0.2.0=pyhd3eb1b0_0 27 | - backports=1.0=pyhd3eb1b0_2 28 | - backports.shutil_get_terminal_size=1.0.0=pyhd3eb1b0_3 29 | - beautifulsoup4=4.9.3=pyha847dfd_0 30 | - bitarray=1.9.2=py38h27cfd23_1 31 | - bkcharts=0.2=py38_0 32 | - black=19.10b0=py_0 33 | - blas=1.0=mkl 34 | - bleach=3.3.0=pyhd3eb1b0_0 35 | - blosc=1.21.0=h8c45485_0 36 | - bokeh=2.3.1=py38h06a4308_0 37 | - boto=2.49.0=py38_0 38 | - bottleneck=1.3.2=py38heb32a55_1 39 | - brotlipy=0.7.0=py38h27cfd23_1003 40 | - bzip2=1.0.8=h7b6447c_0 41 | - c-ares=1.17.1=h27cfd23_0 42 | - ca-certificates=2021.4.13=h06a4308_1 43 | - cairo=1.16.0=hf32fb01_1 44 | - certifi=2020.12.5=py38h06a4308_0 45 | - cffi=1.14.5=py38h261ae71_0 46 | - chardet=4.0.0=py38h06a4308_1003 47 | - click=7.1.2=pyhd3eb1b0_0 48 | - cliff=3.8.0=pyhd8ed1ab_0 49 | - cloudpickle=1.6.0=py_0 50 | - clyent=1.2.2=py38_1 51 | - cmaes=0.8.2=pyh44b312d_0 52 | - cmd2=2.1.2=py38h578d9bd_0 53 | - colorama=0.4.4=pyhd3eb1b0_0 54 | - colorlog=5.0.1=py38h578d9bd_0 55 | - contextlib2=0.6.0.post1=py_0 56 | - cryptography=3.4.7=py38hd23ed53_0 57 | - cudatoolkit=11.1.74=h6bb024c_0 58 | - curl=7.71.1=hbc83047_1 59 | - cycler=0.10.0=py38_0 60 | - cython=0.29.23=py38h2531618_0 61 | - cytoolz=0.11.0=py38h7b6447c_0 62 | - dask=2021.4.0=pyhd3eb1b0_0 63 | - dask-core=2021.4.0=pyhd3eb1b0_0 64 | - dbus=1.13.18=hb2f20db_0 65 | - decorator=5.0.6=pyhd3eb1b0_0 66 | - defusedxml=0.7.1=pyhd3eb1b0_0 67 | - diff-match-patch=20200713=py_0 68 | - distributed=2021.4.0=py38h06a4308_0 69 | - docutils=0.17=py38h06a4308_1 70 | - entrypoints=0.3=py38_0 71 | - et_xmlfile=1.0.1=py_1001 72 | - expat=2.3.0=h2531618_2 73 | - fastcache=1.1.0=py38h7b6447c_0 74 | - ffmpeg=4.3=hf484d3e_0 75 | - filelock=3.0.12=pyhd3eb1b0_1 76 | - flake8=3.9.0=pyhd3eb1b0_0 77 | - flask=1.1.2=pyhd3eb1b0_0 78 | - fontconfig=2.13.1=h6c09931_0 79 | - freetype=2.10.4=h5ab3b9f_0 80 | - fribidi=1.0.10=h7b6447c_0 81 | - fsspec=0.9.0=pyhd3eb1b0_0 82 | - future=0.18.2=py38_1 83 | - get_terminal_size=1.0.0=haa9412d_0 84 | - gevent=21.1.2=py38h27cfd23_1 85 | - glib=2.68.1=h36276a3_0 86 | - glob2=0.7=pyhd3eb1b0_0 87 | - gmp=6.2.1=h2531618_2 88 | - gmpy2=2.0.8=py38hd5f6e3b_3 89 | - gnutls=3.6.15=he1e5248_0 90 | - graphite2=1.3.14=h23475e2_0 91 | - greenlet=1.0.0=py38h2531618_2 92 | - gst-plugins-base=1.14.0=h8213a91_2 93 | - gstreamer=1.14.0=h28cd5cc_2 94 | - h5py=2.10.0=py38h7918eee_0 95 | - harfbuzz=2.8.0=h6f93f22_0 96 | - hdf5=1.10.4=hb1b8bf9_0 97 | - heapdict=1.0.1=py_0 98 | - html5lib=1.1=py_0 99 | - icu=58.2=he6710b0_3 100 | - idna=2.10=pyhd3eb1b0_0 101 | - imageio=2.9.0=pyhd3eb1b0_0 102 | - imagesize=1.2.0=pyhd3eb1b0_0 103 | - importlib-metadata=3.10.0=py38h06a4308_0 104 | - importlib_metadata=3.10.0=hd3eb1b0_0 105 | - iniconfig=1.1.1=pyhd3eb1b0_0 106 | - intel-openmp=2021.2.0=h06a4308_610 107 | - intervaltree=3.1.0=py_0 108 | - ipykernel=5.3.4=py38h5ca1d4c_0 109 | - ipython=7.22.0=py38hb070fc8_0 110 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 111 | - ipywidgets=7.6.3=pyhd3eb1b0_1 112 | - isort=5.8.0=pyhd3eb1b0_0 113 | - itsdangerous=1.1.0=pyhd3eb1b0_0 114 | - jbig=2.1=hdba287a_0 115 | - jdcal=1.4.1=py_0 116 | - jedi=0.17.2=py38h06a4308_1 117 | - jeepney=0.6.0=pyhd3eb1b0_0 118 | - jinja2=2.11.3=pyhd3eb1b0_0 119 | - joblib=1.0.1=pyhd3eb1b0_0 120 | - jpeg=9b=h024ee3a_2 121 | - json5=0.9.5=py_0 122 | - jsonschema=3.2.0=py_2 123 | - jupyter=1.0.0=py38_7 124 | - jupyter-packaging=0.7.12=pyhd3eb1b0_0 125 | - jupyter_client=6.1.12=pyhd3eb1b0_0 126 | - jupyter_console=6.4.0=pyhd3eb1b0_0 127 | - jupyter_core=4.7.1=py38h06a4308_0 128 | - jupyter_server=1.4.1=py38h06a4308_0 129 | - jupyterlab=3.0.14=pyhd3eb1b0_1 130 | - jupyterlab_pygments=0.1.2=py_0 131 | - jupyterlab_server=2.4.0=pyhd3eb1b0_0 132 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 133 | - keyring=22.3.0=py38h06a4308_0 134 | - kiwisolver=1.3.1=py38h2531618_0 135 | - krb5=1.18.2=h173b8e3_0 136 | - lame=3.100=h7b6447c_0 137 | - lazy-object-proxy=1.6.0=py38h27cfd23_0 138 | - lcms2=2.12=h3be6417_0 139 | - ld_impl_linux-64=2.33.1=h53a641e_7 140 | - libarchive=3.4.2=h62408e4_0 141 | - libcurl=7.71.1=h20c2e04_1 142 | - libedit=3.1.20210216=h27cfd23_1 143 | - libev=4.33=h7b6447c_0 144 | - libffi=3.3=he6710b0_2 145 | - libgcc-ng=9.1.0=hdf63c60_0 146 | - libgfortran-ng=7.3.0=hdf63c60_0 147 | - libiconv=1.15=h63c8f33_5 148 | - libidn2=2.3.1=h27cfd23_0 149 | - liblief=0.10.1=he6710b0_0 150 | - libllvm10=10.0.1=hbcb73fb_5 151 | - libpng=1.6.37=hbc83047_0 152 | - libsodium=1.0.18=h7b6447c_0 153 | - libspatialindex=1.9.3=h2531618_0 154 | - libssh2=1.9.0=h1ba5d50_1 155 | - libstdcxx-ng=9.1.0=hdf63c60_0 156 | - libtasn1=4.16.0=h27cfd23_0 157 | - libtiff=4.2.0=h85742a9_0 158 | - libtool=2.4.6=h7b6447c_1005 159 | - libunistring=0.9.10=h27cfd23_0 160 | - libuuid=1.0.3=h1bed415_2 161 | - libuv=1.40.0=h7b6447c_0 162 | - libwebp-base=1.2.0=h27cfd23_0 163 | - libxcb=1.14=h7b6447c_0 164 | - libxml2=2.9.10=hb55368b_3 165 | - libxslt=1.1.34=hc22bd24_0 166 | - llvmlite=0.36.0=py38h612dafd_4 167 | - locket=0.2.1=py38h06a4308_1 168 | - lxml=4.6.3=py38h9120a33_0 169 | - lz4-c=1.9.3=h2531618_0 170 | - lzo=2.10=h7b6447c_2 171 | - mako=1.1.4=pyh44b312d_0 172 | - markupsafe=1.1.1=py38h7b6447c_0 173 | - matplotlib=3.3.4=py38h06a4308_0 174 | - matplotlib-base=3.3.4=py38h62a2d02_0 175 | - mccabe=0.6.1=py38_1 176 | - mistune=0.8.4=py38h7b6447c_1000 177 | - mkl=2021.2.0=h06a4308_296 178 | - mkl-service=2.3.0=py38h27cfd23_1 179 | - mkl_fft=1.3.0=py38h42c9631_2 180 | - mkl_random=1.2.1=py38ha9443f7_2 181 | - mock=4.0.3=pyhd3eb1b0_0 182 | - more-itertools=8.7.0=pyhd3eb1b0_0 183 | - mpc=1.1.0=h10f8cd9_1 184 | - mpfr=4.0.2=hb69a4c5_1 185 | - mpmath=1.2.1=py38h06a4308_0 186 | - msgpack-python=1.0.2=py38hff7bd54_1 187 | - multipledispatch=0.6.0=py38_0 188 | - mypy_extensions=0.4.3=py38_0 189 | - nbclassic=0.2.6=pyhd3eb1b0_0 190 | - nbclient=0.5.3=pyhd3eb1b0_0 191 | - nbconvert=6.0.7=py38_0 192 | - nbformat=5.1.3=pyhd3eb1b0_0 193 | - ncurses=6.2=he6710b0_1 194 | - nest-asyncio=1.5.1=pyhd3eb1b0_0 195 | - nettle=3.7.3=hbbd107a_1 196 | - networkx=2.5=py_0 197 | - ninja=1.10.2=hff7bd54_1 198 | - nltk=3.6.1=pyhd3eb1b0_0 199 | - nose=1.3.7=pyhd3eb1b0_1006 200 | - notebook=6.3.0=py38h06a4308_0 201 | - numexpr=2.7.3=py38h22e1b3c_1 202 | - numpy=1.20.1=py38h93e21f0_0 203 | - numpy-base=1.20.1=py38h7d8b39e_0 204 | - numpydoc=1.1.0=pyhd3eb1b0_1 205 | - olefile=0.46=py_0 206 | - openh264=2.1.0=hd408876_0 207 | - openpyxl=3.0.7=pyhd3eb1b0_0 208 | - openssl=1.1.1k=h27cfd23_0 209 | - optuna=2.8.0=pyhd8ed1ab_0 210 | - packaging=20.9=pyhd3eb1b0_0 211 | - pandas=1.2.4=py38h2531618_0 212 | - pandoc=2.12=h06a4308_0 213 | - pandocfilters=1.4.3=py38h06a4308_1 214 | - pango=1.45.3=hd140c19_0 215 | - parso=0.7.0=py_0 216 | - partd=1.2.0=pyhd3eb1b0_0 217 | - patchelf=0.12=h2531618_1 218 | - path=15.1.2=py38h06a4308_0 219 | - path.py=12.5.0=0 220 | - pathlib2=2.3.5=py38h06a4308_2 221 | - pathspec=0.7.0=py_0 222 | - patsy=0.5.1=py38_0 223 | - pbr=5.6.0=pyhd8ed1ab_0 224 | - pcre=8.44=he6710b0_0 225 | - pep8=1.7.1=py38_0 226 | - pexpect=4.8.0=pyhd3eb1b0_3 227 | - pickleshare=0.7.5=pyhd3eb1b0_1003 228 | - pillow=8.2.0=py38he98fc37_0 229 | - pip=21.0.1=py38h06a4308_0 230 | - pixman=0.40.0=h7b6447c_0 231 | - pkginfo=1.7.0=py38h06a4308_0 232 | - pluggy=0.13.1=py38h06a4308_0 233 | - ply=3.11=py38_0 234 | - prettytable=2.1.0=pyhd8ed1ab_0 235 | - prometheus_client=0.10.1=pyhd3eb1b0_0 236 | - prompt-toolkit=3.0.17=pyh06a4308_0 237 | - prompt_toolkit=3.0.17=hd3eb1b0_0 238 | - psutil=5.8.0=py38h27cfd23_1 239 | - ptyprocess=0.7.0=pyhd3eb1b0_2 240 | - py=1.10.0=pyhd3eb1b0_0 241 | - py-lief=0.10.1=py38h403a769_0 242 | - pycodestyle=2.6.0=pyhd3eb1b0_0 243 | - pycosat=0.6.3=py38h7b6447c_1 244 | - pycparser=2.20=py_2 245 | - pycurl=7.43.0.6=py38h1ba5d50_0 246 | - pydocstyle=6.0.0=pyhd3eb1b0_0 247 | - pyerfa=1.7.3=py38h27cfd23_0 248 | - pyflakes=2.2.0=pyhd3eb1b0_0 249 | - pygments=2.8.1=pyhd3eb1b0_0 250 | - pylint=2.7.4=py38h06a4308_1 251 | - pyls-black=0.4.6=hd3eb1b0_0 252 | - pyls-spyder=0.3.2=pyhd3eb1b0_0 253 | - pyodbc=4.0.30=py38he6710b0_0 254 | - pyopenssl=20.0.1=pyhd3eb1b0_1 255 | - pyparsing=2.4.7=pyhd3eb1b0_0 256 | - pyperclip=1.8.2=pyhd8ed1ab_2 257 | - pyqt=5.9.2=py38h05f1152_4 258 | - pyrsistent=0.17.3=py38h7b6447c_0 259 | - pysocks=1.7.1=py38h06a4308_0 260 | - pytables=3.6.1=py38h9fd0a39_0 261 | - pytest=6.2.3=py38h06a4308_2 262 | - python=3.8.8=hdb3f193_5 263 | - python-dateutil=2.8.1=pyhd3eb1b0_0 264 | - python-editor=1.0.4=py_0 265 | - python-jsonrpc-server=0.4.0=py_0 266 | - python-language-server=0.36.2=pyhd3eb1b0_0 267 | - python-libarchive-c=2.9=pyhd3eb1b0_1 268 | - python_abi=3.8=2_cp38 269 | - pytorch=1.9.0=py3.8_cuda11.1_cudnn8.0.5_0 270 | - pytz=2021.1=pyhd3eb1b0_0 271 | - pywavelets=1.1.1=py38h7b6447c_2 272 | - pyxdg=0.27=pyhd3eb1b0_0 273 | - pyyaml=5.4.1=py38h27cfd23_1 274 | - pyzmq=20.0.0=py38h2531618_1 275 | - qdarkstyle=2.8.1=py_0 276 | - qt=5.9.7=h5867ecd_1 277 | - qtawesome=1.0.2=pyhd3eb1b0_0 278 | - qtconsole=5.0.3=pyhd3eb1b0_0 279 | - qtpy=1.9.0=py_0 280 | - readline=8.1=h27cfd23_0 281 | - regex=2021.4.4=py38h27cfd23_0 282 | - requests=2.25.1=pyhd3eb1b0_0 283 | - ripgrep=12.1.1=0 284 | - rope=0.18.0=py_0 285 | - rtree=0.9.7=py38h06a4308_1 286 | - ruamel_yaml=0.15.100=py38h27cfd23_0 287 | - scikit-image=0.18.1=py38ha9443f7_0 288 | - scikit-learn=0.24.1=py38ha9443f7_0 289 | - scipy=1.6.2=py38had2a1c9_1 290 | - seaborn=0.11.1=pyhd3eb1b0_0 291 | - secretstorage=3.3.1=py38h06a4308_0 292 | - send2trash=1.5.0=pyhd3eb1b0_1 293 | - setuptools=52.0.0=py38h06a4308_0 294 | - simplegeneric=0.8.1=py38_2 295 | - singledispatch=3.6.1=pyhd3eb1b0_1001 296 | - sip=4.19.13=py38he6710b0_0 297 | - six=1.15.0=py38h06a4308_0 298 | - sniffio=1.2.0=py38h06a4308_1 299 | - snowballstemmer=2.1.0=pyhd3eb1b0_0 300 | - sortedcollections=2.1.0=pyhd3eb1b0_0 301 | - sortedcontainers=2.3.0=pyhd3eb1b0_0 302 | - soupsieve=2.2.1=pyhd3eb1b0_0 303 | - sphinx=3.5.3=pyhd3eb1b0_0 304 | - sphinxcontrib=1.0=py38_1 305 | - sphinxcontrib-applehelp=1.0.2=pyhd3eb1b0_0 306 | - sphinxcontrib-devhelp=1.0.2=pyhd3eb1b0_0 307 | - sphinxcontrib-htmlhelp=1.0.3=pyhd3eb1b0_0 308 | - sphinxcontrib-jsmath=1.0.1=pyhd3eb1b0_0 309 | - sphinxcontrib-qthelp=1.0.3=pyhd3eb1b0_0 310 | - sphinxcontrib-serializinghtml=1.1.4=pyhd3eb1b0_0 311 | - sphinxcontrib-websupport=1.2.4=py_0 312 | - spyder=4.2.5=py38h06a4308_0 313 | - spyder-kernels=1.10.2=py38h06a4308_0 314 | - sqlalchemy=1.4.7=py38h27cfd23_0 315 | - sqlite=3.35.4=hdfb4753_0 316 | - statsmodels=0.12.2=py38h27cfd23_0 317 | - stevedore=3.3.0=py38h578d9bd_1 318 | - sympy=1.8=py38h06a4308_0 319 | - tblib=1.7.0=py_0 320 | - terminado=0.9.4=py38h06a4308_0 321 | - testpath=0.4.4=pyhd3eb1b0_0 322 | - textdistance=4.2.1=pyhd3eb1b0_0 323 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 324 | - three-merge=0.1.1=pyhd3eb1b0_0 325 | - tifffile=2020.10.1=py38hdd07704_2 326 | - tk=8.6.10=hbc83047_0 327 | - toml=0.10.2=pyhd3eb1b0_0 328 | - toolz=0.11.1=pyhd3eb1b0_0 329 | - torchaudio=0.9.0=py38 330 | - torchvision=0.10.0=py38_cu111 331 | - tornado=6.1=py38h27cfd23_0 332 | - tqdm=4.59.0=pyhd3eb1b0_1 333 | - traitlets=5.0.5=pyhd3eb1b0_0 334 | - typed-ast=1.4.2=py38h27cfd23_1 335 | - typing_extensions=3.7.4.3=pyha847dfd_0 336 | - ujson=4.0.2=py38h2531618_0 337 | - unicodecsv=0.14.1=py38_0 338 | - unixodbc=2.3.9=h7b6447c_0 339 | - urllib3=1.26.4=pyhd3eb1b0_0 340 | - watchdog=1.0.2=py38h06a4308_1 341 | - wcwidth=0.2.5=py_0 342 | - webencodings=0.5.1=py38_1 343 | - werkzeug=1.0.1=pyhd3eb1b0_0 344 | - wheel=0.36.2=pyhd3eb1b0_0 345 | - widgetsnbextension=3.5.1=py38_0 346 | - wrapt=1.12.1=py38h7b6447c_1 347 | - wurlitzer=2.1.0=py38h06a4308_0 348 | - xlrd=2.0.1=pyhd3eb1b0_0 349 | - xlsxwriter=1.3.8=pyhd3eb1b0_0 350 | - xlwt=1.3.0=py38_0 351 | - xz=5.2.5=h7b6447c_0 352 | - yaml=0.2.5=h7b6447c_0 353 | - yapf=0.31.0=pyhd3eb1b0_0 354 | - zeromq=4.3.4=h2531618_0 355 | - zict=2.0.0=pyhd3eb1b0_0 356 | - zipp=3.4.1=pyhd3eb1b0_0 357 | - zlib=1.2.11=h7b6447c_3 358 | - zope=1.0=py38_1 359 | - zope.event=4.5.0=py38_0 360 | - zope.interface=5.3.0=py38h27cfd23_0 361 | - zstd=1.4.5=h9ceee32_0 362 | - pip: 363 | - absl-py==0.13.0 364 | - aiohttp==3.7.4.post0 365 | - aiohttp-cors==0.7.0 366 | - aioredis==1.3.1 367 | - async-timeout==3.0.1 368 | - autograd==1.3 369 | - axial-attention==0.5.0 370 | - blessings==1.7 371 | - cachetools==4.2.2 372 | - cma==2.7.0 373 | - google-api-core==1.31.0 374 | - google-auth==1.32.0 375 | - google-auth-oauthlib==0.4.4 376 | - googleapis-common-protos==1.53.0 377 | - gpustat==0.6.0 378 | - gpy==1.10.0 379 | - gpytorch==1.5.0 380 | - grpcio==1.38.1 381 | - hebo==0.1.0 382 | - hiredis==2.0.0 383 | - huggingface-hub==0.1.2 384 | - markdown==3.3.4 385 | - multidict==5.1.0 386 | - nlpaug==1.1.9 387 | - nvidia-ml-py3==7.352.0 388 | - oauthlib==3.1.1 389 | - opencensus==0.7.13 390 | - opencensus-context==0.1.2 391 | - paramz==0.9.5 392 | - protobuf==3.17.3 393 | - py-spy==0.3.7 394 | - pyasn1==0.4.8 395 | - pyasn1-modules==0.2.8 396 | - pydantic==1.8.2 397 | - pymoo==0.4.2.2 398 | - ray==1.4.1 399 | - redis==3.5.3 400 | - requests-oauthlib==1.3.0 401 | - rsa==4.7.2 402 | - sacremoses==0.0.45 403 | - scikit-multilearn==0.2.0 404 | - tabulate==0.8.9 405 | - tensorboard==2.5.0 406 | - tensorboard-data-server==0.6.1 407 | - tensorboard-plugin-wit==1.8.0 408 | - tensorboardx==2.4 409 | - tokenizers==0.10.3 410 | - transformers==4.12.5 411 | - yarl==1.6.3 412 | prefix: /home/rpeeters/anaconda3/envs/contrastive-product-matching 413 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | aiohttp==3.7.4.post0 3 | aiohttp-cors==0.7.0 4 | aioredis==1.3.1 5 | alabaster @ file:///home/ktietz/src/ci/alabaster_1611921544520/work 6 | alembic @ file:///home/conda/feedstock_root/build_artifacts/alembic_1622150326904/work 7 | anaconda-client==1.7.2 8 | anaconda-project @ file:///tmp/build/80754af9/anaconda-project_1610472525955/work 9 | anyio @ file:///tmp/build/80754af9/anyio_1617783275907/work/dist 10 | appdirs==1.4.4 11 | argh==0.26.2 12 | argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1613037097816/work 13 | asn1crypto @ file:///tmp/build/80754af9/asn1crypto_1596577642040/work 14 | astroid @ file:///tmp/build/80754af9/astroid_1613500854201/work 15 | astropy @ file:///tmp/build/80754af9/astropy_1617745353437/work 16 | async-generator @ file:///home/ktietz/src/ci/async_generator_1611927993394/work 17 | async-timeout==3.0.1 18 | atomicwrites==1.4.0 19 | attrs @ file:///tmp/build/80754af9/attrs_1604765588209/work 20 | autograd==1.3 21 | autopep8 @ file:///tmp/build/80754af9/autopep8_1615918855173/work 22 | axial-attention==0.5.0 23 | Babel @ file:///tmp/build/80754af9/babel_1607110387436/work 24 | backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work 25 | backports.shutil-get-terminal-size @ file:///tmp/build/80754af9/backports.shutil_get_terminal_size_1608222128777/work 26 | beautifulsoup4 @ file:///home/linux1/recipes/ci/beautifulsoup4_1610988766420/work 27 | bitarray @ file:///tmp/build/80754af9/bitarray_1618431750766/work 28 | bkcharts==0.2 29 | black==19.10b0 30 | bleach @ file:///tmp/build/80754af9/bleach_1612211392645/work 31 | blessings==1.7 32 | bokeh @ file:///tmp/build/80754af9/bokeh_1617824541184/work 33 | boto==2.49.0 34 | Bottleneck==1.3.2 35 | brotlipy==0.7.0 36 | cachetools==4.2.2 37 | certifi==2020.12.5 38 | cffi @ file:///tmp/build/80754af9/cffi_1613246945912/work 39 | chardet @ file:///tmp/build/80754af9/chardet_1607706746162/work 40 | click @ file:///home/linux1/recipes/ci/click_1610990599742/work 41 | cliff @ file:///home/conda/feedstock_root/build_artifacts/cliff_1622119770880/work 42 | cloudpickle @ file:///tmp/build/80754af9/cloudpickle_1598884132938/work 43 | clyent==1.2.2 44 | cma==2.7.0 45 | cmaes @ file:///home/conda/feedstock_root/build_artifacts/cmaes_1613785714721/work 46 | cmd2 @ file:///home/conda/feedstock_root/build_artifacts/cmd2_1625509081997/work 47 | colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work 48 | colorlog==5.0.1 49 | contextlib2==0.6.0.post1 50 | cryptography @ file:///tmp/build/80754af9/cryptography_1616769286105/work 51 | cycler==0.10.0 52 | Cython @ file:///tmp/build/80754af9/cython_1618435160151/work 53 | cytoolz==0.11.0 54 | dask @ file:///tmp/build/80754af9/dask-core_1617390489108/work 55 | decorator @ file:///tmp/build/80754af9/decorator_1617916966915/work 56 | defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work 57 | diff-match-patch @ file:///tmp/build/80754af9/diff-match-patch_1594828741838/work 58 | distributed @ file:///tmp/build/80754af9/distributed_1617381497899/work 59 | docutils @ file:///tmp/build/80754af9/docutils_1617624660125/work 60 | entrypoints==0.3 61 | et-xmlfile==1.0.1 62 | fastcache==1.1.0 63 | filelock @ file:///home/linux1/recipes/ci/filelock_1610993975404/work 64 | flake8 @ file:///tmp/build/80754af9/flake8_1615834841867/work 65 | Flask @ file:///home/ktietz/src/ci/flask_1611932660458/work 66 | fsspec @ file:///tmp/build/80754af9/fsspec_1617959894824/work 67 | future==0.18.2 68 | gevent @ file:///tmp/build/80754af9/gevent_1616770671827/work 69 | glob2 @ file:///home/linux1/recipes/ci/glob2_1610991677669/work 70 | gmpy2==2.0.8 71 | google-api-core==1.31.0 72 | google-auth==1.32.0 73 | google-auth-oauthlib==0.4.4 74 | googleapis-common-protos==1.53.0 75 | gpustat==0.6.0 76 | GPy==1.10.0 77 | gpytorch==1.5.0 78 | greenlet @ file:///tmp/build/80754af9/greenlet_1611957705398/work 79 | grpcio==1.38.1 80 | h5py==2.10.0 81 | HeapDict==1.0.1 82 | HEBO==0.1.0 83 | hiredis==2.0.0 84 | html5lib @ file:///tmp/build/80754af9/html5lib_1593446221756/work 85 | huggingface-hub==0.1.2 86 | idna @ file:///home/linux1/recipes/ci/idna_1610986105248/work 87 | imageio @ file:///tmp/build/80754af9/imageio_1617700267927/work 88 | imagesize @ file:///home/ktietz/src/ci/imagesize_1611921604382/work 89 | importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1617874469820/work 90 | iniconfig @ file:///home/linux1/recipes/ci/iniconfig_1610983019677/work 91 | intervaltree @ file:///tmp/build/80754af9/intervaltree_1598376443606/work 92 | ipykernel @ file:///tmp/build/80754af9/ipykernel_1596207638929/work/dist/ipykernel-5.3.4-py3-none-any.whl 93 | ipython @ file:///tmp/build/80754af9/ipython_1617120885885/work 94 | ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work 95 | ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1610481889018/work 96 | isort @ file:///tmp/build/80754af9/isort_1616355431277/work 97 | itsdangerous @ file:///home/ktietz/src/ci/itsdangerous_1611932585308/work 98 | jdcal==1.4.1 99 | jedi @ file:///tmp/build/80754af9/jedi_1606932564285/work 100 | jeepney @ file:///tmp/build/80754af9/jeepney_1606148855031/work 101 | Jinja2 @ file:///tmp/build/80754af9/jinja2_1612213139570/work 102 | joblib @ file:///tmp/build/80754af9/joblib_1613502643832/work 103 | json5==0.9.5 104 | jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work 105 | jupyter==1.0.0 106 | jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1616770841739/work 107 | jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1616615302928/work 108 | jupyter-core @ file:///tmp/build/80754af9/jupyter_core_1612213311222/work 109 | jupyter-packaging @ file:///tmp/build/80754af9/jupyter-packaging_1613502826984/work 110 | jupyter-server @ file:///tmp/build/80754af9/jupyter_server_1616083640759/work 111 | jupyterlab @ file:///tmp/build/80754af9/jupyterlab_1619133235951/work 112 | jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work 113 | jupyterlab-server @ file:///tmp/build/80754af9/jupyterlab_server_1617134334258/work 114 | jupyterlab-widgets @ file:///tmp/build/80754af9/jupyterlab_widgets_1609884341231/work 115 | keyring @ file:///tmp/build/80754af9/keyring_1614616740399/work 116 | kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1612282420641/work 117 | lazy-object-proxy @ file:///tmp/build/80754af9/lazy-object-proxy_1616526917483/work 118 | libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work 119 | llvmlite==0.36.0 120 | locket==0.2.1 121 | lxml @ file:///tmp/build/80754af9/lxml_1616443220220/work 122 | Mako @ file:///home/conda/feedstock_root/build_artifacts/mako_1610659158978/work 123 | Markdown==3.3.4 124 | MarkupSafe==1.1.1 125 | matplotlib @ file:///tmp/build/80754af9/matplotlib-suite_1613407855456/work 126 | mccabe==0.6.1 127 | mistune==0.8.4 128 | mkl-fft==1.3.0 129 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1618853849286/work 130 | mkl-service==2.3.0 131 | mock @ file:///tmp/build/80754af9/mock_1607622725907/work 132 | more-itertools @ file:///tmp/build/80754af9/more-itertools_1613676688952/work 133 | mpmath==1.2.1 134 | msgpack @ file:///tmp/build/80754af9/msgpack-python_1612287151062/work 135 | multidict==5.1.0 136 | multipledispatch==0.6.0 137 | mypy-extensions==0.4.3 138 | nbclassic @ file:///tmp/build/80754af9/nbclassic_1616085367084/work 139 | nbclient @ file:///tmp/build/80754af9/nbclient_1614364831625/work 140 | nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914830498/work 141 | nbformat @ file:///tmp/build/80754af9/nbformat_1617383369282/work 142 | nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1613680548246/work 143 | networkx @ file:///tmp/build/80754af9/networkx_1598376031484/work 144 | nlpaug==1.1.9 145 | nltk @ file:///tmp/build/80754af9/nltk_1618327084230/work 146 | nose @ file:///tmp/build/80754af9/nose_1606773131901/work 147 | notebook @ file:///tmp/build/80754af9/notebook_1616443462982/work 148 | numexpr @ file:///tmp/build/80754af9/numexpr_1618856167419/work 149 | numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1618497241363/work 150 | numpydoc @ file:///tmp/build/80754af9/numpydoc_1605117425582/work 151 | nvidia-ml-py3==7.352.0 152 | oauthlib==3.1.1 153 | olefile==0.46 154 | opencensus==0.7.13 155 | opencensus-context==0.1.2 156 | openpyxl @ file:///tmp/build/80754af9/openpyxl_1615411699337/work 157 | optuna @ file:///home/conda/feedstock_root/build_artifacts/optuna_1623058031662/work 158 | packaging @ file:///tmp/build/80754af9/packaging_1611952188834/work 159 | pandas==1.2.4 160 | pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120460739/work 161 | paramz==0.9.5 162 | parso==0.7.0 163 | partd @ file:///tmp/build/80754af9/partd_1618000087440/work 164 | path @ file:///tmp/build/80754af9/path_1614022220526/work 165 | pathlib2 @ file:///tmp/build/80754af9/pathlib2_1607024983162/work 166 | pathspec==0.7.0 167 | patsy==0.5.1 168 | pbr @ file:///home/conda/feedstock_root/build_artifacts/pbr_1619460527081/work 169 | pep8==1.7.1 170 | pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work 171 | pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work 172 | Pillow @ file:///tmp/build/80754af9/pillow_1617383569452/work 173 | pkginfo==1.7.0 174 | pluggy @ file:///tmp/build/80754af9/pluggy_1615976321666/work 175 | ply==3.11 176 | prettytable @ file:///home/conda/feedstock_root/build_artifacts/prettytable_1614725168556/work 177 | prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1618088486455/work 178 | prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1616415428029/work 179 | protobuf==3.17.3 180 | psutil @ file:///tmp/build/80754af9/psutil_1612298023621/work 181 | ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 182 | py @ file:///tmp/build/80754af9/py_1607971587848/work 183 | py-spy==0.3.7 184 | pyasn1==0.4.8 185 | pyasn1-modules==0.2.8 186 | pycodestyle @ file:///home/ktietz/src/ci_mi/pycodestyle_1612807597675/work 187 | pycosat==0.6.3 188 | pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work 189 | pycurl==7.43.0.6 190 | pydantic==1.8.2 191 | pydocstyle @ file:///tmp/build/80754af9/pydocstyle_1616182067796/work 192 | pyerfa @ file:///tmp/build/80754af9/pyerfa_1619390903914/work 193 | pyflakes @ file:///home/ktietz/src/ci_ipy2/pyflakes_1612551159640/work 194 | Pygments @ file:///tmp/build/80754af9/pygments_1615143339740/work 195 | pylint @ file:///tmp/build/80754af9/pylint_1617135829881/work 196 | pyls-black @ file:///tmp/build/80754af9/pyls-black_1607553132291/work 197 | pyls-spyder @ file:///tmp/build/80754af9/pyls-spyder_1613849700860/work 198 | pymoo==0.4.2.2 199 | pyodbc===4.0.0-unsupported 200 | pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1608057966937/work 201 | pyparsing @ file:///home/linux1/recipes/ci/pyparsing_1610983426697/work 202 | pyperclip @ file:///home/conda/feedstock_root/build_artifacts/pyperclip_1622337600177/work 203 | pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141720057/work 204 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work 205 | pytest==6.2.3 206 | python-dateutil @ file:///home/ktietz/src/ci/python-dateutil_1611928101742/work 207 | python-editor==1.0.4 208 | python-jsonrpc-server @ file:///tmp/build/80754af9/python-jsonrpc-server_1600278539111/work 209 | python-language-server @ file:///tmp/build/80754af9/python-language-server_1607972495879/work 210 | pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work 211 | PyWavelets @ file:///tmp/build/80754af9/pywavelets_1601658317819/work 212 | pyxdg @ file:///tmp/build/80754af9/pyxdg_1603822279816/work 213 | PyYAML==5.4.1 214 | pyzmq==20.0.0 215 | QDarkStyle==2.8.1 216 | QtAwesome @ file:///tmp/build/80754af9/qtawesome_1615991616277/work 217 | qtconsole @ file:///tmp/build/80754af9/qtconsole_1616775094278/work 218 | QtPy==1.9.0 219 | ray==1.4.1 220 | redis==3.5.3 221 | regex @ file:///tmp/build/80754af9/regex_1617569202463/work 222 | requests @ file:///tmp/build/80754af9/requests_1608241421344/work 223 | requests-oauthlib==1.3.0 224 | rope @ file:///tmp/build/80754af9/rope_1602264064449/work 225 | rsa==4.7.2 226 | Rtree @ file:///tmp/build/80754af9/rtree_1618420845272/work 227 | ruamel-yaml-conda @ file:///tmp/build/80754af9/ruamel_yaml_1616016699510/work 228 | sacremoses==0.0.45 229 | scikit-image==0.18.1 230 | scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1614446682169/work 231 | scikit-multilearn==0.2.0 232 | scipy @ file:///tmp/build/80754af9/scipy_1618855647378/work 233 | seaborn @ file:///tmp/build/80754af9/seaborn_1608578541026/work 234 | SecretStorage @ file:///tmp/build/80754af9/secretstorage_1614022784285/work 235 | Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work 236 | simplegeneric==0.8.1 237 | singledispatch @ file:///tmp/build/80754af9/singledispatch_1614366001199/work 238 | sip==4.19.13 239 | six @ file:///tmp/build/80754af9/six_1605205327372/work 240 | sniffio @ file:///tmp/build/80754af9/sniffio_1614030475067/work 241 | snowballstemmer @ file:///tmp/build/80754af9/snowballstemmer_1611258885636/work 242 | sortedcollections @ file:///tmp/build/80754af9/sortedcollections_1611172717284/work 243 | sortedcontainers @ file:///tmp/build/80754af9/sortedcontainers_1606865132123/work 244 | soupsieve @ file:///tmp/build/80754af9/soupsieve_1616183228191/work 245 | Sphinx @ file:///tmp/build/80754af9/sphinx_1616268783226/work 246 | sphinxcontrib-applehelp @ file:///home/ktietz/src/ci/sphinxcontrib-applehelp_1611920841464/work 247 | sphinxcontrib-devhelp @ file:///home/ktietz/src/ci/sphinxcontrib-devhelp_1611920923094/work 248 | sphinxcontrib-htmlhelp @ file:///home/ktietz/src/ci/sphinxcontrib-htmlhelp_1611920974801/work 249 | sphinxcontrib-jsmath @ file:///home/ktietz/src/ci/sphinxcontrib-jsmath_1611920942228/work 250 | sphinxcontrib-qthelp @ file:///home/ktietz/src/ci/sphinxcontrib-qthelp_1611921055322/work 251 | sphinxcontrib-serializinghtml @ file:///home/ktietz/src/ci/sphinxcontrib-serializinghtml_1611920755253/work 252 | sphinxcontrib-websupport @ file:///tmp/build/80754af9/sphinxcontrib-websupport_1597081412696/work 253 | spyder @ file:///tmp/build/80754af9/spyder_1616775618138/work 254 | spyder-kernels @ file:///tmp/build/80754af9/spyder-kernels_1614030590686/work 255 | SQLAlchemy @ file:///tmp/build/80754af9/sqlalchemy_1618089170652/work 256 | -e git+https://github.com/Weyoun2211/di-research.git@748d15129f0fb1c3c4ba3b91f9ee1133a411892c#egg=src 257 | statsmodels @ file:///tmp/build/80754af9/statsmodels_1614023746358/work 258 | stevedore @ file:///home/conda/feedstock_root/build_artifacts/stevedore_1610093939235/work 259 | sympy @ file:///tmp/build/80754af9/sympy_1618252284338/work 260 | tables==3.6.1 261 | tabulate==0.8.9 262 | tblib @ file:///tmp/build/80754af9/tblib_1597928476713/work 263 | tensorboard==2.5.0 264 | tensorboard-data-server==0.6.1 265 | tensorboard-plugin-wit==1.8.0 266 | tensorboardX==2.4 267 | terminado==0.9.4 268 | testpath @ file:///home/ktietz/src/ci/testpath_1611930608132/work 269 | textdistance @ file:///tmp/build/80754af9/textdistance_1612461398012/work 270 | threadpoolctl @ file:///tmp/tmp9twdgx9k/threadpoolctl-2.1.0-py3-none-any.whl 271 | three-merge @ file:///tmp/build/80754af9/three-merge_1607553261110/work 272 | tifffile==2020.10.1 273 | tokenizers==0.10.3 274 | toml @ file:///tmp/build/80754af9/toml_1616166611790/work 275 | toolz @ file:///home/linux1/recipes/ci/toolz_1610987900194/work 276 | torch==1.9.0 277 | torchaudio==0.9.0a0+33b2469 278 | torchvision==0.10.0 279 | tornado @ file:///tmp/build/80754af9/tornado_1606942300299/work 280 | tqdm @ file:///tmp/build/80754af9/tqdm_1615925068909/work 281 | traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work 282 | transformers==4.12.5 283 | typed-ast @ file:///tmp/build/80754af9/typed-ast_1610484547928/work 284 | typing-extensions @ file:///home/ktietz/src/ci_mi/typing_extensions_1612808209620/work 285 | ujson @ file:///tmp/build/80754af9/ujson_1611259522456/work 286 | unicodecsv==0.14.1 287 | urllib3 @ file:///tmp/build/80754af9/urllib3_1615837158687/work 288 | watchdog @ file:///tmp/build/80754af9/watchdog_1612471027849/work 289 | wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work 290 | webencodings==0.5.1 291 | Werkzeug @ file:///home/ktietz/src/ci/werkzeug_1611932622770/work 292 | widgetsnbextension==3.5.1 293 | wrapt==1.12.1 294 | wurlitzer @ file:///tmp/build/80754af9/wurlitzer_1617224664226/work 295 | xlrd @ file:///tmp/build/80754af9/xlrd_1608072521494/work 296 | XlsxWriter @ file:///tmp/build/80754af9/xlsxwriter_1617224712951/work 297 | xlwt==1.3.0 298 | yapf @ file:///tmp/build/80754af9/yapf_1615749224965/work 299 | yarl==1.6.3 300 | zict==2.0.0 301 | zipp @ file:///tmp/build/80754af9/zipp_1615904174917/work 302 | zope.event==4.5.0 303 | zope.interface @ file:///tmp/build/80754af9/zope.interface_1616357211867/work 304 | -------------------------------------------------------------------------------- /src/contrastive/run_finetune_siamese.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run column type annotation fine-tuning 3 | """ 4 | import numpy as np 5 | np.random.seed(42) 6 | import random 7 | random.seed(42) 8 | 9 | import pandas as pd 10 | from sklearn.metrics import classification_report 11 | 12 | import logging 13 | import os 14 | import sys 15 | from dataclasses import dataclass, field 16 | from typing import Optional 17 | import json 18 | 19 | from copy import deepcopy 20 | 21 | import torch 22 | 23 | import transformers as transformers 24 | 25 | from transformers import ( 26 | HfArgumentParser, 27 | Trainer, 28 | TrainingArguments, 29 | set_seed 30 | ) 31 | from transformers.file_utils import is_offline_mode 32 | from transformers.trainer_utils import get_last_checkpoint 33 | from transformers.utils import check_min_version 34 | from transformers.utils.versions import require_version 35 | 36 | from src.contrastive.models.modeling import ContrastiveClassifierModel 37 | from src.contrastive.data.datasets import ContrastiveClassificationDataset 38 | from src.contrastive.data.data_collators import DataCollatorContrastiveClassification 39 | from src.contrastive.models.metrics import compute_metrics_bce 40 | 41 | from transformers import EarlyStoppingCallback 42 | 43 | from transformers.utils.hp_naming import TrialShortNamer 44 | 45 | from pdb import set_trace 46 | 47 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 48 | check_min_version("4.8.2") 49 | 50 | logger = logging.getLogger(__name__) 51 | 52 | @dataclass 53 | class ModelArguments: 54 | """ 55 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 56 | """ 57 | 58 | model_pretrained_checkpoint: Optional[str] = field( 59 | default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 60 | ) 61 | do_param_opt: Optional[bool] = field( 62 | default=False, metadata={"help": "If aou want to do hyperparamter optimization"} 63 | ) 64 | frozen: Optional[bool] = field( 65 | default=True, metadata={"help": "If encoder params should be frozen"} 66 | ) 67 | grad_checkpoint: Optional[bool] = field( 68 | default=True, metadata={"help": "If aou want to use gradient checkpointing"} 69 | ) 70 | tokenizer: Optional[str] = field( 71 | default='huawei-noah/TinyBERT_General_4L_312D', 72 | metadata={ 73 | "help": "Tokenizer to use" 74 | }, 75 | ) 76 | 77 | @dataclass 78 | class DataTrainingArguments: 79 | """ 80 | Arguments pertaining to what data we are going to input our model for training and eval. 81 | """ 82 | 83 | train_file: Optional[str] = field( 84 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 85 | ) 86 | train_size: Optional[str] = field( 87 | default=None, metadata={"help": "The size of the training set."} 88 | ) 89 | max_train_samples: Optional[int] = field( 90 | default=None, 91 | metadata={ 92 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 93 | "value if set." 94 | }, 95 | ) 96 | augment: Optional[str] = field( 97 | default=None, metadata={"help": "The data augmentation to use."} 98 | ) 99 | validation_file: Optional[str] = field( 100 | default=None, 101 | metadata={ 102 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 103 | "(a jsonlines or csv file)." 104 | }, 105 | ) 106 | max_validation_samples: Optional[int] = field( 107 | default=None, 108 | metadata={ 109 | "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " 110 | "value if set." 111 | }, 112 | ) 113 | test_file: Optional[str] = field( 114 | default=None, 115 | metadata={ 116 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)." 117 | }, 118 | ) 119 | max_test_samples: Optional[int] = field( 120 | default=None, 121 | metadata={ 122 | "help": "For debugging purposes or quicker training, truncate the number of test examples to this " 123 | "value if set." 124 | }, 125 | ) 126 | dataset_name: Optional[str] = field( 127 | default='lspc', 128 | metadata={ 129 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 130 | "(a jsonlines or csv file)." 131 | }, 132 | ) 133 | def __post_init__(self): 134 | if self.train_file is None and self.validation_file is None: 135 | raise ValueError("Need a training file.") 136 | 137 | 138 | 139 | def main(): 140 | 141 | def get_posneg(): 142 | if data_args.dataset_name == 'amazon-google' or data_args.dataset_name == 'abt-buy': 143 | return 9 144 | elif data_args.dataset_name == 'walmart-amazon': 145 | return 10 146 | else: 147 | if data_args.train_size == 'small': 148 | return 3 149 | elif data_args.train_size == 'medium': 150 | return 4 151 | elif data_args.train_size == 'large': 152 | return 5 153 | elif data_args.train_size == 'xlarge': 154 | return 6 155 | 156 | def model_init(trial): 157 | init_args = {} 158 | pos_neg = get_posneg() 159 | if model_args.model_pretrained_checkpoint: 160 | my_model = ContrastiveClassifierModel(checkpoint_path=model_args.model_pretrained_checkpoint, len_tokenizer=len(train_dataset.tokenizer), model=model_args.tokenizer, frozen=model_args.frozen, pos_neg=pos_neg, **init_args) 161 | if model_args.grad_checkpoint: 162 | my_model.encoder.transformer._set_gradient_checkpointing(my_model.encoder.transformer.encoder, True) 163 | return my_model 164 | else: 165 | my_model = ContrastiveClassifierModel(len_tokenizer=len(train_dataset.tokenizer), model=model_args.tokenizer, frozen=model_args.frozen, pos_neg=pos_neg, **init_args) 166 | if model_args.grad_checkpoint: 167 | my_model.encoder.transformer._set_gradient_checkpointing(my_model.encoder.transformer.encoder, True) 168 | return my_model 169 | 170 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 171 | 172 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 173 | 174 | # Setup logging 175 | logging.basicConfig( 176 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 177 | datefmt="%m/%d/%Y %H:%M:%S", 178 | handlers=[logging.StreamHandler(sys.stdout)], 179 | ) 180 | log_level = training_args.get_process_log_level() 181 | logger.setLevel(log_level) 182 | transformers.utils.logging.set_verbosity(log_level) 183 | transformers.utils.logging.enable_default_handler() 184 | transformers.utils.logging.enable_explicit_format() 185 | 186 | # Log on each process the small summary: 187 | logger.warning( 188 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 189 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 190 | ) 191 | logger.info(f"Training/evaluation parameters {training_args}") 192 | 193 | # Set seed before initializing model. 194 | set_seed(training_args.seed) 195 | 196 | data_files = {} 197 | if data_args.train_file is not None: 198 | data_files["train"] = data_args.train_file 199 | if data_args.validation_file is not None: 200 | data_files["validation"] = data_args.validation_file 201 | if data_args.test_file is not None: 202 | data_files["test"] = data_args.test_file 203 | raw_datasets = data_files 204 | 205 | if training_args.do_train: 206 | if "train" not in raw_datasets: 207 | raise ValueError("--do_train requires a train dataset") 208 | train_dataset = raw_datasets["train"] 209 | train_dataset = ContrastiveClassificationDataset(train_dataset, dataset_type='train', size=data_args.train_size, tokenizer=model_args.tokenizer, dataset=data_args.dataset_name, aug=data_args.augment) 210 | if training_args.evaluation_strategy != 'no': 211 | validation_dataset = raw_datasets["validation"] 212 | validation_dataset = ContrastiveClassificationDataset(validation_dataset, dataset_type='validation', size=data_args.train_size, tokenizer=model_args.tokenizer, dataset=data_args.dataset_name) 213 | if training_args.load_best_model_at_end: 214 | test_dataset = raw_datasets["test"] 215 | test_dataset = ContrastiveClassificationDataset(test_dataset, dataset_type='test', size=data_args.train_size, tokenizer=model_args.tokenizer, dataset=data_args.dataset_name) 216 | 217 | elif training_args.do_eval: 218 | if "validation" not in raw_datasets: 219 | raise ValueError("--do_eval requires a validation dataset") 220 | validation_dataset = raw_datasets["validation"] 221 | validation_dataset = ContrastiveClassificationDataset(validation_dataset, dataset_type='validation', size=data_args.train_size, tokenizer=model_args.tokenizer, dataset=data_args.dataset_name) 222 | 223 | elif training_args.do_predict: 224 | if "test" not in raw_datasets: 225 | raise ValueError("--do_predict requires a test dataset") 226 | test_dataset = raw_datasets["test"] 227 | test_dataset = ContrastiveClassificationDataset(test_dataset, dataset_type='test', size=data_args.train_size, tokenizer=model_args.tokenizer, dataset=data_args.dataset_name) 228 | 229 | # Data collator 230 | data_collator = DataCollatorContrastiveClassification(tokenizer=train_dataset.tokenizer) 231 | 232 | # Early stopping callback 233 | callback = EarlyStoppingCallback(early_stopping_patience=10) 234 | 235 | if training_args.do_train and model_args.do_param_opt: 236 | 237 | from ray import tune 238 | def my_hp_space(trial): 239 | return { 240 | "learning_rate": tune.loguniform(5e-5, 5e-3), 241 | "warmup_ratio": tune.choice([0.05, 0.075, 0.10]), 242 | "max_grad_norm": tune.choice([0.0, 1.0]), 243 | "weight_decay": tune.loguniform(0.001, 0.1), 244 | "seed": tune.randint(1, 50) 245 | } 246 | 247 | def my_objective(metrics): 248 | return metrics['eval_f1'] 249 | 250 | trainer = Trainer( 251 | model_init=model_init, 252 | args=training_args, 253 | train_dataset=train_dataset if training_args.do_train else None, 254 | eval_dataset=validation_dataset if training_args.do_eval else None, 255 | data_collator=data_collator, 256 | compute_metrics=compute_metrics_bce, 257 | callbacks=[callback] 258 | ) 259 | trainer.args.save_total_limit = 1 260 | 261 | def hp_name(trial): 262 | namer = TrialShortNamer() 263 | namer.set_defaults('hp', {'learning_rate': 1e-4, 'warmup_ratio': 0.0, 'max_grad_norm': 1.0, 'weight_decay': 0.01, 'seed':1}) 264 | return namer.shortname(trial) 265 | 266 | initial_configs = [ 267 | { 268 | "learning_rate": 1e-3, 269 | "warmup_ratio": 0.05, 270 | "max_grad_norm": 1.0, 271 | "weight_decay": 0.01, 272 | "seed": 42 273 | }, 274 | { 275 | "learning_rate": 1e-4, 276 | "warmup_ratio": 0.05, 277 | "max_grad_norm": 1.0, 278 | "weight_decay": 0.01, 279 | "seed": 42 280 | } 281 | ] 282 | 283 | from ray.tune.suggest.hebo import HEBOSearch 284 | hebo = HEBOSearch(metric="eval_f1", mode="max", points_to_evaluate=initial_configs, random_state_seed=42) 285 | 286 | best_run = trainer.hyperparameter_search(n_trials=24, direction="maximize", hp_space=my_hp_space, compute_objective=my_objective, backend='ray', 287 | resources_per_trial={'cpu':4,'gpu':1}, local_dir=f'{training_args.output_dir}ray_results/', hp_name=hp_name, search_alg=hebo) 288 | 289 | with open(f'{training_args.output_dir}best_params.json', 'w') as f: 290 | json.dump(best_run, f) 291 | 292 | output_dir = deepcopy(training_args.output_dir) 293 | for run in range(3): 294 | init_args = {} 295 | 296 | training_args.save_total_limit = 1 297 | training_args.seed = run 298 | training_args.output_dir = f'{output_dir}{run}' 299 | # if model_args.do_param_opt: 300 | # init_args = {k:v for k, v in best_run.hyperparameters.items() if k in MODEL_PARAMS} 301 | 302 | 303 | # Detecting last checkpoint. 304 | last_checkpoint = None 305 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 306 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 307 | 308 | pos_neg = get_posneg() 309 | if model_args.model_pretrained_checkpoint: 310 | model = ContrastiveClassifierModel(checkpoint_path=model_args.model_pretrained_checkpoint, len_tokenizer=len(train_dataset.tokenizer), model=model_args.tokenizer, frozen=model_args.frozen, pos_neg=pos_neg, **init_args) 311 | if model_args.grad_checkpoint: 312 | model.encoder.transformer._set_gradient_checkpointing(model.encoder.transformer.encoder, True) 313 | else: 314 | model = ContrastiveClassifierModel(len_tokenizer=len(train_dataset.tokenizer), model=model_args.tokenizer, frozen=model_args.frozen, pos_neg=pos_neg, **init_args) 315 | if model_args.grad_checkpoint: 316 | model.encoder.transformer._set_gradient_checkpointing(model.encoder.transformer.encoder, True) 317 | 318 | # Initialize our Trainer 319 | trainer = Trainer( 320 | model=model, 321 | args=training_args, 322 | train_dataset=train_dataset if training_args.do_train else None, 323 | eval_dataset=validation_dataset if training_args.do_eval else None, 324 | data_collator=data_collator, 325 | compute_metrics=compute_metrics_bce, 326 | callbacks=[callback] 327 | ) 328 | 329 | # Training 330 | if training_args.do_train: 331 | if model_args.do_param_opt: 332 | for n, v in best_run.hyperparameters.items(): 333 | setattr(trainer.args, n, v) 334 | # if n not in MODEL_PARAMS: 335 | # setattr(trainer.args, n, v) 336 | 337 | checkpoint = None 338 | if training_args.resume_from_checkpoint is not None: 339 | checkpoint = training_args.resume_from_checkpoint 340 | elif last_checkpoint is not None: 341 | checkpoint = last_checkpoint 342 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 343 | trainer.save_model() # Saves the tokenizer too for easy upload 344 | 345 | metrics = train_result.metrics 346 | max_train_samples = ( 347 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 348 | ) 349 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 350 | 351 | trainer.log_metrics(f"train", metrics) 352 | trainer.save_metrics(f"train", metrics) 353 | trainer.save_state() 354 | 355 | # Evaluation 356 | results = {} 357 | if training_args.do_eval: 358 | logger.info("*** Evaluate ***") 359 | 360 | metrics = trainer.evaluate( 361 | metric_key_prefix="eval" 362 | ) 363 | max_eval_samples = len(validation_dataset) 364 | metrics["eval_samples"] = max_eval_samples 365 | 366 | trainer.log_metrics(f"eval", metrics) 367 | trainer.save_metrics(f"eval", metrics) 368 | 369 | if training_args.do_predict or training_args.do_train: 370 | logger.info("*** Predict ***") 371 | 372 | predict_results = trainer.predict( 373 | test_dataset, 374 | metric_key_prefix="predict" 375 | ) 376 | 377 | metrics = predict_results.metrics 378 | max_predict_samples = len(test_dataset) 379 | metrics["predict_samples"] = max_predict_samples 380 | 381 | trainer.log_metrics(f"predict", metrics) 382 | trainer.save_metrics(f"predict", metrics) 383 | 384 | 385 | return results 386 | 387 | if __name__ == "__main__": 388 | main() -------------------------------------------------------------------------------- /src/contrastive/data/datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | np.random.seed(42) 3 | import random 4 | random.seed(42) 5 | 6 | import pandas as pd 7 | 8 | from pathlib import Path 9 | import glob 10 | import gzip 11 | import pickle 12 | from copy import deepcopy 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | from transformers import AutoTokenizer, AutoConfig 18 | 19 | import nlpaug.augmenter.word as naw 20 | import nlpaug.augmenter.char as nac 21 | from sklearn.preprocessing import LabelEncoder 22 | 23 | from pdb import set_trace 24 | 25 | def assign_clusterid(identifier, cluster_id_dict, cluster_id_amount): 26 | try: 27 | result = cluster_id_dict[identifier] 28 | except KeyError: 29 | result = cluster_id_amount 30 | return result 31 | 32 | 33 | # Methods for serializing examples by dataset 34 | def serialize_sample_lspc(sample): 35 | string = '' 36 | string = f'{string}[COL] brand [VAL] {" ".join(sample[f"brand"].split(" ")[:5])}'.strip() 37 | string = f'{string} [COL] title [VAL] {" ".join(sample[f"title"].split(" ")[:50])}'.strip() 38 | string = f'{string} [COL] description [VAL] {" ".join(sample[f"description"].split(" ")[:100])}'.strip() 39 | string = f'{string} [COL] specTableContent [VAL] {" ".join(sample[f"specTableContent"].split(" ")[:200])}'.strip() 40 | 41 | return string 42 | 43 | def serialize_sample_abtbuy(sample): 44 | string = '' 45 | string = f'{string}[COL] brand [VAL] {" ".join(sample[f"brand"].split())}'.strip() 46 | string = f'{string} [COL] title [VAL] {" ".join(sample[f"name"].split())}'.strip() 47 | string = f'{string} [COL] price [VAL] {" ".join(str(sample[f"price"]).split())}'.strip() 48 | string = f'{string} [COL] description [VAL] {" ".join(sample[f"description"].split()[:100])}'.strip() 49 | 50 | return string 51 | 52 | def serialize_sample_amazongoogle(sample): 53 | string = '' 54 | string = f'{string}[COL] brand [VAL] {" ".join(sample[f"manufacturer"].split())}'.strip() 55 | string = f'{string} [COL] title [VAL] {" ".join(sample[f"title"].split())}'.strip() 56 | string = f'{string} [COL] price [VAL] {" ".join(str(sample[f"price"]).split())}'.strip() 57 | string = f'{string} [COL] description [VAL] {" ".join(sample[f"description"].split()[:100])}'.strip() 58 | 59 | return string 60 | 61 | # Class for Data Augmentation 62 | class Augmenter(): 63 | def __init__(self, aug): 64 | 65 | stopwords = ['[COL]', '[VAL]', 'title', 'name', 'description', 'manufacturer', 'brand', 'specTableContent'] 66 | 67 | aug_typo = nac.KeyboardAug(stopwords=stopwords, aug_char_p=0.1, aug_word_p=0.1) 68 | aug_swap = naw.RandomWordAug(action="swap", stopwords=stopwords, aug_p=0.1) 69 | aug_del = naw.RandomWordAug(action="delete", stopwords=stopwords, aug_p=0.1) 70 | aug_crop = naw.RandomWordAug(action="crop", stopwords=stopwords, aug_p=0.1) 71 | aug_sub = naw.RandomWordAug(action="substitute", stopwords=stopwords, aug_p=0.1) 72 | aug_split = naw.SplitAug(stopwords=stopwords, aug_p=0.1) 73 | 74 | aug = aug.strip('-') 75 | 76 | if aug == 'all': 77 | self.augs = [aug_typo, aug_swap, aug_split, aug_sub, aug_del, aug_crop, None] 78 | 79 | if aug == 'typo': 80 | self.augs = [aug_typo, None] 81 | 82 | if aug == 'swap': 83 | self.augs = [aug_swap, None] 84 | 85 | if aug == 'delete': 86 | self.augs = [aug_del, None] 87 | 88 | if aug == 'crop': 89 | self.augs = [aug_crop, None] 90 | 91 | if aug == 'substitute': 92 | self.augs = [aug_sub, None] 93 | 94 | if aug == 'split': 95 | self.augs = [aug_split, None] 96 | 97 | def apply_aug(self, string): 98 | aug = random.choice(self.augs) 99 | if aug is None: 100 | return string 101 | else: 102 | return aug.augment(string) 103 | 104 | # Dataset class for general Contrastive Pre-training for WDC computers 105 | class ContrastivePretrainDataset(torch.utils.data.Dataset): 106 | def __init__(self, path, deduction_set, tokenizer='huawei-noah/TinyBERT_General_4L_312D', max_length=128, intermediate_set=None, clean=False, dataset='lspc', only_interm=False, aug=False): 107 | 108 | self.max_length = max_length 109 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, additional_special_tokens=('[COL]', '[VAL]')) 110 | self.dataset = dataset 111 | self.aug = aug 112 | 113 | if self.aug: 114 | self.augmenter = Augmenter(self.aug) 115 | 116 | data = pd.read_pickle(path) 117 | 118 | if dataset == 'abt-buy': 119 | data['brand'] = '' 120 | 121 | if dataset == 'amazon-google': 122 | data['description'] = '' 123 | 124 | if intermediate_set is not None: 125 | interm_data = pd.read_pickle(intermediate_set) 126 | if only_interm: 127 | data = interm_data 128 | else: 129 | data = data.append(interm_data) 130 | 131 | data = data.reset_index(drop=True) 132 | 133 | data = data.fillna('') 134 | data = self._prepare_data(data) 135 | 136 | self.data = data 137 | 138 | 139 | def __getitem__(self, idx): 140 | # for every example in batch, sample one positive from the dataset 141 | example = self.data.loc[idx].copy() 142 | selection = self.data[self.data['labels'] == example['labels']] 143 | # if len(selection) > 1: 144 | # selection = selection.drop(idx) 145 | pos = selection.sample(1).iloc[0].copy() 146 | 147 | # apply augmentation if set 148 | if self.aug: 149 | example['features'] = self.augmenter.apply_aug(example['features']) 150 | pos['features'] = self.augmenter.apply_aug(pos['features']) 151 | 152 | return (example, pos) 153 | 154 | def __len__(self): 155 | return len(self.data) 156 | 157 | def _prepare_data(self, data): 158 | 159 | if self.dataset == 'lspc': 160 | data['features'] = data.apply(serialize_sample_lspc, axis=1) 161 | 162 | elif self.dataset == 'abt-buy': 163 | data['features'] = data.apply(serialize_sample_abtbuy, axis=1) 164 | 165 | elif self.dataset == 'amazon-google': 166 | data['features'] = data.apply(serialize_sample_amazongoogle, axis=1) 167 | 168 | label_enc = LabelEncoder() 169 | data['labels'] = label_enc.fit_transform(data['cluster_id']) 170 | 171 | self.label_encoder = label_enc 172 | 173 | data = data[['features', 'labels']] 174 | 175 | return data 176 | 177 | # Dataset class for Contrastive Pre-training for Abt-Buy and Amazon-Google 178 | # builds correspondence graph from train+val and builds source-aware sampling datasets 179 | # if split=False, corresponds to not using source-aware sampling 180 | class ContrastivePretrainDatasetDeepmatcher(torch.utils.data.Dataset): 181 | def __init__(self, path, deduction_set, tokenizer='huawei-noah/TinyBERT_General_4L_312D', max_length=128, intermediate_set=None, clean=False, dataset='abt-buy', aug=False, split=True): 182 | 183 | self.max_length = max_length 184 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, additional_special_tokens=('[COL]', '[VAL]')) 185 | self.dataset = dataset 186 | self.aug = aug 187 | 188 | if self.aug: 189 | self.augmenter = Augmenter(self.aug) 190 | 191 | data = pd.read_pickle(path) 192 | 193 | if dataset == 'abt-buy': 194 | data['brand'] = '' 195 | 196 | if dataset == 'amazon-google': 197 | data['description'] = '' 198 | 199 | if clean: 200 | train_data = pd.read_json(deduction_set, lines=True) 201 | 202 | if dataset == 'abt-buy': 203 | val = pd.read_csv('../../data/interim/abt-buy/abt-buy-valid.csv') 204 | elif dataset == 'amazon-google': 205 | val = pd.read_csv('../../data/interim/amazon-google/amazon-google-valid.csv') 206 | 207 | # use 80% of train and val set positives to build correspondence graph 208 | val_set = train_data[train_data['pair_id'].isin(val['pair_id'])] 209 | val_set_pos = val_set[val_set['label'] == 1] 210 | val_set_pos = val_set_pos.sample(frac=0.80) 211 | val_ids = set() 212 | val_ids.update(val_set['pair_id']) 213 | 214 | train_data = train_data[~train_data['pair_id'].isin(val_ids)] 215 | train_data = train_data[train_data['label'] == 1] 216 | train_data = train_data.sample(frac=0.80) 217 | 218 | train_data = train_data.append(val_set_pos) 219 | 220 | # build the connected components by applying binning 221 | bucket_list = [] 222 | for i, row in train_data.iterrows(): 223 | left = f'{row["id_left"]}' 224 | right = f'{row["id_right"]}' 225 | found = False 226 | for bucket in bucket_list: 227 | if left in bucket and row['label'] == 1: 228 | bucket.add(right) 229 | found = True 230 | break 231 | elif right in bucket and row['label'] == 1: 232 | bucket.add(left) 233 | found = True 234 | break 235 | if not found: 236 | bucket_list.append(set([left, right])) 237 | 238 | cluster_id_amount = len(bucket_list) 239 | 240 | #assign labels to connected components and single nodes (at this point single nodes have same label) 241 | cluster_id_dict = {} 242 | for i, id_set in enumerate(bucket_list): 243 | for v in id_set: 244 | cluster_id_dict[v] = i 245 | data = data.set_index('id', drop=False) 246 | data['cluster_id'] = data['id'].apply(assign_clusterid, args=(cluster_id_dict, cluster_id_amount)) 247 | #data = data[data['cluster_id'] != cluster_id_amount] 248 | 249 | single_entities = data[data['cluster_id'] == cluster_id_amount].copy() 250 | 251 | index = single_entities.index 252 | 253 | if dataset == 'abt-buy': 254 | left_index = [x for x in index if 'abt' in x] 255 | right_index = [x for x in index if 'buy' in x] 256 | elif dataset == 'amazon-google': 257 | left_index = [x for x in index if 'amazon' in x] 258 | right_index = [x for x in index if 'google' in x] 259 | 260 | # assing increasing integer label to single nodes 261 | single_entities = single_entities.reset_index(drop=True) 262 | single_entities['cluster_id'] = single_entities['cluster_id'] + single_entities.index 263 | single_entities = single_entities.set_index('id', drop=False) 264 | single_entities_left = single_entities.loc[left_index] 265 | single_entities_right = single_entities.loc[right_index] 266 | 267 | # source aware sampling, build one sample per dataset 268 | if split: 269 | data1 = data.copy().drop(single_entities['id']) 270 | data1 = data1.append(single_entities_left) 271 | 272 | data2 = data.copy().drop(single_entities['id']) 273 | data2 = data2.append(single_entities_right) 274 | 275 | else: 276 | data1 = data.copy().drop(single_entities['id']) 277 | data1 = data1.append(single_entities_left) 278 | data1 = data1.append(single_entities_right) 279 | 280 | data2 = data.copy().drop(single_entities['id']) 281 | data2 = data2.append(single_entities_left) 282 | data2 = data2.append(single_entities_right) 283 | 284 | if intermediate_set is not None: 285 | interm_data = pd.read_pickle(intermediate_set) 286 | if dataset != 'lspc': 287 | cols = data.columns 288 | if 'name' in cols: 289 | interm_data = interm_data.rename(columns={'title':'name'}) 290 | if 'manufacturer' in cols: 291 | interm_data = interm_data.rename(columns={'brand':'manufacturer'}) 292 | interm_data['cluster_id'] = interm_data['cluster_id']+10000 293 | 294 | data1 = data1.append(interm_data) 295 | data2 = data2.append(interm_data) 296 | 297 | data1 = data1.reset_index(drop=True) 298 | data2 = data2.reset_index(drop=True) 299 | 300 | label_enc = LabelEncoder() 301 | cluster_id_set = set() 302 | cluster_id_set.update(data1['cluster_id']) 303 | cluster_id_set.update(data2['cluster_id']) 304 | label_enc.fit(list(cluster_id_set)) 305 | data1['labels'] = label_enc.transform(data1['cluster_id']) 306 | data2['labels'] = label_enc.transform(data2['cluster_id']) 307 | 308 | self.label_encoder = label_enc 309 | 310 | data1 = data1.reset_index(drop=True) 311 | 312 | data1 = data1.fillna('') 313 | data1 = self._prepare_data(data1) 314 | 315 | data2 = data2.reset_index(drop=True) 316 | 317 | data2 = data2.fillna('') 318 | data2 = self._prepare_data(data2) 319 | 320 | diff = abs(len(data1)-len(data2)) 321 | 322 | if len(data1) > len(data2): 323 | if len(data2) < diff: 324 | sample = data2.sample(diff, replace=True) 325 | else: 326 | sample = data2.sample(diff) 327 | data2 = data2.append(sample) 328 | data2 = data2.reset_index(drop=True) 329 | 330 | elif len(data2) > len(data1): 331 | if len(data1) < diff: 332 | sample = data1.sample(diff, replace=True) 333 | else: 334 | sample = data1.sample(diff) 335 | data1 = data1.append(sample) 336 | data1 = data1.reset_index(drop=True) 337 | 338 | self.data1 = data1 339 | self.data2 = data2 340 | 341 | def __getitem__(self, idx): 342 | # for every example, sample one positive from the respective sampling dataset 343 | example1 = self.data1.loc[idx].copy() 344 | selection1 = self.data1[self.data1['labels'] == example1['labels']] 345 | # if len(selection1) > 1: 346 | # selection1 = selection1.drop(idx) 347 | pos1 = selection1.sample(1).iloc[0].copy() 348 | 349 | example2 = self.data2.loc[idx].copy() 350 | selection2 = self.data2[self.data2['labels'] == example2['labels']] 351 | # if len(selection2) > 1: 352 | # selection2 = selection2.drop(idx) 353 | pos2 = selection2.sample(1).iloc[0].copy() 354 | 355 | # apply augmentation if set 356 | if self.aug: 357 | example1['features'] = self.augmenter.apply_aug(example1['features']) 358 | pos1['features'] = self.augmenter.apply_aug(pos1['features']) 359 | example2['features'] = self.augmenter.apply_aug(example2['features']) 360 | pos2['features'] = self.augmenter.apply_aug(pos2['features']) 361 | 362 | return ((example1, pos1), (example2, pos2)) 363 | 364 | def __len__(self): 365 | return len(self.data1) 366 | 367 | def _prepare_data(self, data): 368 | 369 | if self.dataset == 'lspc': 370 | data['features'] = data.apply(serialize_sample_lspc, axis=1) 371 | 372 | elif self.dataset == 'abt-buy': 373 | data['features'] = data.apply(serialize_sample_abtbuy, axis=1) 374 | 375 | elif self.dataset == 'amazon-google': 376 | data['features'] = data.apply(serialize_sample_amazongoogle, axis=1) 377 | 378 | data = data[['features', 'labels']] 379 | 380 | return data 381 | 382 | # Dataset class for pair-wise cross-entropy fine-tuning 383 | class ContrastiveClassificationDataset(torch.utils.data.Dataset): 384 | def __init__(self, path, dataset_type, size=None, tokenizer='huawei-noah/TinyBERT_General_4L_312D', max_length=128, dataset='lspc', aug=False): 385 | 386 | self.max_length = max_length 387 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, additional_special_tokens=('[COL]', '[VAL]')) 388 | self.dataset_type = dataset_type 389 | self.dataset = dataset 390 | self.aug = aug 391 | 392 | if self.aug: 393 | self.augmenter = Augmenter(self.aug) 394 | 395 | if dataset == 'lspc': 396 | data = pd.read_pickle(path) 397 | else: 398 | data = pd.read_json(path, lines=True) 399 | 400 | if dataset == 'abt-buy': 401 | data['brand_left'] = '' 402 | data['brand_right'] = '' 403 | 404 | if dataset == 'amazon-google': 405 | data['description_left'] = '' 406 | data['description_right'] = '' 407 | 408 | data = data.fillna('') 409 | 410 | if self.dataset_type != 'test': 411 | if dataset == 'lspc': 412 | validation_ids = pd.read_csv(f'../../data/raw/wdc-lspc/validation-sets/computers_valid_{size}.csv') 413 | elif dataset == 'abt-buy': 414 | validation_ids = pd.read_csv(f'../../data/interim/abt-buy/abt-buy-valid.csv') 415 | elif dataset == 'amazon-google': 416 | validation_ids = pd.read_csv(f'../../data/interim/amazon-google/amazon-google-valid.csv') 417 | if self.dataset_type == 'train': 418 | data = data[~data['pair_id'].isin(validation_ids['pair_id'])] 419 | else: 420 | data = data[data['pair_id'].isin(validation_ids['pair_id'])] 421 | 422 | data = data.reset_index(drop=True) 423 | 424 | data = self._prepare_data(data) 425 | 426 | self.data = data 427 | 428 | 429 | def __getitem__(self, idx): 430 | example = self.data.loc[idx].copy() 431 | 432 | if self.aug: 433 | example['features_left'] = self.augmenter.apply_aug(example['features_left']) 434 | example['features_right'] = self.augmenter.apply_aug(example['features_right']) 435 | 436 | return example 437 | 438 | def __len__(self): 439 | return len(self.data) 440 | 441 | def _prepare_data(self, data): 442 | 443 | if self.dataset == 'lspc': 444 | data['features_left'] = data.apply(self.serialize_sample_lspc, args=('left',), axis=1) 445 | data['features_right'] = data.apply(self.serialize_sample_lspc, args=('right',), axis=1) 446 | elif self.dataset == 'abt-buy': 447 | data['features_left'] = data.apply(self.serialize_sample_abtbuy, args=('left',), axis=1) 448 | data['features_right'] = data.apply(self.serialize_sample_abtbuy, args=('right',), axis=1) 449 | elif self.dataset == 'amazon-google': 450 | data['features_left'] = data.apply(self.serialize_sample_amazongoogle, args=('left',), axis=1) 451 | data['features_right'] = data.apply(self.serialize_sample_amazongoogle, args=('right',), axis=1) 452 | 453 | data = data[['features_left', 'features_right', 'label']] 454 | data = data.rename(columns={'label': 'labels'}) 455 | 456 | return data 457 | 458 | def serialize_sample_lspc(self, sample, side): 459 | 460 | string = '' 461 | string = f'{string}[COL] brand [VAL] {" ".join(sample[f"brand_{side}"].split(" ")[:5])}'.strip() 462 | string = f'{string} [COL] title [VAL] {" ".join(sample[f"title_{side}"].split(" ")[:50])}'.strip() 463 | string = f'{string} [COL] description [VAL] {" ".join(sample[f"description_{side}"].split(" ")[:100])}'.strip() 464 | string = f'{string} [COL] specTableContent [VAL] {" ".join(sample[f"specTableContent_{side}"].split(" ")[:200])}'.strip() 465 | 466 | return string 467 | 468 | def serialize_sample_abtbuy(self, sample, side): 469 | 470 | string = '' 471 | string = f'{string}[COL] brand [VAL] {" ".join(sample[f"brand_{side}"].split())}'.strip() 472 | string = f'{string} [COL] title [VAL] {" ".join(sample[f"name_{side}"].split())}'.strip() 473 | string = f'{string} [COL] price [VAL] {" ".join(str(sample[f"price_{side}"]).split())}'.strip() 474 | string = f'{string} [COL] description [VAL] {" ".join(sample[f"description_{side}"].split()[:100])}'.strip() 475 | 476 | 477 | return string 478 | 479 | def serialize_sample_amazongoogle(self, sample, side): 480 | 481 | string = '' 482 | string = f'{string}[COL] brand [VAL] {" ".join(sample[f"manufacturer_{side}"].split())}'.strip() 483 | string = f'{string} [COL] title [VAL] {" ".join(sample[f"title_{side}"].split())}'.strip() 484 | string = f'{string} [COL] price [VAL] {" ".join(str(sample[f"price_{side}"]).split())}'.strip() 485 | string = f'{string} [COL] description [VAL] {" ".join(sample[f"description_{side}"].split()[:100])}'.strip() 486 | 487 | return string --------------------------------------------------------------------------------