# model.py import torch import torch.nn as nn import torch.nn.functional as F class ConvBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, pool=True): super(ConvBlock, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.pool = nn.MaxPool2d(2) if pool else nn.Identity() def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) x = self.pool(x) return x class FeatureExtractor(nn.Module): def __init__(self): super(FeatureExtractor, self).__init__() self.layer1 = ConvBlock(3, 32) self.layer2 = ConvBlock(32, 64) self.layer3 = ConvBlock(64, 128) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) return x class Classifier(nn.Module): def __init__(self, num_classes=10): super(Classifier, self).__init__() self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(128, num_classes) def forward(self, x): x = self.global_pool(x) x = torch.flatten(x, 1) x = self.fc(x) return x class VisionModel(nn.Module): def __init__(self, num_classes=10): super(VisionModel, self).__init__() self.backbone = FeatureExtractor() self.classifier = Classifier(num_classes) def forward(self, x): x = self.backbone(x) x = self.classifier(x) return x if __name__ == "__main__": model = VisionModel() dummy_input = torch.randn(1, 3, 224, 224) output = model(dummy_input) print("Output shape:", output.shape)