Skip to content

Commit 3be86c6

Browse files
author
ouyhlan
committed
修复Trainer里check_code函数忽略pin_memory参数导致的内存bug
1 parent 9ac7d09 commit 3be86c6

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

fastNLP/core/tester.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=No
113113
self.verbose = verbose
114114
self.use_tqdm = use_tqdm
115115
self.logger = logger
116-
self.pin_memory = kwargs.get('pin_memory', True)
116+
self.pin_memory = kwargs.get('pin_memory', False)
117117

118118
if isinstance(data, DataSet):
119119
sampler = kwargs.get('sampler', None)

fastNLP/core/trainer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,6 @@ def on_epoch_end(self):
334334
except:
335335
from .utils import _pseudo_tqdm as tqdm
336336
import warnings
337-
from pkg_resources import parse_version
338337

339338
from .batch import DataSetIter, BatchIter
340339
from .callback import CallbackManager, CallbackException, Callback
@@ -475,7 +474,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
475474
if drop_last:
476475
warnings.warn("drop_last is ignored when train_data is BatchIter.")
477476
# concerning issue from https://github.com/pytorch/pytorch/issues/57273
478-
self.pin_memory = kwargs.get('pin_memory', False if parse_version(torch.__version__)==parse_version('1.9') else True)
477+
self.pin_memory = kwargs.get('pin_memory', False)
479478
if isinstance(model, nn.parallel.DistributedDataParallel): # 如果是分布式的
480479
# device为None
481480
if device is not None:

0 commit comments

Comments
 (0)