Skip to content



Exporter Class for turning torch models into torch script models. Requires a torch Module and corresponding DataLoader populated with complete data samples. This means the DataLoader has atleast one full batch sample that can be used as input to the torch Module in order to run torch.jit.trace()


>>> exporter = TorchScriptExporter(model=BertQAModule(...), dataloader=dataloader, use_gpu=True)
>>> exporter.export()


  • model - Torch model/module, preferably from fastnn.nn
  • dataloader - Torch DataLoader that corresponds to the model
  • use_gpu - Bool for using gpu or cpu. If set True but no gpu devices available, model will default to using cpu
Source code in fastnn/
class TorchScriptExporter:
    """Exporter Class for turning torch models into torch script models. Requires a torch Module and corresponding `DataLoader` populated with complete data samples.
    This means the `DataLoader` has atleast one full batch sample that can be used as input to the torch Module in order to run `torch.jit.trace()`

    >>> exporter = TorchScriptExporter(model=BertQAModule(...), dataloader=dataloader, use_gpu=True)
    >>> exporter.export()


    * **model** - Torch model/module, preferably from `fastnn.nn`
    * **dataloader** - Torch `DataLoader` that corresponds to the model
    * **use_gpu** - Bool for using gpu or cpu. If set True but no gpu devices available, model will default to using cpu


    def __init__(
        self, model: torch.nn.Module, dataloader: DataLoader, use_gpu: bool = False
        if use_gpu:
            if torch.cuda.is_available():
                device = torch.device("cuda")
      "Torch model set to device {device}")
                device = torch.device("cpu")
      "GPU not available...device set to {device}")
            device = torch.device("cpu")
  "Torch model set to device {device}")

        self.model = model
        self.dataloader = dataloader
        self.torchscript_model = None

    def export(self) -> torch.jit.ScriptModule:
        """Traces pytorch model and returns `ScriptModule` model"""
        batch_input = next(iter(self.dataloader))
        self.torchscript_model = torch.jit.trace(self.model, tuple(batch_input))
        return self.torchscript_model

    def serialize(self, file_path: Union[Path, str]):
        """Serialize and save model

        * **file_path** - String file path to save serialized torchscript model
        if isinstance(file_path, str):
            file_path = Path(file_path)

        if self.torchscript_model:
            file_path.parent.mkdir(parents=True, exist_ok=True)


Traces pytorch model and returns ScriptModule model

Source code in fastnn/
def export(self) -> torch.jit.ScriptModule:
    """Traces pytorch model and returns `ScriptModule` model"""
    batch_input = next(iter(self.dataloader))
    self.torchscript_model = torch.jit.trace(self.model, tuple(batch_input))
    return self.torchscript_model

serialize(self, file_path)

Serialize and save model

  • file_path - String file path to save serialized torchscript model
Source code in fastnn/
def serialize(self, file_path: Union[Path, str]):
    """Serialize and save model

    * **file_path** - String file path to save serialized torchscript model
    if isinstance(file_path, str):
        file_path = Path(file_path)

    if self.torchscript_model:
        file_path.parent.mkdir(parents=True, exist_ok=True)