go home Home | Main Page | Topics | Namespace List | Class Hierarchy | Alphabetical List | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages
itkImpactImageToImageMetric.h
Go to the documentation of this file.
1/*=========================================================================
2 *
3 * Copyright UMC Utrecht and contributors
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18
19#ifndef itkImpactImageToImageMetric_h
20#define itkImpactImageToImageMetric_h
21
23#include "itkBSplineInterpolateImageFunction.h"
26#include "ImpactTensorUtils.h"
27#include "ImpactLoss.h"
28
29#include <torch/script.h>
30#include <torch/torch.h>
31
32#include <string>
33#include <vector>
34#include <random>
35
36namespace itk
37{
38
56
57template <typename TFixedImage, typename TMovingImage>
58class ITK_TEMPLATE_EXPORT ImpactImageToImageMetric : public AdvancedImageToImageMetric<TFixedImage, TMovingImage>
59{
60public:
62
66 using Pointer = SmartPointer<Self>;
67 using ConstPointer = SmartPointer<const Self>;
68
70 itkNewMacro(Self);
71
74
76 using typename Superclass::CoordinateRepresentationType;
77 using typename Superclass::MovingImageType;
78 using typename Superclass::MovingImagePixelType;
79 using typename Superclass::MovingImageConstPointer;
80 using typename Superclass::FixedImageType;
81 using typename Superclass::FixedImageConstPointer;
82 using typename Superclass::FixedImageRegionType;
83 using typename Superclass::TransformType;
84 using typename Superclass::TransformPointer;
85 using typename Superclass::InputPointType;
86 using typename Superclass::OutputPointType;
87 using typename Superclass::TransformJacobianType;
89 using typename Superclass::InterpolatorType;
90 using typename Superclass::InterpolatorPointer;
91 using typename Superclass::RealType;
92 using typename Superclass::GradientPixelType;
93 using typename Superclass::GradientImageType;
94 using typename Superclass::GradientImagePointer;
95 using typename Superclass::FixedImageMaskType;
99 using typename Superclass::MeasureType;
100 using typename Superclass::DerivativeType;
101 using typename Superclass::DerivativeValueType;
102 using typename Superclass::ParametersType;
103 using typename Superclass::FixedImagePixelType;
105 using typename Superclass::ImageSamplerType;
106 using typename Superclass::ImageSamplerPointer;
114 using typename Superclass::ThreadInfoType;
115
117 itkStaticConstMacro(FixedImageDimension, unsigned int, FixedImageType::ImageDimension);
118
120 itkStaticConstMacro(MovingImageDimension, unsigned int, MovingImageType::ImageDimension);
121
126 virtual MeasureType
127 GetValueSingleThreaded(const ParametersType & parameters) const;
128
133 MeasureType
134 GetValue(const ParametersType & parameters) const override;
135
139 void
140 GetDerivative(const ParametersType & parameters, DerivativeType & derivative) const override;
141
144 void
145 GetValueAndDerivativeSingleThreaded(const ParametersType & parameters,
146 MeasureType & value,
147 DerivativeType & derivative) const;
148
153 void
154 GetValueAndDerivative(const ParametersType & parameters,
155 MeasureType & value,
156 DerivativeType & derivative) const override;
157
162 void
163 Initialize() override;
164
168 itkSetMacro(FixedModelsConfiguration, std::vector<ImpactModelConfiguration>);
169 itkGetConstReferenceMacro(FixedModelsConfiguration, std::vector<ImpactModelConfiguration>);
170
174 itkSetMacro(MovingModelsConfiguration, std::vector<ImpactModelConfiguration>);
175 itkGetConstReferenceMacro(MovingModelsConfiguration, std::vector<ImpactModelConfiguration>);
176
180 itkSetMacro(SubsetFeatures, std::vector<unsigned int>);
181 itkGetConstMacro(SubsetFeatures, std::vector<unsigned int>);
182
186 itkSetMacro(LayersWeight, std::vector<float>);
187 itkGetConstMacro(LayersWeight, std::vector<float>);
188
189
193 itkSetMacro(Distance, std::vector<std::string>);
194 itkGetConstMacro(Distance, std::vector<std::string>);
195
199 itkSetMacro(PCA, std::vector<unsigned int>);
200 itkGetConstMacro(PCA, std::vector<unsigned int>);
201
205 itkSetMacro(Device, torch::Device);
206 itkGetConstMacro(Device, torch::Device);
207
211 itkSetMacro(UseMixedPrecision, bool);
212 itkGetConstMacro(UseMixedPrecision, bool);
213
217 itkSetMacro(WriteFeatureMaps, bool);
218 itkGetConstMacro(WriteFeatureMaps, bool);
219
223 itkSetMacro(FeatureMapsPath, std::string);
224 itkGetConstMacro(FeatureMapsPath, std::string);
225
230 itkSetMacro(Mode, std::string);
231 itkGetConstMacro(Mode, std::string);
232
235 itkSetMacro(CurrentLevel, unsigned int);
236 itkGetConstMacro(CurrentLevel, unsigned int);
237
240 itkSetMacro(Seed, unsigned int);
241 itkGetConstMacro(Seed, unsigned int);
242
246 itkSetMacro(FeaturesMapUpdateInterval, int);
247 itkGetConstMacro(FeaturesMapUpdateInterval, int);
248
249protected:
251 ~ImpactImageToImageMetric() override = default;
252
257 void
259
261
274 {
275 std::vector<std::unique_ptr<ImpactLoss::Loss>> m_Losses;
276 std::vector<float> m_LayersWeight;
279 std::mt19937 m_RandomGenerator;
280
281 void
282 init(std::vector<std::string> distanceName, std::vector<float> layersWeight, unsigned int seed)
283 {
284 if (seed > 0)
285 {
286 m_RandomGenerator = std::mt19937(seed);
287 }
288 else
289 {
290 m_RandomGenerator = std::mt19937(time(nullptr));
291 }
292 m_LayersWeight = layersWeight;
293 for (std::string name : distanceName)
294 {
295 m_Losses.push_back(ImpactLoss::LossFactory::Instance().Create(name));
296 }
297 }
298
299 void
300 setNumberOfParameters(int numberOfParameters)
301 {
302 m_NumberOfParameters = numberOfParameters;
303 for (int l = 0; l < m_LayersWeight.size(); ++l)
304 {
305 m_Losses[l]->setNumberOfParameters(numberOfParameters);
306 }
307 }
308
309 void
311 {
313 for (std::unique_ptr<ImpactLoss::Loss> & loss : m_Losses)
314 {
315 loss->reset();
316 }
317 }
318
319 double
321 {
322 MeasureType value = MeasureType{};
323 for (int l = 0; l < m_LayersWeight.size(); ++l)
324 {
325 value += m_LayersWeight[l] * m_Losses[l]->GetValue(static_cast<double>(m_NumberOfPixelsCounted));
326 }
327 return value;
328 }
329
330 DerivativeType
332 {
333 DerivativeType derivative = DerivativeType(m_NumberOfParameters);
334 derivative.Fill(DerivativeValueType{});
335 for (int l = 0; l < m_LayersWeight.size(); ++l)
336 {
337 torch::Tensor d = m_LayersWeight[l] * m_Losses[l]->GetDerivative(static_cast<double>(m_NumberOfPixelsCounted));
338 for (int i = 0; i < d.size(0); ++i)
339 {
340 derivative[i] += d[i].item<float>();
341 }
342 }
343 return derivative;
344 }
345
348 {
349 const auto * lossPerThreadStructOther = dynamic_cast<const LossPerThreadStruct *>(&other);
350 if (lossPerThreadStructOther)
351 {
352 m_NumberOfPixelsCounted += lossPerThreadStructOther->m_NumberOfPixelsCounted;
353 for (int i = 0; i < lossPerThreadStructOther->m_Losses.size(); ++i)
354 {
355 *m_Losses[i] += *lossPerThreadStructOther->m_Losses[i];
356 }
357 }
358 return *this;
359 }
360 };
361
363 using typename Superclass::FixedImageIndexType;
364 using typename Superclass::FixedImageIndexValueType;
365 using typename Superclass::MovingImageIndexType;
366 using typename Superclass::FixedImagePointType;
367 using typename Superclass::MovingImagePointType;
368 using typename Superclass::MovingImageContinuousIndexType;
369 using typename Superclass::BSplineInterpolatorType;
370 using typename Superclass::MovingImageDerivativeType;
371 using typename Superclass::NonZeroJacobianIndicesType;
372
388 bool
389 SampleCheck(const FixedImagePointType & fixedImageCenterCoordinate,
390 const std::vector<std::vector<float>> & patchIndex) const;
391
402 bool
403 SampleCheck(const FixedImagePointType & fixedImageCenterCoordinate) const;
404
413 void
414 ThreadedGetValue(ThreadIdType threadID) const override;
415
425 void
426 AfterThreadedGetValue(MeasureType & value) const override;
427
435 void
436 ThreadedGetValueAndDerivative(ThreadIdType threadID) const override;
437
448 void
449 AfterThreadedGetValueAndDerivative(MeasureType & value, DerivativeType & derivative) const override;
450
463 unsigned int
464 ComputeValue(const std::vector<FixedImagePointType> & fixedPoints, LossPerThreadStruct & loss) const;
465
477 unsigned int
478 ComputeValueStatic(const std::vector<FixedImagePointType> & fixedPoints, LossPerThreadStruct & loss) const;
479
492 unsigned int
493 ComputeValueAndDerivativeJacobian(const std::vector<FixedImagePointType> & fixedPoints,
494 LossPerThreadStruct & loss) const;
495
508 unsigned int
509 ComputeValueAndDerivativeStatic(const std::vector<FixedImagePointType> & fixedPoints,
510 LossPerThreadStruct & loss) const;
511
512
520 void
522
530 void
532
533private:
535 using FixedInterpolatorType = BSplineInterpolateImageFunction<FixedImageType, CoordinateRepresentationType, double>;
537 using FeaturesImageType = itk::VectorImage<float, FixedImageDimension>;
541 BSplineInterpolateImageFunction<itk::Image<float, FixedImageDimension>, CoordinateRepresentationType, float>>;
542
552 {
553 typename FeaturesImageType::Pointer m_FeaturesMaps;
555
556 FeaturesMaps(typename FeaturesImageType::Pointer featuresMaps)
557 : m_FeaturesMaps(featuresMaps)
558 {
560 m_FeaturesMapsInterpolator.SetInputImage(featuresMaps);
561 }
562 };
563
565
579 torch::Tensor
580 EvaluateFixedImagesPatchValue(const FixedImagePointType & fixedImageCenterCoordinate,
581 const std::vector<std::vector<float>> & patchIndex,
582 const std::vector<int64_t> & patchSize) const;
583
597 torch::Tensor
598 EvaluateMovingImagesPatchValue(const FixedImagePointType & fixedImageCenterCoordinate,
599 const std::vector<std::vector<float>> & patchIndex,
600 const std::vector<int64_t> & patchSize) const;
601
617 torch::Tensor
619 torch::Tensor & movingImagesPatchesJacobians,
620 const std::vector<std::vector<float>> & patchIndex,
621 const std::vector<int64_t> & patchSize,
622 int s) const;
623
638 template <typename ImagePointType>
639 std::vector<ImagePointType>
640 GeneratePatchIndex(const std::vector<ImpactModelConfiguration> & modelConfig,
641 std::mt19937 & randomGenerator,
642 const std::vector<ImagePointType> & fixedPointsTmp,
643 std::vector<std::vector<std::vector<std::vector<float>>>> & patchIndex) const;
644
646 std::vector<ImpactModelConfiguration> m_FixedModelsConfiguration;
647 std::vector<ImpactModelConfiguration> m_MovingModelsConfiguration;
648
649 std::vector<unsigned int> m_SubsetFeatures;
650 std::vector<unsigned int> m_PCA;
651 std::vector<float> m_LayersWeight;
652 std::vector<std::string> m_Distance;
654 std::string m_Mode;
656 std::string m_FeatureMapsPath;
657 torch::Device m_Device = torch::Device(torch::kCPU);
659 unsigned int m_CurrentLevel;
660 unsigned int m_Seed;
661
662
663 std::vector<FeaturesMaps> m_FixedFeaturesMaps;
664 std::vector<FeaturesMaps> m_MovingFeaturesMaps;
665 std::vector<torch::Tensor> m_PrincipalComponents;
666
667 std::vector<std::vector<unsigned int>> m_FeaturesIndexes;
668
669
674 InterpolatorPointer m_FixedInterpolator = [this] {
675 const auto interpolator = FixedInterpolatorType::New();
676 interpolator->SetSplineOrder(3);
677 return interpolator;
678 }();
679
693 std::vector<unsigned int>
694 GetSubsetOfFeatures(const std::vector<unsigned int> & featuresIndex, std::mt19937 & randomGenerator, int n) const;
695
697 itkPadStruct(ITK_CACHE_LINE_ALIGNMENT, LossPerThreadStruct, PaddedLossPerThreadStruct);
698
699 itkAlignedTypedef(ITK_CACHE_LINE_ALIGNMENT, PaddedLossPerThreadStruct, AlignedLossPerThreadStruct);
700
702 mutable std::unique_ptr<AlignedLossPerThreadStruct[]> m_LossThreadStruct{ nullptr };
703
704 mutable int m_LossThreadStructSize = 0;
705};
706
707} // end namespace itk
708
709#ifndef ITK_MANUAL_INSTANTIATION
710# include "itkImpactImageToImageMetric.hxx"
711#endif
712
713#endif // end #ifndef itkImpactImageToImageMetric_h
Utilities for converting ITK images to Torch tensors and extracting features using TorchScript models...
Helper class to interpolate each component of a VectorImage using separate B-Spline interpolators.
static LossFactory & Instance()
Definition ImpactLoss.h:173
typename ImageSamplerType::OutputVectorContainerPointer ImageSampleContainerPointer
typename MovingImageType::RegionType MovingImageRegionType
FixedArray< double, Self::MovingImageDimension > MovingImageDerivativeScalesType
typename FixedImageType::PixelType FixedImagePixelType
typename DerivativeType::ValueType DerivativeValueType
typename ImageSamplerType::OutputVectorContainerType ImageSampleContainerType
ImageMaskSpatialObject< Self::FixedImageDimension > FixedImageMaskType
SmartPointer< MovingImageMaskType > MovingImageMaskPointer
LimiterFunctionBase< RealType, FixedImageDimension > FixedImageLimiterType
ImageSamplerBase< FixedImageType > ImageSamplerType
MultiThreaderBase::WorkUnitInfo ThreadInfoType
LimiterFunctionBase< RealType, MovingImageDimension > MovingImageLimiterType
typename MovingImageLimiterType::OutputType MovingImageLimiterOutputType
typename FixedImageLimiterType::OutputType FixedImageLimiterOutputType
SmartPointer< FixedImageMaskType > FixedImageMaskPointer
typename ImageSamplerType::Pointer ImageSamplerPointer
typename AdvancedTransformType::NumberOfParametersType NumberOfParametersType
ImageMaskSpatialObject< Self::MovingImageDimension > MovingImageMaskType
MeasureType GetValue(const ParametersType &parameters) const override
itkAlignedTypedef(ITK_CACHE_LINE_ALIGNMENT, PaddedLossPerThreadStruct, AlignedLossPerThreadStruct)
torch::Tensor EvaluateMovingImagesPatchValuesAndJacobians(const FixedImagePointType &fixedImageCenterCoordinate, torch::Tensor &movingImagesPatchesJacobians, const std::vector< std::vector< float > > &patchIndex, const std::vector< int64_t > &patchSize, int s) const
Extracts moving image patch values and computes the spatial Jacobians with respect to image coordinat...
typename DerivativeType::ValueType DerivativeValueType
std::vector< ImagePointType > GeneratePatchIndex(const std::vector< ImpactModelConfiguration > &modelConfig, std::mt19937 &randomGenerator, const std::vector< ImagePointType > &fixedPointsTmp, std::vector< std::vector< std::vector< std::vector< float > > > > &patchIndex) const
Generates valid patch indices and filters out invalid points.
virtual MeasureType GetValueSingleThreaded(const ParametersType &parameters) const
void ThreadedGetValueAndDerivative(ThreadIdType threadID) const override
Computes both the similarity value and its gradient for a given thread.
void GetDerivative(const ParametersType &parameters, DerivativeType &derivative) const override
void GetValueAndDerivative(const ParametersType &parameters, MeasureType &value, DerivativeType &derivative) const override
torch::Tensor EvaluateMovingImagesPatchValue(const FixedImagePointType &fixedImageCenterCoordinate, const std::vector< std::vector< float > > &patchIndex, const std::vector< int64_t > &patchSize) const
Extracts a moving image patch tensor (intensity values) corresponding to a fixed point.
void UpdateFixedFeaturesMaps()
Updates the fixed feature maps in static mode.
torch::Tensor EvaluateFixedImagesPatchValue(const FixedImagePointType &fixedImageCenterCoordinate, const std::vector< std::vector< float > > &patchIndex, const std::vector< int64_t > &patchSize) const
Extracts a fixed image patch tensor centered at a given point using the precomputed patch index.
itkStaticConstMacro(FixedImageDimension, unsigned int, FixedImageType::ImageDimension)
void AfterThreadedGetValue(MeasureType &value) const override
Combines the similarity values computed by all threads.
BSplineInterpolateVectorImageFunction< FeaturesImageType, BSplineInterpolateImageFunction< itk::Image< float, FixedImageDimension >, CoordinateRepresentationType, float > > FeaturesInterpolatorType
ITK_DISALLOW_COPY_AND_MOVE(ImpactImageToImageMetric)
unsigned int ComputeValueAndDerivativeStatic(const std::vector< FixedImagePointType > &fixedPoints, LossPerThreadStruct &loss) const
Computes the value and derivative in static mode (using precomputed features).
void UpdateMovingFeaturesMaps()
Updates the moving feature maps in static mode.
void GetValueAndDerivativeSingleThreaded(const ParametersType &parameters, MeasureType &value, DerivativeType &derivative) const
unsigned int ComputeValueStatic(const std::vector< FixedImagePointType > &fixedPoints, LossPerThreadStruct &loss) const
Computes the semantic similarity value in static mode (using precomputed feature maps).
unsigned int ComputeValue(const std::vector< FixedImagePointType > &fixedPoints, LossPerThreadStruct &loss) const
Computes the semantic similarity value using the current transform parameters.
bool SampleCheck(const FixedImagePointType &fixedImageCenterCoordinate, const std::vector< std::vector< float > > &patchIndex) const
Checks if a patch centered at the given fixed image point is valid for sampling.
~ImpactImageToImageMetric() override=default
AdvancedImageToImageMetric< typename MetricBase< TElastix >::FixedImageType, typename MetricBase< TElastix >::MovingImageType > Superclass
unsigned int ComputeValueAndDerivativeJacobian(const std::vector< FixedImagePointType > &fixedPoints, LossPerThreadStruct &loss) const
Computes both the semantic similarity value and its derivative using Jacobian mode.
bool SampleCheck(const FixedImagePointType &fixedImageCenterCoordinate) const
Checks if the fixed image point lies within valid bounds for sampling.
itkStaticConstMacro(MovingImageDimension, unsigned int, MovingImageType::ImageDimension)
void ThreadedGetValue(ThreadIdType threadID) const override
Computes the similarity value contribution for a given thread.
itkPadStruct(ITK_CACHE_LINE_ALIGNMENT, LossPerThreadStruct, PaddedLossPerThreadStruct)
void AfterThreadedGetValueAndDerivative(MeasureType &value, DerivativeType &derivative) const override
Combines the values and gradients computed by all threads.
void InitializeThreadingParameters() const override
std::vector< unsigned int > GetSubsetOfFeatures(const std::vector< unsigned int > &featuresIndex, std::mt19937 &randomGenerator, int n) const
Retrieves a subset of features based on the provided indices.
Encapsulates a feature map image and its associated interpolator.
FeaturesMaps(typename FeaturesImageType::Pointer featuresMaps)
Thread-local structure for accumulating loss values and gradients for each layer.
LossPerThreadStruct & operator+=(const LossPerThreadStruct &other)
void init(std::vector< std::string > distanceName, std::vector< float > layersWeight, unsigned int seed)
std::vector< std::unique_ptr< ImpactLoss::Loss > > m_Losses


Generated on 26-02-2026 for elastix by doxygen 1.16.1 (669aeeefca743c148e2d935b3d3c69535c7491e6) elastix logo