Gemma3 / app.py
rahul7star's picture
Update app.py
6c3a422 verified
# import torch
# # Utility to log tensor info
# def log_tensor(name, x):
# print(f"--- {name} ---")
# print(f"shape: {x.shape}, dtype: {x.dtype}, device: {x.device}")
# print(f"min: {x.min().item():.6f}, max: {x.max().item():.6f}, mean: {x.mean().item():.6f}, sum: {x.sum().item():.6f}")
# print(f"full tensor:\n{x}\n")
# # Simple function
# def g(x, y):
# z = x + y
# return z
# # Compiled function
# @torch.compile(backend="eager")
# def f(x):
# x = torch.sin(x)
# x = g(x, x)
# return x
# # --- Initialization / run once ---
# def init_and_run_once():
# # Example input
# x = torch.ones(3, 3, dtype=torch.float32)
# print("=== INITIAL INPUT ===")
# log_tensor("original input x", x)
# # Run compiled function once
# out = f(x)
# print("=== OUTPUT AFTER FIRST RUN ===")
# log_tensor("final output", out)
# # Return output if needed
# return out
# # Run once at import / init
# if __name__ == "__main__":
# init_and_run_once()
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
#We create an instance of NeuralNetwork, and move it to the device, and print its structure.
model = NeuralNetwork().to(device)
print(model)
#To use the model, we pass it the input data. This executes the model’s forward, along with some background operations.
X = torch.rand(1, 28, 28, device=device)
logits = model(X)
pred_probab = nn.Softmax(dim=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")