feat: publishing infernet-container-starter v0.2.0
This commit is contained in:
		
							
								
								
									
										110
									
								
								projects/torch-iris/container/src/app.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								projects/torch-iris/container/src/app.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,110 @@ | ||||
| import logging | ||||
| from typing import Any, cast, List | ||||
|  | ||||
| from eth_abi import decode, encode  # type: ignore | ||||
| from infernet_ml.utils.model_loader import ModelSource | ||||
| from infernet_ml.utils.service_models import InfernetInput, InfernetInputSource | ||||
| from infernet_ml.workflows.inference.torch_inference_workflow import ( | ||||
|     TorchInferenceWorkflow, | ||||
| ) | ||||
| from quart import Quart, request | ||||
|  | ||||
| # Note: the IrisClassificationModel needs to be imported in this file for it to exist | ||||
| # in the classpath. This is because pytorch requires the model to be in the classpath. | ||||
| # Simply downloading the weights and model from the hub is not enough. | ||||
| from iris_classification_model import IrisClassificationModel | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| def create_app() -> Quart: | ||||
|     app = Quart(__name__) | ||||
|     # we are downloading the model from the hub. | ||||
|     # model repo is located at: https://huggingface.co/Ritual-Net/iris-dataset | ||||
|     model_source = ModelSource.HUGGINGFACE_HUB | ||||
|     model_args = {"repo_id": "Ritual-Net/iris-dataset", "filename": "iris.torch"} | ||||
|  | ||||
|     workflow = TorchInferenceWorkflow(model_source=model_source, model_args=model_args) | ||||
|     workflow.setup() | ||||
|  | ||||
|     @app.route("/") | ||||
|     def index() -> str: | ||||
|         """ | ||||
|         Utility endpoint to check if the service is running. | ||||
|         """ | ||||
|         return ( | ||||
|             f"Torch Iris Classifier Example Program: {IrisClassificationModel.__name__}" | ||||
|         ) | ||||
|  | ||||
|     @app.route("/service_output", methods=["POST"]) | ||||
|     async def inference() -> dict[str, Any]: | ||||
|         req_data = await request.get_json() | ||||
|         """ | ||||
|         InfernetInput has the format: | ||||
|             source: (0 on-chain, 1 off-chain) | ||||
|             data: dict[str, Any] | ||||
|         """ | ||||
|         infernet_input: InfernetInput = InfernetInput(**req_data) | ||||
|  | ||||
|         if infernet_input.source == InfernetInputSource.OFFCHAIN: | ||||
|             web2_input = cast(dict[str, Any], infernet_input.data) | ||||
|             values = cast(List[List[float]], web2_input["input"]) | ||||
|         else: | ||||
|             # On-chain requests are sent as a generalized hex-string which we will | ||||
|             # decode to the appropriate format. | ||||
|             web3_input: List[int] = decode( | ||||
|                 ["uint256[]"], bytes.fromhex(cast(str, infernet_input.data)) | ||||
|             )[0] | ||||
|             values = [[float(v) / 1e6 for v in web3_input]] | ||||
|  | ||||
|         """ | ||||
|         The input to the torch inference workflow needs to conform to this format: | ||||
|  | ||||
|         { | ||||
|             "dtype": str, | ||||
|             "values": list[Any] | ||||
|         } | ||||
|  | ||||
|         For more information refer to: | ||||
|         https://docs.ritual.net/ml-workflows/inference-workflows/torch_inference_workflow | ||||
|  | ||||
|         """ | ||||
|         inference_result = workflow.inference({"dtype": "float", "values": values}) | ||||
|  | ||||
|         result = [o.detach().numpy().reshape([-1]).tolist() for o in inference_result] | ||||
|  | ||||
|         if infernet_input.source == InfernetInputSource.OFFCHAIN: | ||||
|             """ | ||||
|             In case of an off-chain request, the result is returned as is. | ||||
|             """ | ||||
|             return {"result": result} | ||||
|         else: | ||||
|             """ | ||||
|             In case of an on-chain request, the result is returned in the format: | ||||
|             { | ||||
|                 "raw_input": str, | ||||
|                 "processed_input": str, | ||||
|                 "raw_output": str, | ||||
|                 "processed_output": str, | ||||
|                 "proof": str, | ||||
|             } | ||||
|             refer to: https://docs.ritual.net/infernet/node/containers for more info. | ||||
|             """ | ||||
|             predictions = cast(List[List[float]], result) | ||||
|             predictions_normalized = [int(p * 1e6) for p in predictions[0]] | ||||
|             return { | ||||
|                 "raw_input": "", | ||||
|                 "processed_input": "", | ||||
|                 "raw_output": encode(["uint256[]"], [predictions_normalized]).hex(), | ||||
|                 "processed_output": "", | ||||
|                 "proof": "", | ||||
|             } | ||||
|  | ||||
|     return app | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     """ | ||||
|     Utility to run the app locally. For development purposes only. | ||||
|     """ | ||||
|     create_app().run(port=3000) | ||||
| @ -0,0 +1,23 @@ | ||||
| import torch.nn as nn | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| """ | ||||
| The IrisClassificationModel torch module. This is the computation graph that was used to | ||||
| train the model. Refer to: | ||||
| https://github.com/ritual-net/simple-ml-models/tree/main/iris_classification | ||||
| """ | ||||
|  | ||||
|  | ||||
| class IrisClassificationModel(nn.Module): | ||||
|     def __init__(self, input_dim: int) -> None: | ||||
|         super(IrisClassificationModel, self).__init__() | ||||
|         self.layer1 = nn.Linear(input_dim, 50) | ||||
|         self.layer2 = nn.Linear(50, 50) | ||||
|         self.layer3 = nn.Linear(50, 3) | ||||
|  | ||||
|     def forward(self, x: torch.Tensor) -> torch.Tensor: | ||||
|         x = F.relu(self.layer1(x)) | ||||
|         x = F.relu(self.layer2(x)) | ||||
|         x = F.softmax(self.layer3(x), dim=1) | ||||
|         return x | ||||
							
								
								
									
										7
									
								
								projects/torch-iris/container/src/requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								projects/torch-iris/container/src/requirements.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,7 @@ | ||||
| quart==0.19.4 | ||||
| infernet_ml==0.1.0 | ||||
| PyArweave @ git+https://github.com/ritual-net/pyarweave.git | ||||
| huggingface-hub==0.17.3 | ||||
| sk2torch==1.2.0 | ||||
| torch==2.1.2 | ||||
| web3==6.15.0 | ||||
		Reference in New Issue
	
	Block a user