2024-03-29 17:50:13 +03:00
// SPDX-License-Identifier: BSD-3-Clause-Clear
pragma solidity ^ 0 . 8 . 13 ;
import { console2 } from " forge-std/console2.sol " ;
import { CallbackConsumer } from " infernet-sdk/consumer/Callback.sol " ;
contract IrisClassifier is CallbackConsumer {
string private EXTREMELY_COOL_BANNER = " \n \n "
" _____ _____ _______ _ _ _ \n "
" | __ \\ |_ _|__ __| | | | / \\ | | \n "
" | |__) | | | | | | | | | / \\ | | \n "
" | _ / | | | | | | | |/ / \\ \\ | | \n "
" | | \\ \\ _| |_ | | | |__| / ____ \\ | |____ \n "
" |_| \\ _ \\ _____| |_| \\ ____/_/ \\ _ \\ ______| \n \n " ;
2024-06-06 20:18:48 +03:00
constructor ( address registry ) CallbackConsumer ( registry ) { }
2024-03-29 17:50:13 +03:00
function classifyIris ( ) public {
/// @dev Iris data is in the following format:
/// @dev [sepal_length, sepal_width, petal_length, petal_width]
/// @dev the following vector corresponds to the following properties:
/// "sepal_length": 5.5cm
/// "sepal_width": 2.4cm
/// "petal_length": 3.8cm
/// "petal_width": 1.1cm
/// @dev The data is normalized & scaled.
/// refer to [this function in the model's repository](https://github.com/ritual-net/simple-ml-models/blob/03ebc6fb15d33efe20b7782505b1a65ce3975222/iris_classification/iris_inference_pytorch.py#L13)
/// for more info on normalization.
/// @dev The data is adjusted by 6 decimals
uint256 [ ] memory iris_data = new uint256 [ ] ( 4 ) ;
iris_data [ 0 ] = 1 _038_004 ;
iris_data [ 1 ] = 558 _610 ;
iris_data [ 2 ] = 1 _103_782 ;
iris_data [ 3 ] = 1 _712_096 ;
_requestCompute (
" onnx-iris " ,
abi . encode ( iris_data ) ,
2024-06-06 20:18:48 +03:00
1 , // redundancy
address ( 0 ) , // paymentToken
0 , // paymentAmount
address ( 0 ) , // wallet
address ( 0 ) // prover
2024-03-29 17:50:13 +03:00
) ;
}
function _receiveCompute (
uint32 subscriptionId ,
uint32 interval ,
uint16 redundancy ,
address node ,
bytes calldata input ,
bytes calldata output ,
2024-06-06 20:18:48 +03:00
bytes calldata proof ,
bytes32 containerId ,
uint256 index
2024-03-29 17:50:13 +03:00
) internal override {
console2 . log ( EXTREMELY_COOL_BANNER ) ;
( bytes memory raw_output , bytes memory processed_output ) = abi . decode ( output , ( bytes , bytes ) ) ;
( uint256 [ ] memory classes ) = abi . decode ( raw_output , ( uint256 [ ] ) ) ;
uint256 setosa = classes [ 0 ] ;
uint256 versicolor = classes [ 1 ] ;
uint256 virginica = classes [ 2 ] ;
console2 . log ( " predictions: (adjusted by 6 decimals, 1_000_000 = 100%, 1_000 = 0.1%) " ) ;
console2 . log ( " Setosa: " , setosa ) ;
console2 . log ( " Versicolor: " , versicolor ) ;
console2 . log ( " Virginica: " , virginica ) ;
}
}