File size: 2,157 Bytes
6c3a422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bad48d
6c3a422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e2085e
6c3a422
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# import torch

# # Utility to log tensor info
# def log_tensor(name, x):
#     print(f"--- {name} ---")
#     print(f"shape: {x.shape}, dtype: {x.dtype}, device: {x.device}")
#     print(f"min: {x.min().item():.6f}, max: {x.max().item():.6f}, mean: {x.mean().item():.6f}, sum: {x.sum().item():.6f}")
#     print(f"full tensor:\n{x}\n")

# # Simple function
# def g(x, y):
#     z = x + y
#     return z

# # Compiled function
# @torch.compile(backend="eager")
# def f(x):
#     x = torch.sin(x)
#     x = g(x, x)
#     return x

# # --- Initialization / run once ---
# def init_and_run_once():
#     # Example input
#     x = torch.ones(3, 3, dtype=torch.float32)
#     print("=== INITIAL INPUT ===")
#     log_tensor("original input x", x)

#     # Run compiled function once
#     out = f(x)

#     print("=== OUTPUT AFTER FIRST RUN ===")
#     log_tensor("final output", out)

#     # Return output if needed
#     return out

# # Run once at import / init
# if __name__ == "__main__":
#     init_and_run_once()

import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


#We create an instance of NeuralNetwork, and move it to the device, and print its structure.
model = NeuralNetwork().to(device)
print(model)

#To use the model, we pass it the input data. This executes the model’s forward, along with some background operations. 

X = torch.rand(1, 28, 28, device=device)
logits = model(X)
pred_probab = nn.Softmax(dim=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")