diff --git a/nipype/interfaces/fsl/dti.py b/nipype/interfaces/fsl/dti.py index cd46067daa..c65b10b1f2 100644 --- a/nipype/interfaces/fsl/dti.py +++ b/nipype/interfaces/fsl/dti.py @@ -6,6 +6,7 @@ """ import os import warnings +from shutil import which from ...utils.filemanip import fname_presuffix, split_filename, copyfile from ..base import ( @@ -383,6 +384,7 @@ class BEDPOSTX5InputSpec(FSLXCommandInputSpec): ) grad_dev = File(exists=True, desc="grad_dev file, if gradnonlin, -g is True") use_gpu = traits.Bool(False, desc="Use the GPU version of bedpostx") + num_threads = traits.Int(nohash=True, desc="Number of threads to use") class BEDPOSTX5OutputSpec(TraitedSpec): @@ -451,13 +453,25 @@ class BEDPOSTX5(FSLXCommand): def __init__(self, **inputs): super().__init__(**inputs) self.inputs.on_trait_change(self._cuda_update, "use_gpu") + self.inputs.on_trait_change(self._num_threads_update, "num_threads") def _cuda_update(self): - if isdefined(self.inputs.use_gpu) and self.inputs.use_gpu: + if isdefined(self.inputs.use_gpu) and self.inputs.use_gpu and which("bedpostx_gpu") is not None: self._cmd = "bedpostx_gpu" + self.inputs.num_threads = 1 else: self._cmd = self._default_cmd + def _num_threads_update(self): + if isdefined(self.inputs.use_gpu) and self.inputs.use_gpu and which("bedpostx_gpu") is not None: + self.inputs.num_threads = 1 + self._num_threads = self.inputs.num_threads + if not isdefined(self.inputs.num_threads): + if "FSLSUB_PARALLEL" in self.inputs.environ: + del self.inputs.environ["FSLSUB_PARALLEL"] + else: + self.inputs.environ["FSLSUB_PARALLEL"] = str(self.inputs.num_threads) + def _run_interface(self, runtime): subjectdir = os.path.abspath(self.inputs.out_dir) if not os.path.exists(subjectdir): @@ -1024,6 +1038,7 @@ class ProbTrackX2InputSpec(ProbTrackXBaseInputSpec): '"vox"' ), ) + use_gpu = traits.Bool(False, desc="Use the GPU version of probtrackx2") class ProbTrackX2OutputSpec(ProbTrackXOutputSpec): @@ -1059,9 +1074,20 @@ class ProbTrackX2(ProbTrackX): """ _cmd = "probtrackx2" + _default_cmd = _cmd input_spec = ProbTrackX2InputSpec output_spec = ProbTrackX2OutputSpec + def __init__(self, **inputs): + super().__init__(**inputs) + self.inputs.on_trait_change(self._cuda_update, "use_gpu") + + def _cuda_update(self): + if isdefined(self.inputs.use_gpu) and self.inputs.use_gpu and which("probtrackx2_gpu") is not None: + self._cmd = "probtrackx2_gpu" + else: + self._cmd = self._default_cmd + def _list_outputs(self): outputs = super()._list_outputs()