ritual/projects/torch-iris/container/src/iris_classification_model.py

24 lines
732 B
Python

import torch.nn as nn
import torch
import torch.nn.functional as F
"""
The IrisClassificationModel torch module. This is the computation graph that was used to
train the model. Refer to:
https://github.com/ritual-net/simple-ml-models/tree/main/iris_classification
"""
class IrisClassificationModel(nn.Module):
def __init__(self, input_dim: int) -> None:
super(IrisClassificationModel, self).__init__()
self.layer1 = nn.Linear(input_dim, 50)
self.layer2 = nn.Linear(50, 50)
self.layer3 = nn.Linear(50, 3)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
x = F.softmax(self.layer3(x), dim=1)
return x