Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
import logging | |
from typing import List, Tuple, Any | |
from ..graph import IllegalGraphError, Edge, Graph, Node, Model | |
_logger = logging.getLogger(__name__) | |
def model_to_pytorch_script(model: Model, placement=None) -> str: | |
graphs = [] | |
total_pkgs = set() | |
for name, cell in model.graphs.items(): | |
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement=placement) | |
graphs.append(graph_code) | |
total_pkgs.update(import_pkgs) | |
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs]) | |
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip() | |
def _sorted_incoming_edges(node: Node) -> List[Edge]: | |
edges = [edge for edge in node.graph.edges if edge.tail is node] | |
_logger.debug('sorted_incoming_edges: %s', str(edges)) | |
if not edges: | |
return [] | |
_logger.debug('all tail_slots are None: %s', str([edge.tail_slot for edge in edges])) | |
if all(edge.tail_slot is None for edge in edges): | |
return edges | |
if all(isinstance(edge.tail_slot, int) for edge in edges): | |
edges = sorted(edges, key=(lambda edge: edge.tail_slot)) | |
if [edge.tail_slot for edge in edges] == list(range(len(edges))): | |
return edges | |
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name)) | |
def _format_inputs(node: Node) -> Tuple[List[str], List[Any]]: | |
""" | |
Format the inputs of a given node | |
Parameters | |
---------- | |
node : Node | |
a graph node, get and format its inputs | |
Returns | |
------- | |
list | |
the list of input names | |
list | |
the list of input values, if an input is simple type, record its value, | |
otherwise the value is None | |
""" | |
edges = _sorted_incoming_edges(node) | |
inputs = [] | |
inputs_value = [] | |
for edge in edges: | |
if edge.head.name == '_inputs': | |
assert isinstance(edge.head_slot, int) | |
if edge.head.operation.io_names is not None: | |
# when input has names, e.g., forward(self, tensor1, tensor2, another_one) | |
inputs.append(edge.head.operation.io_names[edge.head_slot]) | |
else: | |
# when input has no name, e.g., forward(*_inputs) | |
inputs.append('_inputs[{}]'.format(edge.head_slot)) | |
inputs_value.append(None) | |
else: | |
if edge.head_slot is None: | |
# when the input comes from a single-output operator | |
inputs.append('{}'.format(edge.head.name)) | |
if edge.head.operation.type in ('prim::Constant', 'prim::GetAttr') and \ | |
'value' in edge.head.operation.parameters: | |
inputs_value.append(edge.head.operation.parameters['value']) | |
else: | |
inputs_value.append(None) | |
else: | |
# when the input comes from a multi-output operator: needs to know which one it comes from | |
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot)) | |
inputs_value.append(None) | |
return inputs, inputs_value | |
def _remove_prefix(names, graph_name): | |
""" | |
variables name (full name space) is too long, | |
shorten the name by removing the prefix ```graph_name``` | |
""" | |
if isinstance(names, list): | |
converted_names = [] | |
for name in names: | |
if name.startswith(graph_name): | |
converted_names.append(name[len(graph_name):]) | |
else: | |
converted_names.append(name) | |
return converted_names | |
else: | |
return names[len(graph_name):] if names.startswith(graph_name) else names | |
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str: | |
nodes = graph.topo_sort() | |
# handle module node and function node differently | |
# only need to generate code for module here | |
import_pkgs = set() | |
node_codes = [] | |
for node in nodes: | |
if node.operation: | |
if node.operation.type == 'shared': | |
continue | |
pkg_name = node.operation.get_import_pkg() | |
if pkg_name is not None: | |
import_pkgs.add(pkg_name) | |
node_code = node.operation.to_init_code(_remove_prefix(node.name, graph_name)) | |
if node_code is not None: | |
if placement and node in placement and len(node_code) > 0: | |
node_codes.append(f"{node_code}.to('{placement[node].device}')") | |
else: | |
node_codes.append(node_code) | |
if graph.input_node.operation.io_names is None: | |
input_code = '*_inputs' | |
else: | |
for name in graph.input_node.operation.io_names: | |
assert not name.startswith(graph_name) | |
input_code = ', '.join(graph.input_node.operation.io_names) | |
edge_codes = [] | |
sorted_nodes = graph.topo_sort() | |
for node in sorted_nodes: | |
if node.operation: | |
inputs, inputs_value = _format_inputs(node) | |
inputs = _remove_prefix(inputs, graph_name) | |
node_name = _remove_prefix(node.name, graph_name) | |
submodule_name = node_name | |
if node.operation.type == 'shared': | |
submodule_name = _remove_prefix(node.operation.parameters['reference'], graph_name) | |
edge_codes.append(node.operation.to_forward_code(submodule_name, node_name, inputs, inputs_value)) | |
output_names, _ = _format_inputs(graph.output_node) | |
output_names = _remove_prefix(output_names, graph_name) | |
if not output_names: | |
raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node)) | |
output_code = ', '.join(output_names) | |
linebreak = '\n ' | |
return import_pkgs, _PyTorchModelTemplate.format( | |
graph_name=('Graph' if graph_name == '_graph' else graph_name), | |
inputs=input_code, | |
outputs=output_code, | |
nodes=linebreak.join(node_codes), | |
edges=linebreak.join(edge_codes) | |
) | |
# TODO: handle imports | |
_PyTorchScriptTemplate = ''' | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import nni.retiarii.nn.pytorch | |
{} | |
{} | |
''' | |
_PyTorchModelTemplate = ''' | |
class {graph_name}(nn.Module): | |
def __init__(self): | |
super().__init__() | |
{nodes} | |
def forward(self, {inputs}): | |
{edges} | |
return {outputs} | |
''' | |