64 lines
1.7 KiB
Python
64 lines
1.7 KiB
Python
# model.py
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class ConvBlock(nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size=3, pool=True):
|
|
super(ConvBlock, self).__init__()
|
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1)
|
|
self.bn = nn.BatchNorm2d(out_channels)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.pool = nn.MaxPool2d(2) if pool else nn.Identity()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
x = self.relu(x)
|
|
x = self.pool(x)
|
|
return x
|
|
|
|
class FeatureExtractor(nn.Module):
|
|
def __init__(self):
|
|
super(FeatureExtractor, self).__init__()
|
|
self.layer1 = ConvBlock(3, 32)
|
|
self.layer2 = ConvBlock(32, 64)
|
|
self.layer3 = ConvBlock(64, 128)
|
|
|
|
def forward(self, x):
|
|
x = self.layer1(x)
|
|
x = self.layer2(x)
|
|
x = self.layer3(x)
|
|
return x
|
|
|
|
class Classifier(nn.Module):
|
|
def __init__(self, num_classes=10):
|
|
super(Classifier, self).__init__()
|
|
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.fc = nn.Linear(128, num_classes)
|
|
|
|
def forward(self, x):
|
|
x = self.global_pool(x)
|
|
x = torch.flatten(x, 1)
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
class VisionModel(nn.Module):
|
|
def __init__(self, num_classes=10):
|
|
super(VisionModel, self).__init__()
|
|
self.backbone = FeatureExtractor()
|
|
self.classifier = Classifier(num_classes)
|
|
|
|
def forward(self, x):
|
|
x = self.backbone(x)
|
|
x = self.classifier(x)
|
|
return x
|
|
|
|
if __name__ == "__main__":
|
|
model = VisionModel()
|
|
dummy_input = torch.randn(1, 3, 224, 224)
|
|
output = model(dummy_input)
|
|
print("Output shape:", output.shape)
|
|
|