File size: 33,821 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import re

import torch

from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, Placeholder
from ..operation import Cell, Operation
from ..serializer import get_init_parameters_or_fail
from ..utils import get_importable_name
from .op_types import MODULE_EXCEPT_LIST, OpTypeName
from .utils import _convert_name, build_full_name


class GraphConverter:
    def __init__(self):
        self.global_seq = 0
        self.global_graph_id = 0

    def _add_edge_handle_source_node(self, _input, graph_inputs, ir_graph, output_remap, node_index):
        if _input in output_remap:
            assert output_remap[_input].kind() == 'aten::append'
            predecessor_node = output_remap[_input]
            assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
            src_node_idx = None
            src_node = node_index[predecessor_node]
            assert isinstance(src_node, Node)
        elif _input in graph_inputs:
            idx = graph_inputs.index(_input)
            src_node = ir_graph.input_node
            src_node_idx = idx
        else:
            predecessor_node = _input.node()
            assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
            # find out the index of _input in the outputs of predecessor_node
            predecessor_outputs = [_output for _output in predecessor_node.outputs()]
            if len(predecessor_outputs) == 1:
                idx = None
            else:
                idx = predecessor_outputs.index(_input)
            ir_predecessor_node = node_index[predecessor_node]
            src_node_idx = idx
            assert isinstance(ir_predecessor_node, Node)
            src_node = ir_predecessor_node
        return src_node, src_node_idx

    def _add_edge(self, ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False):
        """
        Parameters
        ----------
        ir_graph : Graph
        node : torch._C.Node
        graph_inputs : List[torch._C.Value]
            a list of a script graph's inputs
        node_index : Dict
        new_node : Node
            newly created ir node corresponding to `node`
        output_remap : Dict
        ignore_first : bool
            if it is true, skip the first input
        """
        is_single_input = (len([_input for _input in node.inputs()]) - (1 if ignore_first else 0)) == 1
        new_node_input_idx = 0
        for _input in node.inputs():
            if ignore_first:
                ignore_first = False
                continue
            # handle source node
            src_node, src_node_idx = self._add_edge_handle_source_node(_input, graph_inputs, ir_graph, output_remap, node_index)
            # handle destination node
            dst_node = new_node
            if is_single_input:
                dst_node_idx = None
            else:
                dst_node_idx = new_node_input_idx
            # create edge
            ir_graph.add_edge(head=(src_node, src_node_idx), tail=(dst_node, dst_node_idx))

            new_node_input_idx += 1

    def create_prim_constant_node(self, ir_graph, node, module_name):
        # NOTE: compare with string not type, because the type is defined in pytorch C code.
        # `.kind()` can also be used here
        if node.outputsAt(0).type().str() == 'None':
            attrs = {'type': 'None'}
        else:
            attrs = {'type': node.outputsAt(0).type().str(), 'value': node.outputsAt(0).toIValue()}
        self.global_seq += 1
        new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Constant, self.global_seq),
                                     node.kind(), attrs)
        return new_node

    def handle_prim_attr_node(self, node, module):
        assert node.hasAttribute('name')
        value = None
        if node.inputsAt(0).debugName() == 'self':
            _val = getattr(module, node.s('name'))
            # TODO: serialize complex data type, and output proper error message
            if isinstance(_val, (int, float, str, bool)):
                value = _val
        attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName(), 'value': value}
        return node.kind(), attrs

    def _remove_mangle(self, module_type_str):
        return re.sub('\\.___torch_mangle_\\d+', '', module_type_str)

    def remove_unconnected_nodes(self, ir_graph, targeted_type=None):
        """
        Parameters
        ----------
        ir_graph : Graph
            our ir graph representation
        targeted_type : str
            nodes with ```targeted_type``` will be removed from graph if their fanout is 0.
            ```None``` means removing all the nodes whose fanout is 0.
        """
        # build index of outputs of Node(s)
        node_fanout = set()
        for edge in ir_graph.edges:
            if edge.head.id not in node_fanout:
                node_fanout.add(edge.head.id)

        to_removes = []
        for hidden_node in ir_graph.hidden_nodes:
            if hidden_node.id not in node_fanout:
                assert isinstance(hidden_node, Node)
                if targeted_type is None:
                    to_removes.append(hidden_node)
                elif hidden_node.operation.type == targeted_type:
                    to_removes.append(hidden_node)

        for hidden_node in to_removes:
            hidden_node.remove()

    def handle_graph_nodes(self, script_module, sm_graph,
                           module, module_name,
                           ir_model, ir_graph,
                           shared_module_index=None):
        """
        Convert torch script node to our node ir, and build our graph ir

        Parameters
        ----------
        script_module : torch.jit.RecursiveScriptModule
            the torch script of ```module```
        sm_graph : torch._C.Graph
            the graph in torch script
        module : nn.Module
            the targeted pytorch module
        module_name : str
            ```module```'s name
        ir_model : Model
            the whole graph ir
        ir_graph : Graph
            the graph ir of ```module```
        shared_module_index : dict
            it is used for knowing which module has been created an ir node,
            if created and invoked again, then the new ir node can simply reference that ir node.
            this way we can identify shared modules (i.e., one module invoked multiple times in `forward` function)

        Returns
        -------
        dict
            the mapping from graph node to our graph ir node
        """
        # handle inputs
        graph_inputs = []
        for _input in sm_graph.inputs():
            if _input.debugName() == 'self':
                assert _input.unique() == 0
                continue
            graph_inputs.append(_input)
            # TODO: add scope name
            ir_graph._add_input(_convert_name(_input.debugName()))

        node_index = {}  # graph node to graph ir node
        if shared_module_index is None:
            shared_module_index = {}

        # some node does not have output but it modifies a variable, for example aten::append
        # %17 : Tensor[] = aten::append(%out.1, %16)
        # %out.1 is updated, and %17 is None
        # we add output to this type of node and connect it to the following node which uses %out.1
        # key: tensor (%out.1), value: node (this node)
        output_remap = {}

        # ===================handle control flow: if===================
        def handle_if_condition(cond_tensor):
            """
            to calculate the condition, we only deal with the following op types by tracing back
            `prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`

            generate the expression using recursive calls

            NOTE: do not support dynamic graph
            """
            def _generate_expr(tensor):
                if tensor.node().kind() == 'prim::GetAttr':
                    return f'({getattr(module, tensor.node().s("name"))})'
                elif tensor.node().kind() == 'aten::__getitem__':
                    t = _generate_expr(tensor.node().inputsAt(0))
                    idx = _generate_expr(tensor.node().inputsAt(1))
                    return f'({t}[{idx}])'
                elif tensor.node().kind() == 'prim::Constant':
                    return f'{tensor.toIValue()}'
                elif tensor.node().kind() == 'aten::eq':
                    left = _generate_expr(tensor.node().inputsAt(0))
                    right = _generate_expr(tensor.node().inputsAt(1))
                    return f'({left} == {right})'
                elif tensor.node().kind() == 'aten::le':
                    left = _generate_expr(tensor.node().inputsAt(0))
                    right = _generate_expr(tensor.node().inputsAt(1))
                    return f'({left} <= {right})'
                elif tensor.node().kind() == 'aten::ge':
                    left = _generate_expr(tensor.node().inputsAt(0))
                    right = _generate_expr(tensor.node().inputsAt(1))
                    return f'({left} >= {right})'
                elif tensor.node().kind() == 'aten::__not__':
                    value = _generate_expr(tensor.node().inputsAt(0))
                    return f'(not {value})'
                elif tensor.node().kind() == 'aten::Bool':
                    value = _generate_expr(tensor.node().inputsAt(0))
                    return f'bool({value})'
                elif tensor.node().kind() == 'aten::__is__':
                    left = _generate_expr(tensor.node().inputsAt(0))
                    right = _generate_expr(tensor.node().inputsAt(1))
                    return f'({left} is {right})'
                elif tensor.node().kind() == 'aten::__isnot__':
                    left = _generate_expr(tensor.node().inputsAt(0))
                    right = _generate_expr(tensor.node().inputsAt(1))
                    return f'({left} is not {right})'
                elif tensor.node().kind() == 'aten::ne':
                    left = _generate_expr(tensor.node().inputsAt(0))
                    right = _generate_expr(tensor.node().inputsAt(1))
                    return f'({left} != {right})'
                elif tensor.node().kind() == 'aten::gt':
                    left = _generate_expr(tensor.node().inputsAt(0))
                    right = _generate_expr(tensor.node().inputsAt(1))
                    return f'({left} > {right})'
                elif tensor.node().kind() == 'aten::lt':
                    left = _generate_expr(tensor.node().inputsAt(0))
                    right = _generate_expr(tensor.node().inputsAt(1))
                    return f'({left} < {right})'
                elif tensor.node().kind() == 'prim::If':
                    raise RuntimeError('Have not supported `if A and/or B`, please use two `if` statements instead.')
                else:
                    raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition, '
                                        'you are suggested to decorate the corresponding class with "@basic_unit".')
            expr = _generate_expr(cond_tensor)
            return eval(expr)

        def handle_if_node(node):
            """
            Parameters
            ----------
            node : torch._C.Node
                the node from TorchScript graph

            Returns
            -------
            Node
                the created node ir
            """
            # only deal with input of prim::If is constant or attribute for now
            # will support constant expression in future
            inputs = [i for i in node.inputs()]
            assert len(inputs) == 1
            cond = handle_if_condition(inputs[0])
            chosen_block = 0 if cond else 1
            blocks = [block for block in node.blocks()]
            assert len(blocks) == 2
            last_block_node = None
            for node in blocks[chosen_block].nodes():
                last_block_node = handle_single_node(node)
            self.global_seq += 1
            new_node = ir_graph.add_node(build_full_name(module_name, 'noop_identity', self.global_seq), 'noop_identity')
            self._add_edge(ir_graph, blocks[chosen_block].returnNode(), graph_inputs, node_index, new_node, output_remap)
            last_block_node = new_node
            return last_block_node

        # ===================handle function call===================
        def handle_function_callmethod(node):
            # get and handle the first input, which should be an nn.Module
            assert node.hasAttribute('name')
            # NOTE: "forward__0" is hacky, LSTM instance is parsed to call forward__0 in torchscript
            if node.s('name') in ['forward', 'forward__0']:
                # node.inputsAt(0).type() is <class 'torch._C.ClassType'>
                submodule_type_str = self._remove_mangle(node.inputsAt(0).type().str())
                submodule = node.inputsAt(0).node()
                assert submodule.kind() == 'prim::GetAttr'
                assert submodule.hasAttribute('name')
                submodule_name = submodule.s('name')

                if submodule.inputsAt(0).debugName() == 'self':
                    # module is usually instantiated in __init__.
                    # when calling a module in forward,
                    # prim::GetAttr is used to obtain the module in torch script.
                    # therefore, we do this check for a module. example below:
                    # %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
                    # %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
                    assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(
                        submodule_name, script_module._modules.keys())

                    submodule_full_name = build_full_name(module_name, submodule_name)
                    submodule_obj = getattr(module, submodule_name)
                    subgraph, sub_m_attrs = self.convert_module(script_module._modules[submodule_name],
                                                                submodule_obj,
                                                                submodule_full_name, ir_model)
                else:
                    # %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
                    # %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
                    # %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4)
                    if submodule.inputsAt(0).type().name() == 'ModuleList':
                        # handle ModuleList
                        predecessor = submodule.inputsAt(0).node()
                        module_name_space = [submodule_name]
                        while predecessor.inputsAt(0).debugName() != 'self':
                            # this is for dealing with nested ModuleList. below is an example
                            # %3 : __torch__.torch.nn.modules.container.___torch_mangle_0.ModuleList = prim::GetAttr[name="ops"](%self)
                            # %5 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="0"](%3)
                            # %7 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="1"](%3)
                            # %9 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="2"](%3)
                            # %11 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="3"](%3)
                            # %14 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="0"](%5)
                            # %16 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="1"](%5)
                            # %state.2 : Tensor = prim::CallMethod[name="forward"](%14, %x.1) # modulelist.py:18:24
                            # %state.4 : Tensor = prim::CallMethod[name="forward"](%16, %state.2) # modulelist.py:18:24
                            assert predecessor.kind() == 'prim::GetAttr'
                            module_name_space.append(predecessor.s('name'))
                            predecessor = predecessor.inputsAt(0).node()
                        assert predecessor.kind() == 'prim::GetAttr'
                        assert predecessor.hasAttribute('name')
                        module_name_space.append(predecessor.s('name'))
                        submodule_full_name = build_full_name(module_name, list(reversed(module_name_space)))
                        submodule_obj = module
                        script_submodule = script_module
                        for each_name in list(reversed(module_name_space)):
                            submodule_obj = getattr(submodule_obj, each_name)
                            script_submodule = script_submodule._modules[each_name]
                        subgraph, sub_m_attrs = self.convert_module(script_submodule, submodule_obj, submodule_full_name, ir_model)
                    else:
                        raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))

                if submodule_full_name in shared_module_index:
                    # this module is invoked more than once, the ir node has already been created
                    # create a reference node for it.
                    # example: {"name": "conv2", "operation": {"type": "shared", "parameters": {"reference": "conv1"}}}
                    self.global_seq += 1
                    shared_node_name = build_full_name(submodule_full_name, '', self.global_seq)
                    shared_type_operation = Operation.new('shared', {'reference': submodule_full_name})
                    subcell = ir_graph.add_node(shared_node_name, shared_type_operation)
                else:
                    # this module is processed for the first time, build cell for it
                    if subgraph is None:
                        # if we do not parse this module's graph, we create Node for this module
                        subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs)
                        if isinstance(submodule_obj, Placeholder):
                            subcell.update_label(submodule_obj.label)
                        elif isinstance(submodule_obj, InputChoice):
                            subcell.update_label(sub_m_attrs['label'])
                    else:
                        # Graph already created, create Cell for it
                        new_cell = Cell(cell_name=submodule_full_name, parameters=sub_m_attrs)
                        subcell = ir_graph.add_node(submodule_full_name, new_cell)
                    shared_module_index[submodule_full_name] = subcell
                node_index[node] = subcell
                # connect the cell into graph
                self._add_edge(ir_graph, node, graph_inputs, node_index, subcell, output_remap, ignore_first=True)
            else:
                # handle normal member function
                assert hasattr(script_module, node.s('name'))
                # TODO: support non member functions
                assert node.inputsAt(0).debugName() == 'self'
                script_method = getattr(script_module, node.s('name')) # <class 'torch._C.ScriptMethod'>

                # step #1: generate graph ir for this method
                method_ir_graph = Graph(model=ir_model, graph_id=-100, name='temp_graph', _internal=True)
                method_node_index = self.handle_graph_nodes(script_module, script_method.graph, module,
                                                    module_name, ir_model, method_ir_graph, shared_module_index)
                for _output in script_method.graph.outputs():
                    method_ir_graph._add_output(_convert_name(_output.debugName()))
                    predecessor_node_outputs = [o for o in _output.node().outputs()]
                    if len(predecessor_node_outputs) == 1:
                        src_node_idx = None
                    else:
                        src_node_idx = predecessor_node_outputs.index(_output)
                    method_ir_graph.add_edge(head=(method_node_index[_output.node()], src_node_idx),
                                    tail=(method_ir_graph.output_node, None))
                self.refine_graph(method_ir_graph)

                # step #2: merge this graph to its module graph
                for h_node in method_ir_graph.hidden_nodes:
                    h_node.graph = ir_graph
                    ir_graph.hidden_nodes.append(h_node)
                for edge in method_ir_graph.edges:
                    edge.graph = ir_graph
                    if edge.head == method_ir_graph.input_node:
                        # this is a member method, 'self' is the first argument, thus +1
                        _input = node.inputsAt(edge.head_slot + 1)
                        src_node, src_node_idx = self._add_edge_handle_source_node(_input, graph_inputs, ir_graph, output_remap, node_index)
                        edge.head = src_node
                        edge.head_slot = src_node_idx
                    if edge.tail == method_ir_graph.output_node:
                        # since the following nodes have not been created, skip this edge
                        # edge.head is the output node of this method
                        # TODO: check whether there could be multiple output nodes???
                        node_index[node] = edge.head
                        continue
                    ir_graph.edges.append(edge)

        # ===================handle each single node===================
        def handle_single_node(node):
            """
            Parameters
            ----------
            node : torch._C.Node
                the node from TorchScript graph

            Returns
            -------
            Node
                the created node ir
            """
            if node.kind() == 'prim::CallMethod':
                handle_function_callmethod(node)
            elif node.kind() == 'prim::CallFunction':
                func_type_str = self._remove_mangle(node.inputsAt(0).type().str())
                func = node.inputsAt(0).node()
                assert func.kind() == 'prim::Constant'
                assert func.hasAttribute('name')
                func_name = func.s('name')
                # create node for func
                self.global_seq += 1
                func_node = ir_graph.add_node(build_full_name(module_name, func_name, self.global_seq),
                                              '{}.{}'.format(func_type_str, func_name))
                node_index[node] = func_node
                self._add_edge(ir_graph, node, graph_inputs, node_index, func_node, output_remap, ignore_first=True)
            elif node.kind() == 'prim::Constant':
                new_node = self.create_prim_constant_node(ir_graph, node, module_name)
                node_index[node] = new_node
            elif node.kind() in ['prim::ListConstruct', 'prim::ListUnpack', 'prim::TupleConstruct', 'prim::TupleUnpack']:
                self.global_seq += 1
                prim_op_name = node.kind().split('::')[-1]
                new_node = ir_graph.add_node(build_full_name(module_name, prim_op_name, self.global_seq), node.kind())
                node_index[node] = new_node
                self._add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
            elif node.kind() == 'prim::GetAttr':
                node_type, attrs = self.handle_prim_attr_node(node, module)
                self.global_seq += 1
                new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, self.global_seq),
                                             node_type, attrs)
                node_index[node] = new_node
            elif node.kind() == 'prim::If':
                last_block_node = handle_if_node(node)
                # last_block_node is None means no node in the branch block
                node_index[node] = last_block_node
            elif node.kind() == 'prim::Loop':
                # refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
                raise RuntimeError('Loop has not been supported yet!')
            elif node.kind().startswith('prim::'):
                self.global_seq += 1
                prim_op_name = node.kind().replace('::', '__')
                prim_node = ir_graph.add_node(build_full_name(module_name, prim_op_name, self.global_seq), node.kind())
                node_index[node] = prim_node
                self._add_edge(ir_graph, node, graph_inputs, node_index, prim_node, output_remap)
            elif node.kind() == 'aten::append':
                self.global_seq += 1
                aten_op_name = node.kind().replace('::', '__')
                aten_node = ir_graph.add_node(build_full_name(module_name, aten_op_name, self.global_seq), node.kind())
                node_index[node] = aten_node
                self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
                output_remap[node.inputsAt(0)] = node
            elif node.kind().startswith('aten::'):
                # handle aten::XXX
                self.global_seq += 1
                aten_op_name = node.kind().replace('::', '__')
                aten_node = ir_graph.add_node(build_full_name(module_name, aten_op_name, self.global_seq), node.kind())
                node_index[node] = aten_node
                self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
            else:
                raise RuntimeError('Unsupported kind: {}'.format(node.kind()))

            return node_index[node]

        for node in sm_graph.nodes():
            handle_single_node(node)

        return node_index

    def merge_aten_slices(self, ir_graph):
        """
        if there is aten::slice node, merge the consecutive ones together.
        ```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
        each node has 5 inputs: tensor, dim, x, y, z (i.e., x:y:z)
        """
        head_slice_nodes = []
        has_slice_node = False
        for node in ir_graph.hidden_nodes:
            if node.operation.type == 'aten::slice':
                has_slice_node = True
                for pred in node.predecessors:
                    if pred.operation.type not in ['aten::slice', 'prim::Constant']:
                        head_slice_nodes.append(node)
                        break
        if has_slice_node:
            assert head_slice_nodes

        for head_node in head_slice_nodes:
            slot = 0
            new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice)
            if len(head_node.incoming_edges) == 4:
                # when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
                for edge in head_node.incoming_edges:
                    edge.tail = new_slice_node
                for edge in head_node.outgoing_edges:
                    edge.head = new_slice_node
                ir_graph.hidden_nodes.remove(head_node)
                break
            assert len(head_node.incoming_edges) == 5
            for edge in head_node.incoming_edges:
                edge.tail = new_slice_node
            slot += 5
            node = head_node
            while len(node.successors) == 1 and node.successors[0].operation.type == 'aten::slice':
                suc_node = node.successors[0]
                assert len(suc_node.incoming_edges) == 5
                for edge in suc_node.incoming_edges:
                    if edge.tail_slot == 0:
                        edge.remove()
                    else:
                        edge.tail = new_slice_node
                        edge.tail_slot = slot + edge.tail_slot - 1
                slot += 4
                ir_graph.hidden_nodes.remove(node)
                node = suc_node

            for edge in node.outgoing_edges:
                edge.head = new_slice_node
            ir_graph.hidden_nodes.remove(node)

    def refine_graph(self, ir_graph):
        """
        Do the following process to simplify graph:
        1. remove unconnected constant node
        2. remove unconnected getattr node
        """
        # some constant is not used, for example, function name as prim::Constant
        self.remove_unconnected_nodes(ir_graph, targeted_type='prim::Constant')
        self.remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
        self.merge_aten_slices(ir_graph)

    def _handle_inputchoice(self, module):
        return {
            'n_candidates': module.n_candidates,
            'n_chosen': module.n_chosen,
            'reduction': module.reduction,
            'label': module.label
        }

    def _handle_valuechoice(self, module):
        return {
            'candidates': module.candidates,
            'label': module.label,
            'accessor': module._accessor
        }

    def convert_module(self, script_module, module, module_name, ir_model):
        """
        Convert a module to its graph ir (i.e., Graph) along with its input arguments

        Parameters
        ----------
        script_module : torch.jit.RecursiveScriptModule
            the script module of ```module``` obtained with torch.jit.script
        module : nn.Module
            the targeted module instance
        module_name : str
            the constructed name space of ```module```
        ir_model : Model
            the whole graph ir

        Returns
        -------
        Graph
            the built graph ir from module, ```None``` means do not further parse the module
        dict
            the input arguments of this module
        """

        # NOTE: have not supported nested LayerChoice, i.e., a candidate module
        # also has LayerChoice or InputChoice or ValueChoice
        original_type_name = script_module.original_name
        m_attrs = None
        if original_type_name in MODULE_EXCEPT_LIST:
            pass  # do nothing
        elif original_type_name == OpTypeName.LayerChoice:
            graph = Graph(ir_model, -100, module_name, _internal=True)  # graph_id is not used now
            candidate_name_list = [f'layerchoice_{module.label}_{cand_name}' for cand_name in module.names]
            for cand_name, cand in zip(candidate_name_list, module):
                cand_type = '__torch__.' + get_importable_name(cand.__class__)
                graph.add_node(cand_name, cand_type, get_init_parameters_or_fail(cand))
            graph._register()
            return graph, {'mutation': 'layerchoice', 'label': module.label, 'candidates': candidate_name_list}
        elif original_type_name == OpTypeName.InputChoice:
            m_attrs = self._handle_inputchoice(module)
        elif original_type_name == OpTypeName.ValueChoice:
            m_attrs = self._handle_valuechoice(module)
        elif original_type_name == OpTypeName.Placeholder:
            m_attrs = get_init_parameters_or_fail(module)
        elif module.__class__.__module__.startswith('torch.nn') and original_type_name in torch.nn.__dict__:
            # this is a basic module from pytorch, no need to parse its graph
            m_attrs = get_init_parameters_or_fail(module)
        else:
            # this module is marked as serialize, won't continue to parse
            m_attrs = get_init_parameters_or_fail(module, silently=True)
        if m_attrs is not None:
            return None, m_attrs

        # handle TorchScript graph
        sm_graph = script_module.graph
        self.global_graph_id += 1
        ir_graph = Graph(model=ir_model, graph_id=self.global_graph_id, name=module_name, _internal=True)

        # handle graph nodes
        node_index = self.handle_graph_nodes(script_module, sm_graph, module,
                                             module_name, ir_model, ir_graph)

        # handle graph outputs
        for _output in sm_graph.outputs():
            ir_graph._add_output(_convert_name(_output.debugName()))
            predecessor_node_outputs = [o for o in _output.node().outputs()]
            if len(predecessor_node_outputs) == 1:
                src_node_idx = None
            else:
                src_node_idx = predecessor_node_outputs.index(_output)
            ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx),
                              tail=(ir_graph.output_node, None))

        self.refine_graph(ir_graph)

        ir_graph._register()

        # add mutation signal for special modules
        if original_type_name == OpTypeName.Repeat:
            attrs = {
                'mutation': 'repeat',
                'label': module.label,
                'min_depth': module.min_depth,
                'max_depth': module.max_depth
            }
            return ir_graph, attrs

        return ir_graph, {}


def convert_to_graph(script_module, module):
    """
    Convert module to our graph ir, i.e., build a ```Model``` type

    Parameters
    ----------
    script_module : torch.jit.RecursiveScriptModule
        the script module obtained with torch.jit.script
    module : nn.Module
        the targeted module instance

    Returns
    -------
    Model
        the constructed IR model
    """

    model = Model(_internal=True)
    module_name = '_model'
    GraphConverter().convert_module(script_module, module, module_name, model)

    return model