copydata-1/model.py
2025-05-22 06:51:51 +00:00

64 lines
1.7 KiB
Python

# model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, pool=True):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(2) if pool else nn.Identity()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.pool(x)
return x
class FeatureExtractor(nn.Module):
def __init__(self):
super(FeatureExtractor, self).__init__()
self.layer1 = ConvBlock(3, 32)
self.layer2 = ConvBlock(32, 64)
self.layer3 = ConvBlock(64, 128)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
class Classifier(nn.Module):
def __init__(self, num_classes=10):
super(Classifier, self).__init__()
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(128, num_classes)
def forward(self, x):
x = self.global_pool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
class VisionModel(nn.Module):
def __init__(self, num_classes=10):
super(VisionModel, self).__init__()
self.backbone = FeatureExtractor()
self.classifier = Classifier(num_classes)
def forward(self, x):
x = self.backbone(x)
x = self.classifier(x)
return x
if __name__ == "__main__":
model = VisionModel()
dummy_input = torch.randn(1, 3, 224, 224)
output = model(dummy_input)
print("Output shape:", output.shape)