diff --git a/CHANGELOG.md b/CHANGELOG.md index b3cadca2..351247db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,23 +1,22 @@ ## Added -1. C++ and CUDA bindings for `memtorch.bh.crossbar.Tile.tile_matmul`. - -Using an NVIDIA GeForce GTX 1080, a tile shape of (25, 25), and two tensors of size (500, 500), the runtime of `tile_matmul` without quantization support is reduced by 2.45x and 5.48x, for CPU-bound and GPU-bound operation, respectively. With an ADC resolution of 4 bits and an overflow rate of 0.0, the runtime of `tile_matmul` with quantization support is reduced by 2.30x and 105.27x, for CPU-bound and GPU-bound operation, respectively. - -| Implementation | Runtime Without Quantization Support (s) | Runtime With Quantization Support (s) | -| ---------------------- | ---------------------------------------- | ------------------------------------- | -| Pure Python (Previous) | 6.917784 | 27.099764 | -| C++ (CPU-bound) | 2.822265 | 11.736974 | -| CUDA (GPU-bound) | 1.262861 | 0.2574267 | - -3. `Eigen` integration with C++ and CUDA bindings. -4. Additional unit tests. +1. Added another version of the Data Driven Model defined using `memtorch.bh.memrsitor.Data_Driven2021`. +2. Added CPU- and GPU-bound C++ bindings for `gen_tiles`. +3. Exposed `use_bindings`. +4. Added unit tests for `use_bindings`. +5. Added `exemptAssignees` tag to `scale.yml`. +6. Created `memtorch.map.Input` to encapsulate customizable input scaling methods. +7. Added the `force_scale` input argument to the default scaling method to specify whether inputs are force scaled if they do not exceed `max_input_voltage`. +8. Added CPU and GPU bindings for `tiled_inference`. ## Enhanced -1. Modularized C++ and CUDA `quantize` bindings. -2. Enhanced functionality of `naive_progam` and added additional input arguments to dictate logic for stuck devices. +1. Modularized input scaling logic for all layer types. +2. Modularized `tile_inference` for all layer types. +3. Updated ReadTheDocs documentation. ## Fixed -1. Removed debugging code from `naive_progam`. +1. Fixed GitHub Action Workflows for external pull requests. +2. Fixed error raised by `memtorch.map.Parameter` when `p_l` is defined. +3. Fixed semantic error in `memtorch.cpp.gen_tiles`. diff --git a/docs/conf.py b/docs/conf.py index cf9fe943..13b92f5d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -21,7 +21,7 @@ author = "Corey Lammie" # The full version, including alpha/beta/rc tags -release = "1.1.2" +release = "1.1.3" autodoc_inherit_docstrings = False # -- General configuration --------------------------------------------------- @@ -72,3 +72,5 @@ html_css_files = [ "my_theme.css", ] + +pygments_style = "autumn" diff --git a/docs/index.rst b/docs/index.rst index 6b4c6153..fbaf3759 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -5,12 +5,21 @@ :github_url: https://github.com/coreylammie/MemTorch -MemTorch documentation +MemTorch ==================================== `MemTorch `_ is a simulation framework for memristive deep learning systems that integrates directly with the well-known PyTorch Machine Learning (ML) library. +MemTorch is formally described in *MemTorch: An Open-source Simulation Framework for Memristive Deep Learning Systems*, which is openly accessible `here `_. + +.. image:: https://raw.githubusercontent.com/coreylammie/MemTorch/master/overview.svg?raw=True + +The best place to get started is `here `__. + +Documentation +==================================== +We provide documentation in the form of a complete Python API, and numerous interactive tutorials. In addition, a Gitter chatroom is avaliable for discussions: .. toctree:: - :maxdepth: 4 + :maxdepth: 3 memtorch tutorials diff --git a/docs/memtorch.bh.memristor.rst b/docs/memtorch.bh.memristor.rst index 1aae7dae..d0d79e7e 100644 --- a/docs/memtorch.bh.memristor.rst +++ b/docs/memtorch.bh.memristor.rst @@ -1,6 +1,6 @@ memtorch.bh.memristor ===================== -Submodule containing various behavioral memristor models, that extend :ref:`base-class-label`. +Submodule containing various behavioral memristor models, that extend :class:`memtorch.bh.memristor.Memristor`. .. automodule:: memtorch.bh.memristor.window :members: @@ -9,12 +9,15 @@ Submodule containing various behavioral memristor models, that extend :ref:`base memtorch.bh.memristor.Memristor ------------------------------- +Base class used to model memristive device behavior. .. automodule:: memtorch.bh.memristor.Memristor :members: :undoc-members: :show-inheritance: +Currently supported memristor models are listed below: + memtorch.bh.memristor.LinearIonDrift ------------------------------------ @@ -39,6 +42,14 @@ memtorch.bh.memristor.Data_Driven :undoc-members: :show-inheritance: +memtorch.bh.memristor.Data_Driven2021 +--------------------------------- + +.. automodule:: memtorch.bh.memristor.Data_Driven2021 + :members: + :undoc-members: + :show-inheritance: + memtorch.bh.memristor.Stanford_PKU ---------------------------------- diff --git a/docs/memtorch.bh.nonideality.rst b/docs/memtorch.bh.nonideality.rst index b4c06841..5de78ecc 100644 --- a/docs/memtorch.bh.nonideality.rst +++ b/docs/memtorch.bh.nonideality.rst @@ -1,12 +1,40 @@ memtorch.bh.nonideality ======================= -Submodule containing various models, which can be used to introduce various non-ideal device characteristics using `memtorch.bh.nonideality.NonIdeality.apply_nonidealities`. - -.. toctree:: - memtorch.bh.nonideality.endurance_retention_models +Submodule containing various models, which can be used to introduce various non-ideal device characteristics using :class:`memtorch.bh.nonideality.NonIdeality.apply_nonidealities`. memtorch.bh.nonideality.NonIdeality ----------------------------------- +Class used to introduce/model non-ideal device and circuit characteristics. :class:`patched_model.apply_nonidealities` is commonly used to introduce such characteristics, as demonstrated by the following example: + +.. code-block:: python + + import copy + import Net + from memtorch.mn.Module import patch_model + from memtorch.map.Parameter import naive_map + from memtorch.map.Input import naive_scale + + model = Net() + reference_memristor = memtorch.bh.memristor.VTEAM + patched_model = patch_model(copy.deepcopy(model), + memristor_model=reference_memristor, + memristor_model_params={}, + module_parameters_to_patch=[torch.nn.Linear, torch.nn.Conv2d], + mapping_routine=naive_map, + transistor=True, + programming_routine=None, + tile_shape=(128, 128), + max_input_voltage=0.3, + scaling_routine=naive_scale, + ADC_resolution=8, + ADC_overflow_rate=0., + quant_method='linear') + # Example usage of memtorch.bh.nonideality.NonIdeality.DeviceFaults + patched_model = patched_model.apply_nonidealities(patched_model, + non_idealities=[memtorch.bh.nonideality.NonIdeality.DeviceFaults], + lrs_proportion=0.25, + hrs_proportion=0.10, + electroform_proportion=0) .. automodule:: memtorch.bh.nonideality.NonIdeality :members: @@ -15,7 +43,6 @@ memtorch.bh.nonideality.NonIdeality memtorch.bh.nonideality.FiniteConductanceStates ----------------------------------------------- -Used to model a finite number of conductance states. .. automodule:: memtorch.bh.nonideality.FiniteConductanceStates :members: @@ -24,6 +51,7 @@ Used to model a finite number of conductance states. memtorch.bh.nonideality.DeviceFaults ------------------------------------ +Methods used to model device faults. .. automodule:: memtorch.bh.nonideality.DeviceFaults :members: @@ -52,4 +80,21 @@ memtorch.bh.nonideality.Retention .. automodule:: memtorch.bh.nonideality.Retention :members: :undoc-members: - :show-inheritance: \ No newline at end of file + :show-inheritance: + +For both :class:`memtorch.bh.nonideality.Endurance` and :class:`memtorch.bh.nonideality.Retention`, the following internal endurance and retention models are natively supported: + +memtorch.bh.nonideality.endurance_retention_models.conductance_drift +-------------------------------------------------------------------- +.. automodule:: memtorch.bh.nonideality.endurance_retention_models.conductance_drift + :members: + :undoc-members: + :show-inheritance: + +memtorch.bh.nonideality.endurance_retention_models.empirical_metal_oxide_RRAM +----------------------------------------------------------------------------- +.. automodule:: memtorch.bh.nonideality.endurance_retention_models.empirical_metal_oxide_RRAM + :members: + :undoc-members: + :show-inheritance: + \ No newline at end of file diff --git a/docs/memtorch.bh.rst b/docs/memtorch.bh.rst index b942b57b..444c3b3a 100644 --- a/docs/memtorch.bh.rst +++ b/docs/memtorch.bh.rst @@ -1,41 +1,60 @@ memtorch.bh =========== -Submodule containing various behavioral models. +Submodule containing various memristive device behavioral models and methods to simualte non-ideal device and circuit behavior. -.. toctree:: - memtorch.bh.memristor - memtorch.bh.nonideality +memtorch.bh.memristor +--------------------- +All memristor models and window functions are encapsulated and documented in :doc:`memtorch.bh.memristor <../memtorch.bh.memristor>`. + +memtorch.bh.nonideality +----------------------- +All non-idealities modelled by MemTorch are encapsulated and documented in :doc:`memtorch.bh.nonideality <../memtorch.bh.nonideality>`. memtorch.bh.crossbar.Crossbar ----------------------------- -Class used to model memristor crossbars. +Class used to model memristor crossbars and to manage modular crossbar tiles. + +.. code-block:: python + + import torch + import memtorch + + crossbar = memtorch.bh.crossbar.Crossbar(memtorch.bh.memristor.VTEAM, + {"r_on": 1e2, "r_off": 1e4}, + shape=(100, 100), + tile_shape=(64, 64)) + crossbar.write_conductance_matrix(torch.zeros(100, 100).uniform_(1e-2, 1e-4), transistor=True) + crossbar.devices[0][0][0].set_conductance(1e-4) + crossbar.update(from_devices=True, parallelize=True) + +.. note:: + **use_bindings** is enabled by default, to accelerate operation using C++/CUDA (if supported) bindings. .. automodule:: memtorch.bh.crossbar.Crossbar :members: :undoc-members: :show-inheritance: -memtorch.bh.crossbar.Tile -------------------------- -Class used to create modular crossbar tiles to represent 2D matrices. +memtorch.bh.crossbar.Program +---------------------------- +Methods to program (alter) the conductance devices within a crossbar or modular crossbar tiles. -.. automodule:: memtorch.bh.crossbar.Tile +.. automodule:: memtorch.bh.crossbar.Program :members: :undoc-members: :show-inheritance: -memtorch.bh.crossbar.Program ----------------------------- -Methods to program (alter) the conductance devices within a crossbar. +memtorch.bh.crossbar.Tile +------------------------- -.. automodule:: memtorch.bh.crossbar.Program +.. automodule:: memtorch.bh.crossbar.Tile :members: :undoc-members: :show-inheritance: memtorch.bh.Quantize -------------------- -Wrapper for the pytorch-playground quant.py script. +Wrapper for C++ quantization bindings. .. automodule:: memtorch.bh.Quantize :members: @@ -44,7 +63,19 @@ Wrapper for the pytorch-playground quant.py script. memtorch.bh.StochasticParameter ------------------------------- -Methods to model stochastic parameters. +Methods to model stochastic parameters. + +**memtorch.bh.StochasticParameter** is most commonly used to define stochastic parameters when defining behavioural memristor models, as follows: + +.. code-block:: python + + import torch + import memtorch + + crossbar = memtorch.bh.crossbar.Crossbar(memtorch.bh.memristor.VTEAM, + {"r_on": memtorch.bh.StochasticParameter(min=1e3, max=1e2), "r_off": 1e4}, + shape=(100, 100), + tile_shape=(64, 64)) .. automodule:: memtorch.bh.StochasticParameter :members: diff --git a/docs/memtorch.cpp.rst b/docs/memtorch.cpp.rst new file mode 100644 index 00000000..e69de29b diff --git a/docs/memtorch.cu.rst b/docs/memtorch.cu.rst new file mode 100644 index 00000000..e69de29b diff --git a/docs/memtorch.map.rst b/docs/memtorch.map.rst index 44dab0dc..17a25e97 100644 --- a/docs/memtorch.map.rst +++ b/docs/memtorch.map.rst @@ -1,10 +1,54 @@ memtorch.map ============ -Submodule containing various mapping algorithms. +Submodule containing various mapping, scaling, and encoding methods. + +memtorch.map.Input +------------------- +Encapsulates internal methods to encode (scale) input values as bit-line voltages. Methods can either be specified when converting individual layers: + +.. code-block:: python + + from memtorch.map.Input import naive_scale + + m = memtorch.mn.Linear(torch.nn.Linear(10, 10), + memtorch.bh.memristor.VTEAM, + {}, + tile_shape=(64, 64), + scaling_routine=naive_scale) + +or when converting :class:`torch.nn.Module` instances: + +.. code-block:: python + + import copy + from memtorch.mn.Module import patch_model + from memtorch.map.Input import naive_scale + import Net + + model = Net() + patched_model = patch_model(copy.deepcopy(model), + memtorch.bh.memristor.VTEAM, + {}, + module_parameters_to_patch=[torch.nn.Linear], + scaling_routine=naive_scale) + +.. automodule:: memtorch.map.Input + :members: + :undoc-members: + :show-inheritance: + +.. note:: + **force_scale** is used to specify whether inputs smaller than or equal to **max_input_voltage** are scaled or not. memtorch.map.Module ------------------- -Methods to determine relationships between a memristive crossbar and the output for a given memristive module. +Encapsulates internal methods to determine relationships between readout currents of memristive crossbars and desired outputs. + +.. warning:: + Currently, only **naive_tune** is supported. In a future release, externally-defined methods will be supported. + + + .. automodule:: memtorch.map.Module :members: @@ -13,7 +57,33 @@ Methods to determine relationships between a memristive crossbar and the output memtorch.map.Parameter ---------------------- -Methods to naively map network parameters to memristive device conductance's. +Encapsulates internal methods to naively map network parameters to memristive device conductance values. Methods can either be specified when converting individual layers: + +.. code-block:: python + + from memtorch.map.Parameter import naive_map + + m = memtorch.mn.Linear(torch.nn.Linear(10, 10), + memtorch.bh.memristor.VTEAM, + {}, + tile_shape=(64, 64), + mapping_routine=naive_map) + +or when converting :class:`torch.nn.Module` instances: + +.. code-block:: python + + import copy + from memtorch.mn.Module import patch_model + from memtorch.map.Parameter import naive_map + import Net + + model = Net() + patched_model = patch_model(copy.deepcopy(model), + memtorch.bh.memristor.VTEAM, + {}, + module_parameters_to_patch=[torch.nn.Linear], + mapping_routine=naive_map) .. automodule:: memtorch.map.Parameter :members: diff --git a/docs/memtorch.mn.rst b/docs/memtorch.mn.rst index fab89bfc..adfe8500 100644 --- a/docs/memtorch.mn.rst +++ b/docs/memtorch.mn.rst @@ -1,19 +1,48 @@ memtorch.mn =========== -torch.nn equivalent submodule. +Memristive `torch.nn `_ equivalent submodule. memtorch.mn.Module ------------------ -Methods to convert and tune torch.nn models. +Encapsulates :class:`memtorch.bmn.Module.patch_model`, which can be used to convert `torch.nn `_ models. + +.. code-block:: python + + import copy + import Net + from memtorch.mn.Module import patch_model + from memtorch.map.Parameter import naive_map + from memtorch.map.Input import naive_scale + + model = Net() + reference_memristor = memtorch.bh.memristor.VTEAM + patched_model = patch_model(copy.deepcopy(model), + memristor_model=reference_memristor, + memristor_model_params={}, + module_parameters_to_patch=[torch.nn.Linear, torch.nn.Conv2d], + mapping_routine=naive_map, + transistor=True, + programming_routine=None, + tile_shape=(128, 128), + max_input_voltage=0.3, + scaling_routine=naive_scale, + ADC_resolution=8, + ADC_overflow_rate=0., + quant_method='linear') + +.. warning:: + It is strongly suggested to copy the original model using **copy.deepcopy** prior to conversion, as some values are overriden by-reference. .. automodule:: memtorch.mn.Module :members: :undoc-members: :show-inheritance: +The following layer/module types are currently supported: + memtorch.mn.Linear ------------------ -torch.nn.Linear equivalent. +`torch.nn.Linear `_ equivalent. .. automodule:: memtorch.mn.Linear :members: @@ -22,7 +51,7 @@ torch.nn.Linear equivalent. memtorch.mn.Conv1d ------------------ -torch.nn.Conv1d equivalent. +`torch.nn.Conv1d `_ equivalent. .. automodule:: memtorch.mn.Conv1d :members: @@ -31,7 +60,7 @@ torch.nn.Conv1d equivalent. memtorch.mn.Conv2d ------------------ -torch.nn.Conv2d equivalent. +`torch.nn.Conv2d `_ equivalent. .. automodule:: memtorch.mn.Conv2d :members: @@ -40,7 +69,7 @@ torch.nn.Conv2d equivalent. memtorch.mn.Conv3d ------------------ -torch.nn.Conv3d equivalent. +`torch.nn.Conv3d `_ equivalent. .. automodule:: memtorch.mn.Conv3d :members: diff --git a/docs/memtorch.rst b/docs/memtorch.rst index ffa4d6f7..a0e66ce4 100644 --- a/docs/memtorch.rst +++ b/docs/memtorch.rst @@ -1,7 +1,11 @@ Python API ========== +MemTorch consists of various submodules, as defined below: + .. toctree:: memtorch.bh + memtorch.cpp + memtorch.cu memtorch.map memtorch.mn diff --git a/docs/tutorials.rst b/docs/tutorials.rst index 3e64e4fd..ea7d2be6 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -33,4 +33,7 @@ To learn how to use MemTorch using interactive tutorials, and to reproduce simul :alt: Open In Colab :target: https://colab.research.google.com/github/coreylammie/MemTorch/blob/master/memtorch/examples/legacy/NovelSimulations.ipynb -The development of more Jupyter notebooks and tutorials is currently ongoing. \ No newline at end of file +The development of more Jupyter notebooks and tutorials is currently ongoing. + +[1] C. Lammie, W. Xiang, B. Linares-Barranco, and Azghadi, Mostafa Rahimi, “MemTorch: An Open-source Simulation Framework for Memristive Deep Learning Systems,” arXiv.org, 2020. https://arxiv.org/abs/2004.10971. +‌ \ No newline at end of file diff --git a/memtorch/bh/Quantize.py b/memtorch/bh/Quantize.py index d6ff8e2e..f698a278 100644 --- a/memtorch/bh/Quantize.py +++ b/memtorch/bh/Quantize.py @@ -22,7 +22,7 @@ def quantize( Parameters ---------- - tensor : tensor + tensor : torch.Tensor Input tensor. quant : int Bit width (if quant_method is not None) or the number of discrete quantization levels (if quant_method is None). @@ -39,7 +39,7 @@ def quantize( Returns ------- - tensor + torch.Tensor Quantized tensor. """ diff --git a/memtorch/bh/crossbar/Crossbar.py b/memtorch/bh/crossbar/Crossbar.py index 069d59b6..8d85f030 100644 --- a/memtorch/bh/crossbar/Crossbar.py +++ b/memtorch/bh/crossbar/Crossbar.py @@ -31,9 +31,9 @@ class Crossbar: Memristor model. memristor_model_params: **kwargs **kwargs to instantiate the memristor model with. - shape : (int, int) + shape : int, int Shape of the crossbar. - tile_shape : (int, int) + tile_shape : int, int Tile shape to use to store weights. If None, modular tiles are not used. use_bindings : bool Used to determine if C++/CUDA bindings are used (True) or not (False). @@ -162,7 +162,7 @@ def write_conductance_matrix( conductance_matrix : torch.FloatTensor Conductance matrix to write. transistor : bool - Used to determine if a 1T1R (True) or 1R arrangement (False) is simulated. + Used to determine if a 1T1R (True) or 0T1R arrangement (False) is simulated. programming_routine Programming routine (method) to use. programming_routine_params : **kwargs @@ -253,7 +253,7 @@ def init_crossbar( Parameters ---------- - weights : torch.tensor + weights : torch.Tensor Weights to map. memristor_model : memtorch.bh.memristor.Memristor.Memristor Memristor model. @@ -271,7 +271,7 @@ def init_crossbar( If not None, the proportion of weights to retain. scheme : memtorch.bh.Scheme Scheme enum. - tile_shape : (int, int) + tile_shape : int, int Tile shape to use to store weights. If None, modular tiles are not used. use_bindings : bool Used to determine if C++/CUDA bindings are used (True) or not (False). @@ -457,15 +457,15 @@ def simulate_matmul( Parameters ---------- - input : tensor + input : torch.Tensor Scaled input tensor. crossbar : memtorch.bh.Crossbar Crossbar containing devices to simulate. nl : bool Use lookup tables rather than simulating each device (True). - tiles_map: torch.tensor + tiles_map: torch.Tensor Tiles map for devices if tile_shape is not None. - crossbar_shape : (int, int) + crossbar_shape : int, int Crossbar shape if tile_shape is not None. max_input_voltage : float Maximum input voltage used to encode inputs. If None, inputs are unbounded. @@ -480,7 +480,7 @@ def simulate_matmul( Returns ------- - torch.tensor + torch.Tensor Output tensor. """ devices = crossbar.devices diff --git a/memtorch/bh/crossbar/Tile.py b/memtorch/bh/crossbar/Tile.py index ba4eb236..7e8fc37f 100644 --- a/memtorch/bh/crossbar/Tile.py +++ b/memtorch/bh/crossbar/Tile.py @@ -18,7 +18,7 @@ class Tile: Parameters ---------- - tile_shape : (int, int) + tile_shape : int, int Tile shape to use to store weights. patch_num : int Patch number. @@ -37,7 +37,7 @@ def update_array(self, new_array): Parameters ---------- - new_array : torch.tensor + new_array : torch.Tensor New array to construct the tile with. """ if new_array.shape == self.tile_shape or new_array.shape == ( @@ -64,16 +64,16 @@ def gen_tiles(tensor, tile_shape, input=False, use_bindings=True): Parameters ---------- - tensor : torch.tensor + tensor : torch.Tensor Tensor to represent using modular crossbar tiles. - tile_shape : (int, int) + tile_shape : int, int Tile shape to use to store weights. input : bool Used to determine if a tensor is an input (True). Returns ------- - (torch.tensor, torch.tensor) + torch.Tensor, torch.Tensor Tiles and tile_map. """ if use_bindings: @@ -166,15 +166,15 @@ def tile_matmul_row( Parameters ---------- - mat_a_row_tiles : torch.tensor + mat_a_row_tiles : torch.Tensor Tiles representing a row of matrix A. - mat_a_tiles_map : torch.tensor + mat_a_tiles_map : torch.Tensor Tiles map for matrix A. - mat_b_tiles : torch.tensor + mat_b_tiles : torch.Tensor Tiles representing matrix B. - mat_b_tiles_map : torch.tensor + mat_b_tiles_map : torch.Tensor Tiles map for matrix B. - mat_b_shape : (int, int) + mat_b_shape : int, int Shape of matrix B. ADC_resolution : int ADC resolution (bit width). If None, quantization noise is not accounted for. @@ -185,7 +185,7 @@ def tile_matmul_row( Returns ------- - torch.tensor + torch.Tensor Output tensor. """ device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda") @@ -242,17 +242,17 @@ def tile_matmul( Parameters ---------- - mat_a_tiles : torch.tensor + mat_a_tiles : torch.Tensor Tiles representing matrix A. - mat_a_tiles_map : torch.tensor + mat_a_tiles_map : torch.Tensor Tiles map for matrix A. - mat_a_shape : (int, int) + mat_a_shape : int, int Shape of matrix A. - mat_b_tiles : torch.tensor + mat_b_tiles : torch.Tensor Tiles representing matrix B. - mat_b_tiles_map : torch.tensor + mat_b_tiles_map : torch.Tensor Tiles map for matrix B. - mat_b_shape : (int, int) + mat_b_shape : int, int Shape of matrix B. ADC_resolution : int ADC resolution (bit width). If None, quantization noise is not accounted for. @@ -267,7 +267,7 @@ def tile_matmul( Returns ------- - torch.tensor + torch.Tensor Output tensor. """ assert ( @@ -329,3 +329,73 @@ def tile_matmul( quant_method, ) return result + + +def tiled_inference(input, m): + """Method to perform tiled inference. + + Parameters + ---------- + input : torch.Tensor + Input tensor (2-D). + m : memtorch.mn + Memristive MemTorch layer. + + Returns + ------- + torch.Tensor + Output tensor. + """ + tiles_map = m.crossbars[0].tiles_map + crossbar_shape = (m.crossbars[0].rows, m.crossbars[0].columns) + if m.use_bindings: + quant_method = m.quant_method + if quant_method is None: + return memtorch_bindings.tiled_inference( + input, + input.shape, + m.tile_shape, + m.crossbar_operation( + m.crossbars, lambda crossbar: crossbar.conductance_matrix + ), + m.crossbars[0].tiles_map, + (m.crossbars[0].rows, m.crossbars[0].columns), + ) + else: + assert ( + quant_method in memtorch.bh.Quantize.quant_methods + ), "quant_method is invalid." + return memtorch_bindings.tiled_inference( + input, + input.shape, + m.tile_shape, + m.crossbar_operation( + m.crossbars, lambda crossbar: crossbar.conductance_matrix + ), + tiles_map, + crossbar_shape, + m.ADC_resolution, + m.ADC_overflow_rate, + memtorch.bh.Quantize.quant_methods.index(quant_method), + ) + else: + (input_tiles, input_tiles_map) = gen_tiles( + input, + m.tile_shape, + input=True, + use_bindings=False, + ) + return tile_matmul( + input_tiles, + input_tiles_map, + input.shape, + m.crossbar_operation( + m.crossbars, lambda crossbar: crossbar.conductance_matrix + ), + tiles_map, + crossbar_shape, + m.ADC_resolution, + m.ADC_overflow_rate, + m.quant_method, + use_bindings=False, + ) diff --git a/memtorch/bh/nonideality/NonLinear.py b/memtorch/bh/nonideality/NonLinear.py index 6a50aa1c..fa42bc9a 100644 --- a/memtorch/bh/nonideality/NonLinear.py +++ b/memtorch/bh/nonideality/NonLinear.py @@ -16,7 +16,7 @@ def apply_non_linear( num_conductance_states=None, simulate=False, ): - """Method to model non_linear iv characteristics for devices within a memristive layer. + """Method to model non_linear I/V characteristics for devices within a memristive layer. Parameters ---------- @@ -25,12 +25,12 @@ def apply_non_linear( sweep_duration : float Voltage sweep duration (s). sweep_voltage_signal_amplitude : float - Voltage sweep amplitude (v). + Voltage sweep amplitude (V). sweep_voltage_signal_frequency : float Voltage sweep frequency (Hz). - num_conductance_states : int + num_conductance_states : int, optional Number of finite conductance states to model. None indicates finite states are not to be modeled. - simulate : bool + simulate : bool, optional Each device is simulated during inference (True). Returns diff --git a/memtorch/cpp/bindings.cpp b/memtorch/cpp/bindings.cpp index 066c6781..793cc054 100644 --- a/memtorch/cpp/bindings.cpp +++ b/memtorch/cpp/bindings.cpp @@ -3,15 +3,18 @@ #include #include "gen_tiles.h" +#include "inference.h" #include "quantize.h" #include "tile_matmul.h" void quantize_bindings(py::module_ &); void gen_tiles_bindings(py::module_ &); void tile_matmul_bindings(py::module_ &); +void inference_bindings(py::module_ &); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { quantize_bindings(m); gen_tiles_bindings(m); tile_matmul_bindings(m); + inference_bindings(m); } \ No newline at end of file diff --git a/memtorch/cpp/gen_tiles.h b/memtorch/cpp/gen_tiles.h index 7aebe51e..6158e27a 100644 --- a/memtorch/cpp/gen_tiles.h +++ b/memtorch/cpp/gen_tiles.h @@ -1 +1,4 @@ -void gen_tiles_bindings(py::module_ &m); \ No newline at end of file +void gen_tiles_bindings(py::module_ &m); +std::tuple +gen_tiles(at::Tensor tensor, int tile_shape[2], bool input, + torch::TensorOptions tensor_options); \ No newline at end of file diff --git a/memtorch/cpp/inference.cpp b/memtorch/cpp/inference.cpp new file mode 100644 index 00000000..f23d0b66 --- /dev/null +++ b/memtorch/cpp/inference.cpp @@ -0,0 +1,81 @@ +#include +#include +#include + +#include "gen_tiles.h" +#include "tile_matmul.h" + +at::Tensor tiled_inference(at::Tensor input, int input_shape[2], + int tile_shape[2], at::Tensor weight_tiles, + at::Tensor weight_tiles_map, int weight_shape[2]) { + at::Tensor input_tiles; + at::Tensor input_tiles_map; + std::tie(input_tiles, input_tiles_map) = gen_tiles( + input, tile_shape, true, torch::TensorOptions().device(torch::kCPU)); + return tile_matmul(input_tiles, input_tiles_map, input_shape, weight_tiles, + weight_tiles_map, weight_shape); +} + +at::Tensor tiled_inference(at::Tensor input, int input_shape[2], + int tile_shape[2], at::Tensor weight_tiles, + at::Tensor weight_tiles_map, int weight_shape[2], + int ADC_resolution, float ADC_overflow_rate, + int quant_method) { + at::Tensor input_tiles; + at::Tensor input_tiles_map; + std::tie(input_tiles, input_tiles_map) = gen_tiles( + input, tile_shape, true, torch::TensorOptions().device(torch::kCPU)); + return tile_matmul(input_tiles, input_tiles_map, input_shape, weight_tiles, + weight_tiles_map, weight_shape, ADC_resolution, + ADC_overflow_rate, quant_method); +} + +void inference_bindings(py::module_ &m) { + // Binding without quantization support + m.def( + "tiled_inference", + [](at::Tensor input, std::tuple input_shape, + std::tuple tile_shape, at::Tensor weight_tiles, + at::Tensor weight_tiles_map, std::tuple weight_shape) { + assert((std::tuple_size(input_shape) == 2)); + assert((std::tuple_size(tile_shape) == 2)); + assert((std::tuple_size(weight_shape) == 3)); + int input_shape_array[2] = {(int)std::get<0>(input_shape), + (int)std::get<1>(input_shape)}; + int tile_shape_array[2] = {(int)std::get<0>(tile_shape), + (int)std::get<1>(tile_shape)}; + int weight_shape_array[2] = {(int)std::get<0>(weight_shape), + (int)std::get<1>(weight_shape)}; + return tiled_inference(input, input_shape_array, tile_shape_array, + weight_tiles, weight_tiles_map, + weight_shape_array); + }, + py::arg("input"), py::arg("input_shape"), py::arg("tile_shape"), + py::arg("weight_tiles"), py::arg("weight_tiles_map"), + py::arg("weight_shape")); + // Binding with quantization support + m.def( + "tiled_inference", + [](at::Tensor input, std::tuple input_shape, + std::tuple tile_shape, at::Tensor weight_tiles, + at::Tensor weight_tiles_map, std::tuple weight_shape, + int ADC_resolution, float ADC_overflow_rate, int quant_method) { + assert((std::tuple_size(input_shape) == 2)); + assert((std::tuple_size(tile_shape) == 2)); + assert((std::tuple_size(weight_shape) == 3)); + int input_shape_array[2] = {(int)std::get<0>(input_shape), + (int)std::get<1>(input_shape)}; + int tile_shape_array[2] = {(int)std::get<0>(tile_shape), + (int)std::get<1>(tile_shape)}; + int weight_shape_array[2] = {(int)std::get<0>(weight_shape), + (int)std::get<1>(weight_shape)}; + return tiled_inference(input, input_shape_array, tile_shape_array, + weight_tiles, weight_tiles_map, + weight_shape_array, ADC_resolution, + ADC_overflow_rate, quant_method); + }, + py::arg("input"), py::arg("input_shape"), py::arg("tile_shape"), + py::arg("weight_tiles"), py::arg("weight_tiles_map"), + py::arg("weight_shape"), py::arg("ADC_resolution"), + py::arg("ADC_overflow_rate"), py::arg("quant_method")); +} \ No newline at end of file diff --git a/memtorch/cpp/inference.h b/memtorch/cpp/inference.h new file mode 100644 index 00000000..f36363e2 --- /dev/null +++ b/memtorch/cpp/inference.h @@ -0,0 +1 @@ +void inference_bindings(py::module_ &m); \ No newline at end of file diff --git a/memtorch/cpp/tile_matmul.cpp b/memtorch/cpp/tile_matmul.cpp index be39ef09..88e86608 100644 --- a/memtorch/cpp/tile_matmul.cpp +++ b/memtorch/cpp/tile_matmul.cpp @@ -64,6 +64,7 @@ at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, } void tile_matmul_bindings(py::module_ &m) { + // Binding without quantization support m.def( "tile_matmul", [](at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, @@ -83,6 +84,7 @@ void tile_matmul_bindings(py::module_ &m) { py::arg("mat_a_shape"), py::arg("mat_b_tiles"), py::arg("mat_b_tiles_map"), py::arg("mat_b_shape"), py::arg("cuda_malloc_heap_size") = NULL); + // Binding with quantization support m.def( "tile_matmul", [](at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, diff --git a/memtorch/cpp/tile_matmul.h b/memtorch/cpp/tile_matmul.h index 42f7ff2a..76660755 100644 --- a/memtorch/cpp/tile_matmul.h +++ b/memtorch/cpp/tile_matmul.h @@ -1 +1,9 @@ -void tile_matmul_bindings(py::module_ &m); \ No newline at end of file +void tile_matmul_bindings(py::module_ &m); +at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, + int mat_a_shape[2], at::Tensor mat_b_tiles, + at::Tensor mat_b_tiles_map, int mat_b_shape[2]); +at::Tensor tile_matmul(at::Tensor mat_a_tiles, at::Tensor mat_a_tiles_map, + int mat_a_shape[2], at::Tensor mat_b_tiles, + at::Tensor mat_b_tiles_map, int mat_b_shape[2], + int ADC_resolution, float ADC_overflow_rate, + int quant_method); \ No newline at end of file diff --git a/memtorch/cu/bindings.cpp b/memtorch/cu/bindings.cpp index ec7956db..20fbe44c 100644 --- a/memtorch/cu/bindings.cpp +++ b/memtorch/cu/bindings.cpp @@ -3,12 +3,15 @@ #include #include "gen_tiles.h" +#include "inference.h" #include "tile_matmul.h" void tile_matmul_bindings(py::module_ &); void gen_tiles_bindings_gpu(py::module_ &); +void inference_bindings(py::module_ &); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { gen_tiles_bindings_gpu(m); tile_matmul_bindings(m); + inference_bindings(m); } \ No newline at end of file diff --git a/memtorch/cu/gen_tiles.h b/memtorch/cu/gen_tiles.h index 6883ff8a..956a9346 100644 --- a/memtorch/cu/gen_tiles.h +++ b/memtorch/cu/gen_tiles.h @@ -1 +1,4 @@ -void gen_tiles_bindings_gpu(py::module_ &m); \ No newline at end of file +void gen_tiles_bindings_gpu(py::module_ &m); +std::tuple +gen_tiles(at::Tensor tensor, int tile_shape[2], bool input, + torch::TensorOptions tensor_options); \ No newline at end of file diff --git a/memtorch/cu/inference.cpp b/memtorch/cu/inference.cpp new file mode 100644 index 00000000..8ee43420 --- /dev/null +++ b/memtorch/cu/inference.cpp @@ -0,0 +1,86 @@ +#include +#include +#include + +#include "gen_tiles.h" +#include "tile_matmul_kernels.cuh" + +at::Tensor tiled_inference(at::Tensor input, int input_shape[2], + int tile_shape[2], at::Tensor weight_tiles, + at::Tensor weight_tiles_map, int weight_shape[2], + int cuda_malloc_heap_size) { + at::Tensor input_tiles; + at::Tensor input_tiles_map; + std::tie(input_tiles, input_tiles_map) = gen_tiles( + input, tile_shape, true, torch::TensorOptions().device(torch::kCUDA, 0)); + return tile_matmul(input_tiles, input_tiles_map, input_shape, weight_tiles, + weight_tiles_map, weight_shape, NULL, NULL, -1, + cuda_malloc_heap_size); +} + +at::Tensor tiled_inference(at::Tensor input, int input_shape[2], + int tile_shape[2], at::Tensor weight_tiles, + at::Tensor weight_tiles_map, int weight_shape[2], + int ADC_resolution, float ADC_overflow_rate, + int quant_method, int cuda_malloc_heap_size) { + at::Tensor input_tiles; + at::Tensor input_tiles_map; + std::tie(input_tiles, input_tiles_map) = gen_tiles( + input, tile_shape, true, torch::TensorOptions().device(torch::kCUDA, 0)); + return tile_matmul(input_tiles, input_tiles_map, input_shape, weight_tiles, + weight_tiles_map, weight_shape, ADC_resolution, + ADC_overflow_rate, quant_method, cuda_malloc_heap_size); +} + +void inference_bindings(py::module_ &m) { + // Binding without quantization support + m.def( + "tiled_inference", + [](at::Tensor input, std::tuple input_shape, + std::tuple tile_shape, at::Tensor weight_tiles, + at::Tensor weight_tiles_map, std::tuple weight_shape, + int cuda_malloc_heap_size) { + assert((std::tuple_size(input_shape) == 2)); + assert((std::tuple_size(tile_shape) == 2)); + assert((std::tuple_size(weight_shape) == 3)); + int input_shape_array[2] = {(int)std::get<0>(input_shape), + (int)std::get<1>(input_shape)}; + int tile_shape_array[2] = {(int)std::get<0>(tile_shape), + (int)std::get<1>(tile_shape)}; + int weight_shape_array[2] = {(int)std::get<0>(weight_shape), + (int)std::get<1>(weight_shape)}; + return tiled_inference(input, input_shape_array, tile_shape_array, + weight_tiles, weight_tiles_map, + weight_shape_array, cuda_malloc_heap_size); + }, + py::arg("input"), py::arg("input_shape"), py::arg("tile_shape"), + py::arg("weight_tiles"), py::arg("weight_tiles_map"), + py::arg("weight_shape"), py::arg("cuda_malloc_heap_size") = 50); + // Binding with quantization support + m.def( + "tiled_inference", + [](at::Tensor input, std::tuple input_shape, + std::tuple tile_shape, at::Tensor weight_tiles, + at::Tensor weight_tiles_map, std::tuple weight_shape, + int ADC_resolution, float ADC_overflow_rate, int quant_method, + int cuda_malloc_heap_size) { + assert((std::tuple_size(input_shape) == 2)); + assert((std::tuple_size(tile_shape) == 2)); + assert((std::tuple_size(weight_shape) == 3)); + int input_shape_array[2] = {(int)std::get<0>(input_shape), + (int)std::get<1>(input_shape)}; + int tile_shape_array[2] = {(int)std::get<0>(tile_shape), + (int)std::get<1>(tile_shape)}; + int weight_shape_array[2] = {(int)std::get<0>(weight_shape), + (int)std::get<1>(weight_shape)}; + return tiled_inference( + input, input_shape_array, tile_shape_array, weight_tiles, + weight_tiles_map, weight_shape_array, ADC_resolution, + ADC_overflow_rate, quant_method, cuda_malloc_heap_size); + }, + py::arg("input"), py::arg("input_shape"), py::arg("tile_shape"), + py::arg("weight_tiles"), py::arg("weight_tiles_map"), + py::arg("weight_shape"), py::arg("ADC_resolution"), + py::arg("ADC_overflow_rate"), py::arg("quant_method"), + py::arg("cuda_malloc_heap_size") = 50); +} \ No newline at end of file diff --git a/memtorch/cu/inference.h b/memtorch/cu/inference.h new file mode 100644 index 00000000..f36363e2 --- /dev/null +++ b/memtorch/cu/inference.h @@ -0,0 +1 @@ +void inference_bindings(py::module_ &m); \ No newline at end of file diff --git a/memtorch/examples/Tutorial.ipynb b/memtorch/examples/Tutorial.ipynb index 2d9ee042..0d1fd97b 100644 --- a/memtorch/examples/Tutorial.ipynb +++ b/memtorch/examples/Tutorial.ipynb @@ -2,16 +2,17 @@ "cells": [ { "cell_type": "markdown", - "metadata": {}, "source": [ "# MemTorch Tutorial\n", "## Introduction\n", "In this tutorial, you will learn how to use MemTorch to convert Deep Neural Networks (DNNs) to Memristive Deep Neural Networks (MDNNs), and how to simulate non-ideal device characteristics and key peripheral circuitry. MemTorch is a Simulation Framework for Memristive Deep Learning Systems, which integrates directly with the well-known PyTorch Machine Learning (ML) library. MemTorch is formally described in *MemTorch: An Open-source Simulation Framework for Memristive Deep Learning Systems*, which is openly accessible [here](https://arxiv.org/abs/2004.10971).\n", "\n", "![Overview](https://raw.githubusercontent.com/coreylammie/MemTorch/master/overview.svg)\n" - ] + ], + "metadata": {} }, { + "cell_type": "markdown", "source": [ "## 1. Installation\n", "MemTorch can be installed from source using `python setup.py install`:\n", @@ -43,59 +44,58 @@ "\n", "MemTorch can be installed using Jupyter notebooks as follows:" ], - "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ - "# Installation of MemTorch (with CUDA functionality) from source using pip\n", - "!git clone --recursive https://github.com/coreylammie/MemTorch\n", - "%cd MemTorch\n", - "!sed -i 's/CUDA = False/CUDA = True/g' setup.py\n", + "# Installation of MemTorch (with CUDA functionality) from source using pip\r\n", + "!git clone --recursive https://github.com/coreylammie/MemTorch\r\n", + "%cd MemTorch\r\n", + "!sed -i 's/CUDA = False/CUDA = True/g' setup.py\r\n", "!pip install ." - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ - "# Installation of MemTorch (without CUDA functionality) from source using pip\n", - "!git clone --recursive https://github.com/coreylammie/MemTorch\n", - "%cd MemTorch\n", + "# Installation of MemTorch (without CUDA functionality) from source using pip\r\n", + "!git clone --recursive https://github.com/coreylammie/MemTorch\r\n", + "%cd MemTorch\r\n", "!pip install ." - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ - "# Installation of MemTorch (with CUDA functionality) using pip\n", + "# Installation of MemTorch (with CUDA functionality) using pip\r\n", "!pip install memtorch" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ - "# Installation of MemTorch (without CUDA functionality) using pip\n", + "# Installation of MemTorch (without CUDA functionality) using pip\r\n", "!pip install memtorch-cpu" - ] + ], + "outputs": [], + "metadata": {} }, { + "cell_type": "markdown", "source": [ "## 2. Training and Benchmarking a Deep Neural Network Using MNIST" ], - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "qBYEsnSav5E1" @@ -103,10 +103,6 @@ }, { "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "HCy9QqPHv5E3" - }, "source": [ "MemTorch can currently be used to simulate the inference routines of MDNNs. Consequently, prior to conversion, DNNs must be either defined and trained using PyTorch or imported using PyTorch. \n", "\n", @@ -118,102 +114,102 @@ "* Adam is used to optimize network parameters and Cross Entropy (CE) is used to determine network losses.\n", "* `memtorch.utils.LoadMNIST` is used to load the MNIST training and test sets. After each epoch, the model is evaluated using the MNIST test set. \n", "* The model that achieves the highest test set accuracy is saved as *trained_model.pt*." - ] + ], + "metadata": { + "colab_type": "text", + "id": "HCy9QqPHv5E3" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "import torch\r\n", + "from torch.autograd import Variable\r\n", + "import memtorch\r\n", + "import torch.nn as nn\r\n", + "import torch.nn.functional as F\r\n", + "import torch.optim as optim\r\n", + "from memtorch.utils import LoadMNIST\r\n", + "import numpy as np\r\n", + "\r\n", + "class Net(nn.Module):\r\n", + " def __init__(self):\r\n", + " super(Net, self).__init__()\r\n", + " self.conv1 = nn.Conv2d(1, 20, 5, 1)\r\n", + " self.conv2 = nn.Conv2d(20, 50, 5, 1)\r\n", + " self.fc1 = nn.Linear(4*4*50, 500)\r\n", + " self.fc2 = nn.Linear(500, 10)\r\n", + "\r\n", + " def forward(self, x):\r\n", + " x = F.relu(self.conv1(x))\r\n", + " x = F.max_pool2d(x, 2, 2)\r\n", + " x = F.relu(self.conv2(x))\r\n", + " x = F.max_pool2d(x, 2, 2)\r\n", + " x = x.view(-1, 4*4*50)\r\n", + " x = F.relu(self.fc1(x))\r\n", + " x = self.fc2(x)\r\n", + " return x\r\n", + "\r\n", + "def test(model, test_loader):\r\n", + " correct = 0\r\n", + " for batch_idx, (data, target) in enumerate(test_loader): \r\n", + " output = model(data.to(device))\r\n", + " pred = output.data.max(1)[1]\r\n", + " correct += pred.eq(target.to(device).data.view_as(pred)).cpu().sum()\r\n", + "\r\n", + " return 100. * float(correct) / float(len(test_loader.dataset))\r\n", + "\r\n", + "device = torch.device('cpu' if 'cpu' in memtorch.__version__ else 'cuda')\r\n", + "epochs = 10\r\n", + "learning_rate = 1e-1\r\n", + "step_lr = 5\r\n", + "batch_size = 256\r\n", + "train_loader, validation_loader, test_loader = LoadMNIST(batch_size=batch_size, validation=False)\r\n", + "model = Net().to(device)\r\n", + "criterion = nn.CrossEntropyLoss()\r\n", + "optimizer = optim.Adam(model.parameters(), lr=learning_rate)\r\n", + "best_accuracy = 0\r\n", + "for epoch in range(0, epochs):\r\n", + " print('Epoch: [%d]\\t\\t' % (epoch + 1), end='')\r\n", + " if epoch % step_lr == 0:\r\n", + " learning_rate = learning_rate * 0.1\r\n", + " for param_group in optimizer.param_groups:\r\n", + " param_group['lr'] = learning_rate\r\n", + "\r\n", + " model.train()\r\n", + " for batch_idx, (data, target) in enumerate(train_loader):\r\n", + " optimizer.zero_grad()\r\n", + " output = model(data.to(device))\r\n", + " loss = criterion(output, target.to(device))\r\n", + " loss.backward()\r\n", + " optimizer.step()\r\n", + "\r\n", + " accuracy = test(model, test_loader)\r\n", + " print('%2.2f%%' % accuracy)\r\n", + " if accuracy > best_accuracy:\r\n", + " torch.save(model.state_dict(), 'trained_model.pt')\r\n", + " best_accuracy = accuracy" + ], + "outputs": [], "metadata": { "colab": {}, "colab_type": "code", "id": "jVH_tu3tv5E4" - }, - "outputs": [], - "source": [ - "import torch\n", - "from torch.autograd import Variable\n", - "import memtorch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import torch.optim as optim\n", - "from memtorch.utils import LoadMNIST\n", - "import numpy as np\n", - "\n", - "class Net(nn.Module):\n", - " def __init__(self):\n", - " super(Net, self).__init__()\n", - " self.conv1 = nn.Conv2d(1, 20, 5, 1)\n", - " self.conv2 = nn.Conv2d(20, 50, 5, 1)\n", - " self.fc1 = nn.Linear(4*4*50, 500)\n", - " self.fc2 = nn.Linear(500, 10)\n", - "\n", - " def forward(self, x):\n", - " x = F.relu(self.conv1(x))\n", - " x = F.max_pool2d(x, 2, 2)\n", - " x = F.relu(self.conv2(x))\n", - " x = F.max_pool2d(x, 2, 2)\n", - " x = x.view(-1, 4*4*50)\n", - " x = F.relu(self.fc1(x))\n", - " x = self.fc2(x)\n", - " return x\n", - "\n", - "def test(model, test_loader):\n", - " correct = 0\n", - " for batch_idx, (data, target) in enumerate(test_loader): \n", - " output = model(data.to(device))\n", - " pred = output.data.max(1)[1]\n", - " correct += pred.eq(target.to(device).data.view_as(pred)).cpu().sum()\n", - "\n", - " return 100. * float(correct) / float(len(test_loader.dataset))\n", - "\n", - "device = torch.device('cpu' if 'cpu' in memtorch.__version__ else 'cuda')\n", - "epochs = 10\n", - "learning_rate = 1e-1\n", - "step_lr = 5\n", - "batch_size = 256\n", - "train_loader, validation_loader, test_loader = LoadMNIST(batch_size=batch_size, validation=False)\n", - "model = Net().to(device)\n", - "criterion = nn.CrossEntropyLoss()\n", - "optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n", - "best_accuracy = 0\n", - "for epoch in range(0, epochs):\n", - " print('Epoch: [%d]\\t\\t' % (epoch + 1), end='')\n", - " if epoch % step_lr == 0:\n", - " learning_rate = learning_rate * 0.1\n", - " for param_group in optimizer.param_groups:\n", - " param_group['lr'] = learning_rate\n", - "\n", - " model.train()\n", - " for batch_idx, (data, target) in enumerate(train_loader):\n", - " optimizer.zero_grad()\n", - " output = model(data.to(device))\n", - " loss = criterion(output, target.to(device))\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " accuracy = test(model, test_loader)\n", - " print('%2.2f%%' % accuracy)\n", - " if accuracy > best_accuracy:\n", - " torch.save(model.state_dict(), 'trained_model.pt')\n", - " best_accuracy = accuracy" - ] + } }, { "cell_type": "markdown", + "source": [ + "## 3. Conversion of a Deep Neural Network to a Memristive Deep Neural Network " + ], "metadata": { "colab_type": "text", "id": "Ag8Z6Rn_v5E8" - }, - "source": [ - "## 3. Conversion of a Deep Neural Network to a Memristive Deep Neural Network " - ] + } }, { "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "lyt9qHvAv5E9" - }, "source": [ "Within MemTorch, `memtorch.mn.Module.patch_model` can be used to convert DNNs to a MDNNs. Prior to conversion, a memristive device model must be defined and characterized in part (prior to the introduction of other non-ideal device characteristics).\n", "\n", @@ -223,11 +219,23 @@ "* A `memtorch.bh.memristor.Memristor` object is instantiated\n", "* The hysteresis loop of the instantiated memristor object is generated/plotted.\n", "* The bipolar switching behaviour of the instantiated memristor object is generated/plotted." - ] + ], + "metadata": { + "colab_type": "text", + "id": "lyt9qHvAv5E9" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "reference_memristor = memtorch.bh.memristor.VTEAM\r\n", + "reference_memristor_params = {'time_series_resolution': 1e-10}\r\n", + "memristor = reference_memristor(**reference_memristor_params)\r\n", + "memristor.plot_hysteresis_loop()\r\n", + "memristor.plot_bipolar_switching_behaviour()" + ], + "outputs": [], "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -236,22 +244,10 @@ "colab_type": "code", "id": "dRMGKP-lv5E-", "outputId": "f22a800e-3eae-4060-a874-23d7a218f3cf" - }, - "outputs": [], - "source": [ - "reference_memristor = memtorch.bh.memristor.VTEAM\n", - "reference_memristor_params = {'time_series_resolution': 1e-10}\n", - "memristor = reference_memristor(**reference_memristor_params)\n", - "memristor.plot_hysteresis_loop()\n", - "memristor.plot_bipolar_switching_behaviour()" - ] + } }, { "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "b0kLErl0v5FC" - }, "source": [ "In the cell below, the trained DNN from Section 2 is converted to an equivalent MDNN, where all convolutional layers are replaced with memristive-equivalent layers. While only *Conv2d* layers are converted for demonstration purposes, we note that MemTorch currently supports conversion of *Conv1d*, *Conv2d*, *Conv3d*, and *Linear* layers. Specifically:\n", "* `memtorch.bh.map.Parameter.naive_map` is used to convert the weights within all `torch.nn.Conv2d` layers to equivalent conductance values, to be programmed to the two memristive devices used to represent each weight (positive and negative, respectively). \n", @@ -265,97 +261,98 @@ "\n", "\n", "We note if `transistor` is `False` `programming_routine` must not be `None`. In which case, device-level simulation is performed for each device using `memtorch.bh.crossbar.gen_programming_signal` and `memtorch.bh.memristor.Memristor.simulate`, which use finite differences to model internal device dynamics. As `scheme` is not defined, a double-column parameter representation scheme is adopted. Finally, `max_input_voltage` is 0.3, meaning inputs to each layer are encoded between -0.3V and +0.3V." - ] + ], + "metadata": { + "colab_type": "text", + "id": "b0kLErl0v5FC" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "import copy\r\n", + "from memtorch.mn.Module import patch_model\r\n", + "from memtorch.map.Input import naive_scale\r\n", + "from memtorch.map.Parameter import naive_map\r\n", + "\r\n", + "\r\n", + "model = Net().to(device)\r\n", + "model.load_state_dict(torch.load('trained_model.pt'), strict=False)\r\n", + "patched_model = patch_model(copy.deepcopy(model),\r\n", + " memristor_model=reference_memristor,\r\n", + " memristor_model_params=reference_memristor_params,\r\n", + " module_parameters_to_patch=[torch.nn.Conv2d],\r\n", + " mapping_routine=naive_map,\r\n", + " transistor=True,\r\n", + " programming_routine=None,\r\n", + " tile_shape=(128, 128),\r\n", + " max_input_voltage=0.3,\r\n", + " scaling_routine=naive_scale,\r\n", + " ADC_resolution=8,\r\n", + " ADC_overflow_rate=0.,\r\n", + " quant_method='linear')" + ], + "outputs": [], "metadata": { "colab": {}, "colab_type": "code", "id": "oJWSTW5Qv5FD" - }, - "outputs": [], - "source": [ - "import copy\n", - "from memtorch.mn.Module import patch_model\n", - "from memtorch.map.Parameter import naive_map\n", - "from memtorch.bh.crossbar.Program import naive_program\n", - "\n", - "\n", - "model = Net().to(device)\n", - "model.load_state_dict(torch.load('trained_model.pt'), strict=False)\n", - "patched_model = patch_model(copy.deepcopy(model),\n", - " memristor_model=reference_memristor,\n", - " memristor_model_params=reference_memristor_params,\n", - " module_parameters_to_patch=[torch.nn.Conv2d],\n", - " mapping_routine=naive_map,\n", - " transistor=True,\n", - " programming_routine=None,\n", - " tile_shape=(128, 128),\n", - " max_input_voltage=0.3,\n", - " ADC_resolution=8,\n", - " ADC_overflow_rate=0.,\n", - " quant_method='linear')" - ] + } }, { + "cell_type": "markdown", "source": [ "In the cell below, all patched `torch.nn.Conv2d` layers are tuned using linear regression. A randomly generated tensor of size (8, `self.in_channels`, 32, 32) is propagated through each memristive layer and each legacy layer (accessible using `layer.forward_legacy`). `sklearn.linear_model.LinearRegression` is used to determine the coefficient and intercept between the linear relationship of each set of outputs, which is used to define the `transform_output` lamdba function, that maps the output of each layer to their equivalent representations." ], - "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": null, + "source": [ + "patched_model.tune_()" + ], + "outputs": [], "metadata": { "colab": {}, "colab_type": "code", "id": "Mam3ggffv5FG" - }, - "outputs": [], - "source": [ - "patched_model.tune_()" - ] + } }, { + "cell_type": "markdown", "source": [ "Finally, in the cell below, the converted and tuned MDNN is benchmarked using the MNIST test data set. \n", "*Note: This cell may take a considerable amount of time to run.*" ], - "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": null, + "source": [ + "print(test(patched_model, test_loader))" + ], + "outputs": [], "metadata": { "colab": {}, "colab_type": "code", "id": "U5F0muXPv5FK" - }, - "outputs": [], - "source": [ - "print(test(patched_model, test_loader))" - ] + } }, { "cell_type": "markdown", + "source": [ + "## 4. Modeling Non-Ideal Device Characteristics" + ], "metadata": { "colab_type": "text", "id": "eV8IJSH6v5FN" - }, - "source": [ - "## 4. Modeling Non-Ideal Device Characteristics" - ] + } }, { "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "R4H9f4d248V7" - }, "source": [ "\n", "Non-ideal device characteristics can either be encapsulated within device specific memristive models, or introduced to base (generic) models after conversion, using `memtorch.bh.nonideality.NonIdeality.apply_nonidealities`. Currently, the following non-ideal device characteristics are supported:\n", @@ -365,23 +362,27 @@ "* `memtorch.bh.nonideality.NonLinear`\n", "\n", "Stochastic parameters, used to model process variances, can be defined using `memtorch.bh.StochaticParameter`. The introduction of each type of non ideal device characteristic is demonstrated below.\n" - ] + ], + "metadata": { + "colab_type": "text", + "id": "R4H9f4d248V7" + } }, { + "cell_type": "markdown", "source": [ - "### 4.1 Modeling Device Faults\n", - "\n", - "Memristive devices are susceptible to failure, by either failing to eletroform at a pristine state, or becoming stuck at high or low resistance states. MemTorch incorporates a specific function for accounting for device failure, `memtorch.bh.nonideality.DeviceFaults`. \n", - "\n", - "In the cell below:\n", - "* The original patched model is copied using `copy.deepcopy`.\n", - "* `lrs_proportion` is set to 0.25, so that 25% of devices are assumed to fail to a low resistance state.\n", - "* `hrs_proportion` is set to 0.10, so that 15% of devices are assumed to fail to a high resistance state.\n", - "\n", - "It is assumed that the total proportion of devices set to a high resistance state is equal to the proportion of devices that fail to eletroform at pristine states plus the proportion of devices stuck at a high resistance state.\n", - "\n" + "### 4.1 Modeling Device Faults\r\n", + "\r\n", + "Memristive devices are susceptible to failure, by either failing to eletroform at a pristine state, or becoming stuck at high or low resistance states. MemTorch incorporates a specific function for accounting for device failure, `memtorch.bh.nonideality.DeviceFaults`. \r\n", + "\r\n", + "In the cell below:\r\n", + "* The original patched model is copied using `copy.deepcopy`.\r\n", + "* `lrs_proportion` is set to 0.25, so that 25% of devices are assumed to fail to a low resistance state.\r\n", + "* `hrs_proportion` is set to 0.10, so that 15% of devices are assumed to fail to a high resistance state.\r\n", + "\r\n", + "It is assumed that the total proportion of devices set to a high resistance state is equal to the proportion of devices that fail to eletroform at pristine states plus the proportion of devices stuck at a high resistance state.\r\n", + "\r\n" ], - "cell_type": "markdown", "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -395,33 +396,29 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ - "from memtorch.bh.nonideality.NonIdeality import apply_nonidealities\n", - "\n", - "patched_model_ = apply_nonidealities(copy.deepcopy(patched_model),\n", - " non_idealities=[memtorch.bh.nonideality.NonIdeality.DeviceFaults],\n", - " lrs_proportion=0.25,\n", - " hrs_proportion=0.10,\n", + "from memtorch.bh.nonideality.NonIdeality import apply_nonidealities\r\n", + "\r\n", + "patched_model_ = apply_nonidealities(copy.deepcopy(patched_model),\r\n", + " non_idealities=[memtorch.bh.nonideality.NonIdeality.DeviceFaults],\r\n", + " lrs_proportion=0.25,\r\n", + " hrs_proportion=0.10,\r\n", " electroform_proportion=0)" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "print(test(patched_model_, test_loader))" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "D2Vcpuuw5D_S" - }, "source": [ "### 4.2 Modeling Device Endurance and Retention\n", "\n", @@ -434,27 +431,33 @@ "* `x`, the number of SET-RESET cycles is set to be equal to 10,000.\n", "* Endurance characteristics are accounted for using `memtorch.bh.nonideality.NonIdeality.Endurance` and `memtorch.bh.nonideality.endurance_retention_models.model_endurance_retention`.\n", "* `operation_mode` within `endurance_model_kwargs` is set to `sudden`, so that sudden failure is modeled, and various other model arguments are set.\n" - ] + ], + "metadata": { + "colab_type": "text", + "id": "D2Vcpuuw5D_S" + } }, { + "cell_type": "code", + "execution_count": null, "source": [ - "from memtorch.bh.nonideality.NonIdeality import apply_nonidealities\n", - "\n", - "patched_model_ = apply_nonidealities(copy.deepcopy(patched_model),\n", - " non_idealities=[memtorch.bh.nonideality.NonIdeality.Endurance],\n", - " x=1e4,\n", - " endurance_model=memtorch.bh.nonideality.endurance_retention_models.model_endurance_retention,\n", - " endurance_model_kwargs={\n", - " \"operation_mode\": memtorch.bh.nonideality.endurance_retention_models.OperationMode.sudden,\n", - " \"p_lrs\": [1, 0, 0, 0],\n", - " \"stable_resistance_lrs\": 100,\n", - " \"p_hrs\": [1, 0, 0, 0],\n", - " \"stable_resistance_hrs\": 1000,\n", - " \"cell_size\": 10,\n", - " \"temperature\": 350,\n", + "from memtorch.bh.nonideality.NonIdeality import apply_nonidealities\r\n", + "\r\n", + "patched_model_ = apply_nonidealities(copy.deepcopy(patched_model),\r\n", + " non_idealities=[memtorch.bh.nonideality.NonIdeality.Endurance],\r\n", + " x=1e4,\r\n", + " endurance_model=memtorch.bh.nonideality.endurance_retention_models.model_endurance_retention,\r\n", + " endurance_model_kwargs={\r\n", + " \"operation_mode\": memtorch.bh.nonideality.endurance_retention_models.OperationMode.sudden,\r\n", + " \"p_lrs\": [1, 0, 0, 0],\r\n", + " \"stable_resistance_lrs\": 100,\r\n", + " \"p_hrs\": [1, 0, 0, 0],\r\n", + " \"stable_resistance_hrs\": 1000,\r\n", + " \"cell_size\": 10,\r\n", + " \"temperature\": 350,\r\n", " })" ], - "cell_type": "code", + "outputs": [], "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -463,20 +466,19 @@ "colab_type": "code", "id": "4thsaVDEv5FS", "outputId": "e46021ed-b57f-4be7-f29e-15bd17ec2473" - }, - "execution_count": null, - "outputs": [] + } }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "print(test(patched_model_, test_loader))" - ] + ], + "outputs": [], + "metadata": {} }, { + "cell_type": "markdown", "source": [ "In the cell below:\n", "* The original patched model is copied using `copy.deepcopy`.\n", @@ -485,37 +487,37 @@ "* `initial_time` within `retention_model_kwargs`, the initial time, is set to be equal to 1s.\n", "* `drift_coefficient` within `retention_model_kwargs` is set to be equal to 0.1." ], - "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ - "from memtorch.bh.nonideality.NonIdeality import apply_nonidealities\n", - "\n", - "patched_model_ = apply_nonidealities(copy.deepcopy(patched_model),\n", - " non_idealities=[memtorch.bh.nonideality.NonIdeality.Retention],\n", - " time=1e3,\n", - " retention_model=memtorch.bh.nonideality.endurance_retention_models.model_conductance_drift,\n", - " retention_model_kwargs={\n", - " \"initial_time\": 1,\n", - " \"drift_coefficient\": 0.1,\n", + "from memtorch.bh.nonideality.NonIdeality import apply_nonidealities\r\n", + "\r\n", + "patched_model_ = apply_nonidealities(copy.deepcopy(patched_model),\r\n", + " non_idealities=[memtorch.bh.nonideality.NonIdeality.Retention],\r\n", + " time=1e3,\r\n", + " retention_model=memtorch.bh.nonideality.endurance_retention_models.model_conductance_drift,\r\n", + " retention_model_kwargs={\r\n", + " \"initial_time\": 1,\r\n", + " \"drift_coefficient\": 0.1,\r\n", " })" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "print(test(patched_model_, test_loader))" - ] + ], + "outputs": [], + "metadata": {} }, { + "cell_type": "markdown", "source": [ "### 4.3 Modeling a Finite Number of Conductance States\n", "\n", @@ -526,32 +528,32 @@ "* A finite number of conductance states are accounted for using `memtorch.bh.nonideality.NonIdeality.FiniteConductanceStates`.\n", "* `conductance_states` is set to be equal to 5, to model 5 evenly-distributed conductance states." ], - "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ - "from memtorch.bh.nonideality.NonIdeality import apply_nonidealities\n", - "\n", - "patched_model_ = apply_nonidealities(copy.deepcopy(patched_model),\n", - " non_idealities=[memtorch.bh.nonideality.NonIdeality.FiniteConductanceStates],\n", + "from memtorch.bh.nonideality.NonIdeality import apply_nonidealities\r\n", + "\r\n", + "patched_model_ = apply_nonidealities(copy.deepcopy(patched_model),\r\n", + " non_idealities=[memtorch.bh.nonideality.NonIdeality.FiniteConductanceStates],\r\n", " conductance_states=5) " - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "print(test(patched_model_, test_loader))" - ] + ], + "outputs": [], + "metadata": {} }, { + "cell_type": "markdown", "source": [ "### 4.4 Modeling Non-Linear Device Characteristics\n", "\n", @@ -571,32 +573,32 @@ "\n", "\n" ], - "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ - "from memtorch.bh.nonideality.NonIdeality import apply_nonidealities\n", - "\n", - "patched_model_ = apply_nonidealities(copy.deepcopy(patched_model),\n", - " non_idealities=[memtorch.bh.nonideality.NonIdeality.NonLinear],\n", + "from memtorch.bh.nonideality.NonIdeality import apply_nonidealities\r\n", + "\r\n", + "patched_model_ = apply_nonidealities(copy.deepcopy(patched_model),\r\n", + " non_idealities=[memtorch.bh.nonideality.NonIdeality.NonLinear],\r\n", " simulate=True)" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "print(test(patched_model_, test_loader))" - ] + ], + "outputs": [], + "metadata": {} }, { + "cell_type": "markdown", "source": [ "In the cell below:\n", "* The original patched model is copied using `copy.deepcopy`.\n", @@ -606,41 +608,41 @@ "* `sweep_voltage_signal_amplitude` is set to be equal to 1V.\n", "* `sweep_voltage_signal_frequency` is set to be equal to 0.5Hz.\n" ], - "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ - "from memtorch.bh.nonideality.NonIdeality import apply_nonidealities\n", - "\n", - "patched_model_ = apply_nonidealities(copy.deepcopy(patched_model),\n", - " non_idealities=[memtorch.bh.nonideality.NonIdeality.NonLinear],\n", - " sweep_duration=2,\n", - " sweep_voltage_signal_amplitude=1,\n", + "from memtorch.bh.nonideality.NonIdeality import apply_nonidealities\r\n", + "\r\n", + "patched_model_ = apply_nonidealities(copy.deepcopy(patched_model),\r\n", + " non_idealities=[memtorch.bh.nonideality.NonIdeality.NonLinear],\r\n", + " sweep_duration=2,\r\n", + " sweep_voltage_signal_amplitude=1,\r\n", " sweep_voltage_signal_frequency=0.5)" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "print(test(patched_model_, test_loader))" - ] + ], + "outputs": [], + "metadata": {} }, { + "cell_type": "markdown", "source": [ "### 4.5 Modeling Stochastic Parameters" ], - "cell_type": "markdown", "metadata": {} }, { + "cell_type": "markdown", "source": [ "MemTorch supports the usage of stochastic parameters for higher flexibility to simply account for process variances using `memtorch.bh.StochasticParameter.StochasticParameter`. Stochastic parameters can be used when defining device characteristics. \n", "\n", @@ -650,35 +652,34 @@ "\n", "Each time the memristor object is instantiated, stochastic parameters will be resampled.\n" ], - "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ - "import memtorch\n", - "\n", - "reference_memristor = memtorch.bh.memristor.VTEAM\n", - "reference_memristor_params = {'time_series_resolution': 1e-10, \n", - " 'r_off': memtorch.bh.StochasticParameter(loc=1000, scale=200, min=2),\n", - " 'r_on': memtorch.bh.StochasticParameter(loc=5000, scale=sigma, min=1)}\n", - "\n", - "memristor = reference_memristor(**reference_memristor_params)\n", - "memristor.plot_hysteresis_loop()\n", + "import memtorch\r\n", + "\r\n", + "reference_memristor = memtorch.bh.memristor.VTEAM\r\n", + "reference_memristor_params = {'time_series_resolution': 1e-10, \r\n", + " 'r_off': memtorch.bh.StochasticParameter(loc=1000, scale=200, min=2),\r\n", + " 'r_on': memtorch.bh.StochasticParameter(loc=5000, scale=sigma, min=1)}\r\n", + "\r\n", + "memristor = reference_memristor(**reference_memristor_params)\r\n", + "memristor.plot_hysteresis_loop()\r\n", "memristor.plot_bipolar_switching_behaviour()" - ] + ], + "outputs": [], + "metadata": {} }, { + "cell_type": "markdown", "source": [ "## Final Remarks\n", "A complete API is avaliable [here](https://memtorch.readthedocs.io/). To learn how to use MemTorch, and to reproduce results of ‘_MemTorch: An Open-source Simulation Framework for Memristive Deep Learning Systems_’, we provide numerous tutorials in the form of Jupyter notebooks [here](https://memtorch.readthedocs.io/en/latest/tutorials.html).\n", "\n", "Current issues, feature requests and improvements are welcome, and are tracked using: https://github.com/coreylammie/MemTorch/projects/1. These should be reported [here](https://github.com/coreylammie/MemTorch/issues)." ], - "cell_type": "markdown", "metadata": {} } ], diff --git a/memtorch/map/Input.py b/memtorch/map/Input.py index 81e0c1f2..e0544520 100644 --- a/memtorch/map/Input.py +++ b/memtorch/map/Input.py @@ -8,6 +8,22 @@ def naive_scale(module, input, force_scale=False): + """Naive method to encode input values as bit-line voltages. + + Parameters + ---------- + module : torch.nn.Module + Memristive layer to tune. + input : torch.tensor + Input tensor to encode. + force_scale : bool, optional + Used to determine if inputs are scaled (True) or not (False) if they no not exceed max_input_voltage. + + Returns + ------- + torch.Tensor + Encoded voltages. + """ if module.max_input_voltage is not None: assert ( type(module.max_input_voltage) == int diff --git a/memtorch/map/Module.py b/memtorch/map/Module.py index 8afd3146..cef9b2c6 100644 --- a/memtorch/map/Module.py +++ b/memtorch/map/Module.py @@ -16,9 +16,9 @@ def naive_tune(module, input_shape, verbose=True): ---------- module : torch.nn.Module Memristive layer to tune. - input_shape : (int, int) + input_shape : int, int Shape of the randomly generated input used to tune a crossbar. - verbose : bool + verbose : bool, optional Used to determine if verbose output is enabled (True) or disabled (False). Returns diff --git a/memtorch/map/Parameter.py b/memtorch/map/Parameter.py index c7b73727..de274817 100644 --- a/memtorch/map/Parameter.py +++ b/memtorch/map/Parameter.py @@ -20,7 +20,7 @@ def naive_map(weight, r_on, r_off, scheme, p_l=None): High resistance state. scheme: memtorch.bh.crossbar.Scheme Weight representation scheme. - p_l: float + p_l: float, optional If not None, the proportion of weights to retain. Returns diff --git a/memtorch/mn/Conv1d.py b/memtorch/mn/Conv1d.py index 4945b776..ae0474c3 100644 --- a/memtorch/mn/Conv1d.py +++ b/memtorch/mn/Conv1d.py @@ -6,7 +6,7 @@ import memtorch from memtorch.bh.crossbar.Crossbar import init_crossbar, simulate_matmul -from memtorch.bh.crossbar.Tile import gen_tiles, tile_matmul +from memtorch.bh.crossbar.Tile import tiled_inference from memtorch.map.Input import naive_scale from memtorch.map.Module import naive_tune from memtorch.map.Parameter import naive_map @@ -223,35 +223,7 @@ def forward(self, input): ) else: if self.tile_shape is not None: - ( - unfolded_batch_input_tiles, - unfolded_batch_input_tiles_map, - ) = gen_tiles( - unfolded_batch_input, - self.tile_shape, - input=True, - use_bindings=self.use_bindings, - ) - crossbar_shape = ( - self.crossbars[0].rows, - self.crossbars[0].columns, - ) - tiles_map = self.crossbars[0].tiles_map - out_ = tile_matmul( - unfolded_batch_input_tiles, - unfolded_batch_input_tiles_map, - unfolded_batch_input_shape, - self.crossbar_operation( - self.crossbars, - lambda crossbar: crossbar.conductance_matrix, - ), - tiles_map, - crossbar_shape, - self.ADC_resolution, - self.ADC_overflow_rate, - self.quant_method, - use_bindings=self.use_bindings, - ).T + out_ = tiled_inference(unfolded_batch_input, self).T else: out_ = torch.matmul( unfolded_batch_input, diff --git a/memtorch/mn/Conv2d.py b/memtorch/mn/Conv2d.py index 388522ca..a9661555 100644 --- a/memtorch/mn/Conv2d.py +++ b/memtorch/mn/Conv2d.py @@ -6,7 +6,7 @@ import memtorch from memtorch.bh.crossbar.Crossbar import init_crossbar, simulate_matmul -from memtorch.bh.crossbar.Tile import gen_tiles, tile_matmul +from memtorch.bh.crossbar.Tile import tiled_inference from memtorch.map.Input import naive_scale from memtorch.map.Module import naive_tune from memtorch.map.Parameter import naive_map @@ -245,35 +245,7 @@ def forward(self, input): ) else: if self.tile_shape is not None: - ( - unfolded_batch_input_tiles, - unfolded_batch_input_tiles_map, - ) = gen_tiles( - unfolded_batch_input, - self.tile_shape, - input=True, - use_bindings=self.use_bindings, - ) - crossbar_shape = ( - self.crossbars[0].rows, - self.crossbars[0].columns, - ) - tiles_map = self.crossbars[0].tiles_map - out_ = tile_matmul( - unfolded_batch_input_tiles, - unfolded_batch_input_tiles_map, - unfolded_batch_input_shape, - self.crossbar_operation( - self.crossbars, - lambda crossbar: crossbar.conductance_matrix, - ), - tiles_map, - crossbar_shape, - self.ADC_resolution, - self.ADC_overflow_rate, - self.quant_method, - use_bindings=self.use_bindings, - ).T + out_ = tiled_inference(unfolded_batch_input, self).T else: out_ = torch.matmul( unfolded_batch_input, diff --git a/memtorch/mn/Conv3d.py b/memtorch/mn/Conv3d.py index 8930d3c5..ad265362 100644 --- a/memtorch/mn/Conv3d.py +++ b/memtorch/mn/Conv3d.py @@ -6,7 +6,7 @@ import memtorch from memtorch.bh.crossbar.Crossbar import init_crossbar, simulate_matmul -from memtorch.bh.crossbar.Tile import gen_tiles, tile_matmul +from memtorch.bh.crossbar.Tile import tiled_inference from memtorch.map.Input import naive_scale from memtorch.map.Module import naive_tune from memtorch.map.Parameter import naive_map @@ -265,35 +265,7 @@ def forward(self, input): ) else: if self.tile_shape is not None: - ( - unfolded_batch_input_tiles, - unfolded_batch_input_tiles_map, - ) = gen_tiles( - unfolded_batch_input, - self.tile_shape, - input=True, - use_bindings=self.use_bindings, - ) - crossbar_shape = ( - self.crossbars[0].rows, - self.crossbars[0].columns, - ) - tiles_map = self.crossbars[0].tiles_map - out_ = tile_matmul( - unfolded_batch_input_tiles, - unfolded_batch_input_tiles_map, - unfolded_batch_input_shape, - self.crossbar_operation( - self.crossbars, - lambda crossbar: crossbar.conductance_matrix, - ), - tiles_map, - crossbar_shape, - self.ADC_resolution, - self.ADC_overflow_rate, - self.quant_method, - use_bindings=self.use_bindings, - ).T + out_ = tiled_inference(unfolded_batch_input, self).T else: out_ = torch.matmul( unfolded_batch_input, diff --git a/memtorch/mn/Linear.py b/memtorch/mn/Linear.py index 19bb93fd..569d5d63 100644 --- a/memtorch/mn/Linear.py +++ b/memtorch/mn/Linear.py @@ -6,7 +6,7 @@ import memtorch from memtorch.bh.crossbar.Crossbar import init_crossbar, simulate_matmul -from memtorch.bh.crossbar.Tile import gen_tiles, tile_matmul +from memtorch.bh.crossbar.Tile import tiled_inference from memtorch.map.Input import naive_scale from memtorch.map.Module import naive_tune from memtorch.map.Parameter import naive_map @@ -193,28 +193,7 @@ def forward(self, input): ).to(self.device) else: if self.tile_shape is not None: - (input_tiles, input_tiles_map) = gen_tiles( - input, - self.tile_shape, - input=True, - use_bindings=self.use_bindings, - ) - crossbar_shape = (self.crossbars[0].rows, self.crossbars[0].columns) - tiles_map = self.crossbars[0].tiles_map - out_ = tile_matmul( - input_tiles, - input_tiles_map, - input_shape, - self.crossbar_operation( - self.crossbars, lambda crossbar: crossbar.conductance_matrix - ), - tiles_map, - crossbar_shape, - self.ADC_resolution, - self.ADC_overflow_rate, - self.quant_method, - use_bindings=self.use_bindings, - ) + out_ = tiled_inference(input, self) else: out_ = torch.matmul( input.to(self.device), diff --git a/memtorch/version.py b/memtorch/version.py index f3b319c9..0b2f79db 100644 --- a/memtorch/version.py +++ b/memtorch/version.py @@ -1 +1 @@ -__version__ = "1.1.2-cpu" +__version__ = "1.1.3" diff --git a/setup.py b/setup.py index 4cb8d817..2f03869d 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ from setuptools import find_packages, setup from torch.utils.cpp_extension import include_paths -version = "1.1.2" +version = "1.1.3" CUDA = False