24 lines
732 B
Python
24 lines
732 B
Python
import torch.nn as nn
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
"""
|
|
The IrisClassificationModel torch module. This is the computation graph that was used to
|
|
train the model. Refer to:
|
|
https://github.com/ritual-net/simple-ml-models/tree/main/iris_classification
|
|
"""
|
|
|
|
|
|
class IrisClassificationModel(nn.Module):
|
|
def __init__(self, input_dim: int) -> None:
|
|
super(IrisClassificationModel, self).__init__()
|
|
self.layer1 = nn.Linear(input_dim, 50)
|
|
self.layer2 = nn.Linear(50, 50)
|
|
self.layer3 = nn.Linear(50, 3)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = F.relu(self.layer1(x))
|
|
x = F.relu(self.layer2(x))
|
|
x = F.softmax(self.layer3(x), dim=1)
|
|
return x
|