Hello, I have solved the first two problems, but I encountered a new problem at 50% of the first epoch of the code:
Traceback (most recent call last):
File "/home/lizhaohui/DiffCSE/train.py", line 600, in
main()
File "/home/lizhaohui/DiffCSE/train.py", line 564, in main
train_result = trainer.train(model_path=model_path)
File "/home/lizhaohui/DiffCSE/diffcse/trainers.py", line 513, in train
tr_loss += self.training_step(model, inputs)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/transformers/trainer.py", line 1248, in training_step
loss = self.compute_loss(model, inputs)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/transformers/trainer.py", line 1277, in compute_loss
outputs = model(**inputs)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 162, in forward
return self.gather(outputs, self.output_device)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 174, in gather
return gather(outputs, output_device, dim=self.dim)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/scatter_gather.py", line 68, in gather
res = gather_map(outputs)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/scatter_gather.py", line 61, in gather_map
return type(out)(((k, gather_map([d[k] for d in outputs]))
File "", line 7, in init
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/transformers/file_utils.py", line 1383, in post_init
for element in iterator:
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/scatter_gather.py", line 61, in
return type(out)(((k, gather_map([d[k] for d in outputs]))
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/scatter_gather.py", line 55, in gather_map
return Gather.apply(target_device, dim, *outputs)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/_functions.py", line 71, in forward
return comm.gather(inputs, ctx.dim, ctx.target_device)
File "/home/lizhaohui/miniconda3/envs/py39/lib/python3.9/site-packages/torch/nn/parallel/comm.py", line 230, in gather return torch._C._gather(tensors, dim, destination)
RuntimeError: Input tensor at index 3 has invalid shape [14, 14], but expected [14, 17]
50%|█████████████████████████████████████▍ | 3906/7814 [1:00:50<1:00:51, 1.07it/s]Fatal Python error: PyEval_SaveThread: the function must be called with the GIL held, but the GIL is released (the current Python thread state is NULL)
Python runtime state: finalizing (tstate=0x559cd4786400)
run_diffcse.sh: line 30: 3465377 Aborted (core dumped) python train.py --model_name_or_path bert-base-uncased --generator_name distilbert-base-uncased --train_file data/wiki1m_for_simcse.txt --output_dir output_dir --num_train_epochs 2 --per_device_train_batch_size 64 --learning_rate 7e-6 --max_seq_length 32 --evaluation_strategy steps --metric_for_best_model stsb_spearman --load_best_model_at_end --eval_steps 125 --pooler_type cls --mlp_only_train --overwrite_output_dir --logging_first_step --logging_dir log_dir --temp 0.05 --do_train --do_eval --batchnorm --lambda_weight 0.005 --fp16 --masking_ratio 0.30