Skip to content

Exporter

TorchScriptExporter

Exporter Class for turning torch models into torch script models. Requires a torch Module and corresponding DataLoader populated with complete data samples. This means the DataLoader has atleast one full batch sample that can be used as input to the torch Module in order to run torch.jit.trace()

Usage:

>>> exporter = TorchScriptExporter(model=BertQAModule(...), dataloader=dataloader, use_gpu=True)
>>> exporter.export()

Parameters:

  • model - Torch model/module, preferably from fastnn.nn
  • dataloader - Torch DataLoader that corresponds to the model
  • use_gpu - Bool for using gpu or cpu. If set True but no gpu devices available, model will default to using cpu
Source code in fastnn/exporting.py
class TorchScriptExporter:
    """Exporter Class for turning torch models into torch script models. Requires a torch Module and corresponding `DataLoader` populated with complete data samples.
    This means the `DataLoader` has atleast one full batch sample that can be used as input to the torch Module in order to run `torch.jit.trace()`

    Usage:
    ```python
    >>> exporter = TorchScriptExporter(model=BertQAModule(...), dataloader=dataloader, use_gpu=True)
    >>> exporter.export()
    ```

    **Parameters:**

    * **model** - Torch model/module, preferably from `fastnn.nn`
    * **dataloader** - Torch `DataLoader` that corresponds to the model
    * **use_gpu** - Bool for using gpu or cpu. If set True but no gpu devices available, model will default to using cpu

    """

    def __init__(
        self, model: torch.nn.Module, dataloader: DataLoader, use_gpu: bool = False
    ):
        if use_gpu:
            if torch.cuda.is_available():
                device = torch.device("cuda")
                logger.info(f"Torch model set to device {device}")
            else:
                device = torch.device("cpu")
                logger.info(f"GPU not available...device set to {device}")
        else:
            device = torch.device("cpu")
            logger.info(f"Torch model set to device {device}")

        self.model = model
        self.model.eval()
        self.model.to(device)
        self.dataloader = dataloader
        self.torchscript_model = None

    def export(self) -> torch.jit.ScriptModule:
        """Traces pytorch model and returns `ScriptModule` model"""
        batch_input = next(iter(self.dataloader))
        self.torchscript_model = torch.jit.trace(self.model, tuple(batch_input))
        return self.torchscript_model

    def serialize(self, file_path: Union[Path, str]):
        """Serialize and save model

        * **file_path** - String file path to save serialized torchscript model
        """
        if isinstance(file_path, str):
            file_path = Path(file_path)

        if self.torchscript_model:
            file_path.parent.mkdir(parents=True, exist_ok=True)
            self.torchscript_model.save(str(file_path))

export(self)

Traces pytorch model and returns ScriptModule model

Source code in fastnn/exporting.py
def export(self) -> torch.jit.ScriptModule:
    """Traces pytorch model and returns `ScriptModule` model"""
    batch_input = next(iter(self.dataloader))
    self.torchscript_model = torch.jit.trace(self.model, tuple(batch_input))
    return self.torchscript_model

serialize(self, file_path)

Serialize and save model

  • file_path - String file path to save serialized torchscript model
Source code in fastnn/exporting.py
def serialize(self, file_path: Union[Path, str]):
    """Serialize and save model

    * **file_path** - String file path to save serialized torchscript model
    """
    if isinstance(file_path, str):
        file_path = Path(file_path)

    if self.torchscript_model:
        file_path.parent.mkdir(parents=True, exist_ok=True)
        self.torchscript_model.save(str(file_path))