Source code for medigan.execute_model.synthetic_dataset

# -*- coding: utf-8 -*-
# ! /usr/bin/env python
""" `SyntheticDataset` allows to return a generative model as torch dataset. """

from torch.utils.data import Dataset


[docs]class SyntheticDataset(Dataset): """A synthetic dataset containing data generated by a model of medigan Parameters ---------- samples: list List of data points in the dataset e.g. generated images as numpy array. masks: list List of segmentation masks, if applicable, pertaining to the `samples` items other_imaging_output: list List of other imaging output produced by the generative model (e.g. specific types of other masks/modalities) labels: list list of labels, if applicable, pertaining to the `samples` items transform: torch compose transform functions that are applied to the torch dataset. Attributes ---------- samples: list List of data points in the dataset e.g. generated images as numpy array. masks: list List of segmentation masks, if applicable, pertaining to the `samples` items other_imaging_output: list List of other imaging output produced by the generative model (e.g. specific types of other masks/modalities) labels: list list of labels, if applicable, pertaining to the `samples` items transform: torch compose transform functions that are applied to the torch dataset. """ def __init__( self, samples, masks=None, other_imaging_output=None, labels=None, transform=None, ): self.samples = samples self.masks = masks self.other_imaging_output = other_imaging_output self.labels = labels self.transform = transform def __getitem__(self, index): x = self.samples[index] y = self.labels[index] if self.labels is not None else None mask = self.masks[index] if self.masks is not None else None other_imaging_output = ( self.other_imaging_output[index] if self.other_imaging_output is not None else None ) if self.transform: if mask is not None: if other_imaging_output is not None: x, mask, other_imaging_output = self.transform( x, mask, other_imaging_output ) # transform needs to be applied to both mask and image. x, mask = self.transform(x, mask) elif other_imaging_output is not None: x, other_imaging_output = self.transform(x, other_imaging_output) else: x = self.transform(x) item = {"sample": x} # extendable dictionary if y is not None: item["label"] = y if mask is not None: item["mask"] = mask if other_imaging_output is not None: item["other_imaging_output"] = other_imaging_output return item def __len__(self): return len(self.samples)