76 using typename Superclass::CoordinateRepresentationType;
77 using typename Superclass::MovingImageType;
78 using typename Superclass::MovingImagePixelType;
79 using typename Superclass::MovingImageConstPointer;
80 using typename Superclass::FixedImageType;
81 using typename Superclass::FixedImageConstPointer;
82 using typename Superclass::FixedImageRegionType;
83 using typename Superclass::TransformType;
84 using typename Superclass::TransformPointer;
85 using typename Superclass::InputPointType;
86 using typename Superclass::OutputPointType;
87 using typename Superclass::TransformJacobianType;
89 using typename Superclass::InterpolatorType;
90 using typename Superclass::InterpolatorPointer;
91 using typename Superclass::RealType;
92 using typename Superclass::GradientPixelType;
93 using typename Superclass::GradientImageType;
94 using typename Superclass::GradientImagePointer;
99 using typename Superclass::MeasureType;
100 using typename Superclass::DerivativeType;
102 using typename Superclass::ParametersType;
134 GetValue(
const ParametersType & parameters)
const override;
140 GetDerivative(
const ParametersType & parameters, DerivativeType & derivative)
const override;
147 DerivativeType & derivative)
const;
156 DerivativeType & derivative)
const override;
168 itkSetMacro(FixedModelsConfiguration, std::vector<ImpactModelConfiguration>);
169 itkGetConstReferenceMacro(FixedModelsConfiguration, std::vector<ImpactModelConfiguration>);
174 itkSetMacro(MovingModelsConfiguration, std::vector<ImpactModelConfiguration>);
175 itkGetConstReferenceMacro(MovingModelsConfiguration, std::vector<ImpactModelConfiguration>);
180 itkSetMacro(SubsetFeatures, std::vector<unsigned int>);
181 itkGetConstMacro(SubsetFeatures, std::vector<unsigned int>);
186 itkSetMacro(LayersWeight, std::vector<float>);
187 itkGetConstMacro(LayersWeight, std::vector<float>);
193 itkSetMacro(Distance, std::vector<std::string>);
194 itkGetConstMacro(Distance, std::vector<std::string>);
199 itkSetMacro(PCA, std::vector<unsigned int>);
200 itkGetConstMacro(PCA, std::vector<unsigned int>);
205 itkSetMacro(Device, torch::Device);
206 itkGetConstMacro(Device, torch::Device);
211 itkSetMacro(UseMixedPrecision,
bool);
212 itkGetConstMacro(UseMixedPrecision,
bool);
217 itkSetMacro(WriteFeatureMaps,
bool);
218 itkGetConstMacro(WriteFeatureMaps,
bool);
223 itkSetMacro(FeatureMapsPath, std::string);
224 itkGetConstMacro(FeatureMapsPath, std::string);
230 itkSetMacro(Mode, std::string);
231 itkGetConstMacro(Mode, std::string);
235 itkSetMacro(CurrentLevel,
unsigned int);
236 itkGetConstMacro(CurrentLevel,
unsigned int);
240 itkSetMacro(Seed,
unsigned int);
241 itkGetConstMacro(Seed,
unsigned int);
246 itkSetMacro(FeaturesMapUpdateInterval,
int);
247 itkGetConstMacro(FeaturesMapUpdateInterval,
int);
275 std::vector<std::unique_ptr<ImpactLoss::Loss>>
m_Losses;
282 init(std::vector<std::string> distanceName, std::vector<float> layersWeight,
unsigned int seed)
293 for (std::string name : distanceName)
305 m_Losses[l]->setNumberOfParameters(numberOfParameters);
313 for (std::unique_ptr<ImpactLoss::Loss> & loss :
m_Losses)
322 MeasureType value = MeasureType{};
338 for (
int i = 0; i < d.size(0); ++i)
340 derivative[i] += d[i].item<
float>();
350 if (lossPerThreadStructOther)
353 for (
int i = 0; i < lossPerThreadStructOther->m_Losses.size(); ++i)
355 *
m_Losses[i] += *lossPerThreadStructOther->m_Losses[i];
363 using typename Superclass::FixedImageIndexType;
364 using typename Superclass::FixedImageIndexValueType;
365 using typename Superclass::MovingImageIndexType;
366 using typename Superclass::FixedImagePointType;
367 using typename Superclass::MovingImagePointType;
368 using typename Superclass::MovingImageContinuousIndexType;
369 using typename Superclass::BSplineInterpolatorType;
370 using typename Superclass::MovingImageDerivativeType;
371 using typename Superclass::NonZeroJacobianIndicesType;
390 const std::vector<std::vector<float>> & patchIndex)
const;
464 ComputeValue(
const std::vector<FixedImagePointType> & fixedPoints, LossPerThreadStruct & loss)
const;
478 ComputeValueStatic(
const std::vector<FixedImagePointType> & fixedPoints, LossPerThreadStruct & loss)
const;
494 LossPerThreadStruct & loss)
const;
510 LossPerThreadStruct & loss)
const;
535 using FixedInterpolatorType = BSplineInterpolateImageFunction<FixedImageType, CoordinateRepresentationType, double>;
541 BSplineInterpolateImageFunction<itk::Image<float, FixedImageDimension>, CoordinateRepresentationType,
float>>;
581 const std::vector<std::vector<float>> & patchIndex,
582 const std::vector<int64_t> & patchSize)
const;
599 const std::vector<std::vector<float>> & patchIndex,
600 const std::vector<int64_t> & patchSize)
const;
619 torch::Tensor & movingImagesPatchesJacobians,
620 const std::vector<std::vector<float>> & patchIndex,
621 const std::vector<int64_t> & patchSize,
638 template <
typename ImagePo
intType>
639 std::vector<ImagePointType>
641 std::mt19937 & randomGenerator,
642 const std::vector<ImagePointType> & fixedPointsTmp,
643 std::vector<std::vector<std::vector<std::vector<float>>>> & patchIndex)
const;
657 torch::Device
m_Device = torch::Device(torch::kCPU);
675 const auto interpolator = FixedInterpolatorType::New();
676 interpolator->SetSplineOrder(3);
693 std::vector<unsigned int>
694 GetSubsetOfFeatures(
const std::vector<unsigned int> & featuresIndex, std::mt19937 & randomGenerator,
int n)
const;
697 itkPadStruct(ITK_CACHE_LINE_ALIGNMENT, LossPerThreadStruct, PaddedLossPerThreadStruct);
699 itkAlignedTypedef(ITK_CACHE_LINE_ALIGNMENT, PaddedLossPerThreadStruct, AlignedLossPerThreadStruct);