infernet-1.0.0 update
This commit is contained in:
		| @ -2,10 +2,11 @@ import logging | ||||
| import os | ||||
| from typing import Any, cast | ||||
|  | ||||
| from eth_abi import decode, encode  # type: ignore | ||||
| from infernet_ml.utils.service_models import InfernetInput, InfernetInputSource | ||||
| from eth_abi.abi import decode, encode | ||||
| from infernet_ml.utils.service_models import InfernetInput, JobLocation | ||||
| from infernet_ml.workflows.inference.tgi_client_inference_workflow import ( | ||||
|     TGIClientInferenceWorkflow, | ||||
|     TgiInferenceRequest, | ||||
| ) | ||||
| from quart import Quart, request | ||||
|  | ||||
| @ -16,7 +17,7 @@ def create_app() -> Quart: | ||||
|     app = Quart(__name__) | ||||
|  | ||||
|     workflow = TGIClientInferenceWorkflow( | ||||
|         server_url=cast(str, os.environ.get("TGI_SERVICE_URL")) | ||||
|         server_url=os.environ["TGI_SERVICE_URL"], | ||||
|     ) | ||||
|  | ||||
|     workflow.setup() | ||||
| @ -38,42 +39,51 @@ def create_app() -> Quart: | ||||
|         """ | ||||
|         infernet_input: InfernetInput = InfernetInput(**req_data) | ||||
|  | ||||
|         if infernet_input.source == InfernetInputSource.OFFCHAIN: | ||||
|             prompt = cast(dict[str, Any], infernet_input.data).get("prompt") | ||||
|         else: | ||||
|             # On-chain requests are sent as a generalized hex-string which we will | ||||
|             # decode to the appropriate format. | ||||
|             (prompt,) = decode( | ||||
|                 ["string"], bytes.fromhex(cast(str, infernet_input.data)) | ||||
|             ) | ||||
|         match infernet_input: | ||||
|             case InfernetInput(source=JobLocation.OFFCHAIN): | ||||
|                 prompt = cast(dict[str, Any], infernet_input.data).get("prompt") | ||||
|             case InfernetInput(source=JobLocation.ONCHAIN): | ||||
|                 # On-chain requests are sent as a generalized hex-string which we will | ||||
|                 # decode to the appropriate format. | ||||
|                 (prompt,) = decode( | ||||
|                     ["string"], bytes.fromhex(cast(str, infernet_input.data)) | ||||
|                 ) | ||||
|             case _: | ||||
|                 raise ValueError("Invalid source") | ||||
|  | ||||
|         result: dict[str, Any] = workflow.inference({"text": prompt}) | ||||
|         result: dict[str, Any] = workflow.inference( | ||||
|             TgiInferenceRequest(text=cast(str, prompt)) | ||||
|         ) | ||||
|  | ||||
|         if infernet_input.source == InfernetInputSource.OFFCHAIN: | ||||
|             """ | ||||
|             In case of an off-chain request, the result is returned as a dict. The | ||||
|             infernet node expects a dict format. | ||||
|             """ | ||||
|             return {"data": result} | ||||
|         else: | ||||
|             """ | ||||
|             In case of an on-chain request, the result is returned in the format: | ||||
|             { | ||||
|                 "raw_input": str, | ||||
|                 "processed_input": str, | ||||
|                 "raw_output": str, | ||||
|                 "processed_output": str, | ||||
|                 "proof": str, | ||||
|             } | ||||
|             refer to: https://docs.ritual.net/infernet/node/containers for more info. | ||||
|             """ | ||||
|             return { | ||||
|                 "raw_input": "", | ||||
|                 "processed_input": "", | ||||
|                 "raw_output": encode(["string"], [result]).hex(), | ||||
|                 "processed_output": "", | ||||
|                 "proof": "", | ||||
|             } | ||||
|         match infernet_input: | ||||
|             case InfernetInput(destination=JobLocation.OFFCHAIN): | ||||
|                 """ | ||||
|                 In case of an off-chain request, the result is returned as a dict. The | ||||
|                 infernet node expects a dict format. | ||||
|                 """ | ||||
|                 return {"data": result} | ||||
|             case InfernetInput(destination=JobLocation.ONCHAIN): | ||||
|                 """ | ||||
|                 In case of an on-chain request, the result is returned in the format: | ||||
|                 { | ||||
|                     "raw_input": str, | ||||
|                     "processed_input": str, | ||||
|                     "raw_output": str, | ||||
|                     "processed_output": str, | ||||
|                     "proof": str, | ||||
|                 } | ||||
|                 refer to: https://docs.ritual.net/infernet/node/containers for more | ||||
|                 info. | ||||
|                 """ | ||||
|                 return { | ||||
|                     "raw_input": "", | ||||
|                     "processed_input": "", | ||||
|                     "raw_output": encode(["string"], [result]).hex(), | ||||
|                     "processed_output": "", | ||||
|                     "proof": "", | ||||
|                 } | ||||
|             case _: | ||||
|                 raise ValueError("Invalid destination") | ||||
|  | ||||
|     return app | ||||
|  | ||||
|  | ||||
| @ -1,6 +1,5 @@ | ||||
| quart==0.19.4 | ||||
| infernet_ml==0.1.0 | ||||
| PyArweave @ git+https://github.com/ritual-net/pyarweave.git | ||||
| infernet-ml==1.0.0 | ||||
| infernet-ml[tgi_inference]==1.0.0 | ||||
| web3==6.15.0 | ||||
| retry2==0.9.5 | ||||
| text-generation==0.6.1 | ||||
|  | ||||
		Reference in New Issue
	
	Block a user