infernet-1.0.0 update
This commit is contained in:
		| @ -1,11 +1,16 @@ | ||||
| import logging | ||||
| from typing import Any, cast, List | ||||
| from infernet_ml.utils.common_types import TensorInput | ||||
|  | ||||
| from eth_abi import decode, encode  # type: ignore | ||||
| from infernet_ml.utils.model_loader import ( | ||||
|     HFLoadArgs, | ||||
| ) | ||||
| from infernet_ml.utils.model_loader import ModelSource | ||||
| from infernet_ml.utils.service_models import InfernetInput, InfernetInputSource | ||||
| from infernet_ml.utils.service_models import InfernetInput, JobLocation | ||||
| from infernet_ml.workflows.inference.torch_inference_workflow import ( | ||||
|     TorchInferenceWorkflow, | ||||
|     TorchInferenceInput, | ||||
| ) | ||||
| from quart import Quart, request | ||||
|  | ||||
| @ -21,10 +26,10 @@ def create_app() -> Quart: | ||||
|     app = Quart(__name__) | ||||
|     # we are downloading the model from the hub. | ||||
|     # model repo is located at: https://huggingface.co/Ritual-Net/iris-dataset | ||||
|     model_source = ModelSource.HUGGINGFACE_HUB | ||||
|     model_args = {"repo_id": "Ritual-Net/iris-dataset", "filename": "iris.torch"} | ||||
|  | ||||
|     workflow = TorchInferenceWorkflow(model_source=model_source, model_args=model_args) | ||||
|     workflow = TorchInferenceWorkflow( | ||||
|         model_source=ModelSource.HUGGINGFACE_HUB, | ||||
|         load_args=HFLoadArgs(repo_id="Ritual-Net/iris-dataset", filename="iris.torch"), | ||||
|     ) | ||||
|     workflow.setup() | ||||
|  | ||||
|     @app.route("/") | ||||
| @ -46,16 +51,17 @@ def create_app() -> Quart: | ||||
|         """ | ||||
|         infernet_input: InfernetInput = InfernetInput(**req_data) | ||||
|  | ||||
|         if infernet_input.source == InfernetInputSource.OFFCHAIN: | ||||
|             web2_input = cast(dict[str, Any], infernet_input.data) | ||||
|             values = cast(List[List[float]], web2_input["input"]) | ||||
|         else: | ||||
|             # On-chain requests are sent as a generalized hex-string which we will | ||||
|             # decode to the appropriate format. | ||||
|             web3_input: List[int] = decode( | ||||
|                 ["uint256[]"], bytes.fromhex(cast(str, infernet_input.data)) | ||||
|             )[0] | ||||
|             values = [[float(v) / 1e6 for v in web3_input]] | ||||
|         match infernet_input: | ||||
|             case InfernetInput(source=JobLocation.OFFCHAIN): | ||||
|                 web2_input = cast(dict[str, Any], infernet_input.data) | ||||
|                 values = cast(List[List[float]], web2_input["input"]) | ||||
|             case InfernetInput(source=JobLocation.ONCHAIN): | ||||
|                 web3_input: List[int] = decode( | ||||
|                     ["uint256[]"], bytes.fromhex(cast(str, infernet_input.data)) | ||||
|                 )[0] | ||||
|                 values = [[float(v) / 1e6 for v in web3_input]] | ||||
|             case _: | ||||
|                 raise ValueError("Invalid source") | ||||
|  | ||||
|         """ | ||||
|         The input to the torch inference workflow needs to conform to this format: | ||||
| @ -66,39 +72,52 @@ def create_app() -> Quart: | ||||
|         } | ||||
|  | ||||
|         For more information refer to: | ||||
|         https://docs.ritual.net/ml-workflows/inference-workflows/torch_inference_workflow | ||||
|         https://infernet-ml.docs.ritual.net/reference/infernet_ml/workflows/inference/torch_inference_workflow/?h=torch | ||||
|  | ||||
|         """ | ||||
|         inference_result = workflow.inference({"dtype": "float", "values": values}) | ||||
|         """  # noqa: E501 | ||||
|         log.info("Input values: %s", values) | ||||
|  | ||||
|         result = [o.detach().numpy().reshape([-1]).tolist() for o in inference_result] | ||||
|         _input = TensorInput( | ||||
|             dtype="float", | ||||
|             shape=(1, 4), | ||||
|             values=values, | ||||
|         ) | ||||
|  | ||||
|         if infernet_input.source == InfernetInputSource.OFFCHAIN: | ||||
|             """ | ||||
|             In case of an off-chain request, the result is returned as is. | ||||
|             """ | ||||
|             return {"result": result} | ||||
|         else: | ||||
|             """ | ||||
|             In case of an on-chain request, the result is returned in the format: | ||||
|             { | ||||
|                 "raw_input": str, | ||||
|                 "processed_input": str, | ||||
|                 "raw_output": str, | ||||
|                 "processed_output": str, | ||||
|                 "proof": str, | ||||
|             } | ||||
|             refer to: https://docs.ritual.net/infernet/node/containers for more info. | ||||
|             """ | ||||
|             predictions = cast(List[List[float]], result) | ||||
|             predictions_normalized = [int(p * 1e6) for p in predictions[0]] | ||||
|             return { | ||||
|                 "raw_input": "", | ||||
|                 "processed_input": "", | ||||
|                 "raw_output": encode(["uint256[]"], [predictions_normalized]).hex(), | ||||
|                 "processed_output": "", | ||||
|                 "proof": "", | ||||
|             } | ||||
|         iris_inference_input = TorchInferenceInput(input=_input) | ||||
|  | ||||
|         inference_result = workflow.inference(iris_inference_input) | ||||
|  | ||||
|         result = inference_result.outputs | ||||
|  | ||||
|         match infernet_input: | ||||
|             case InfernetInput(destination=JobLocation.OFFCHAIN): | ||||
|                 """ | ||||
|                 In case of an off-chain request, the result is returned as is. | ||||
|                 """ | ||||
|                 return {"result": result} | ||||
|             case InfernetInput(destination=JobLocation.ONCHAIN): | ||||
|                 """ | ||||
|                 In case of an on-chain request, the result is returned in the format: | ||||
|                 { | ||||
|                     "raw_input": str, | ||||
|                     "processed_input": str, | ||||
|                     "raw_output": str, | ||||
|                     "processed_output": str, | ||||
|                     "proof": str, | ||||
|                 } | ||||
|                 refer to: https://docs.ritual.net/infernet/node/containers for more | ||||
|                 info. | ||||
|                 """ | ||||
|                 predictions_normalized = [int(p * 1e6) for p in result] | ||||
|                 return { | ||||
|                     "raw_input": "", | ||||
|                     "processed_input": "", | ||||
|                     "raw_output": encode(["uint256[]"], [predictions_normalized]).hex(), | ||||
|                     "processed_output": "", | ||||
|                     "proof": "", | ||||
|                 } | ||||
|             case _: | ||||
|                 raise ValueError("Invalid destination") | ||||
|  | ||||
|     return app | ||||
|  | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| quart==0.19.4 | ||||
| infernet_ml==0.1.0 | ||||
| PyArweave @ git+https://github.com/ritual-net/pyarweave.git | ||||
| infernet-ml==1.0.0 | ||||
| infernet-ml[torch_inference]==1.0.0 | ||||
| huggingface-hub==0.17.3 | ||||
| sk2torch==1.2.0 | ||||
| torch==2.1.2 | ||||
|  | ||||
		Reference in New Issue
	
	Block a user