feat: publishing infernet-container-starter v0.2.0
This commit is contained in:
25
projects/prompt-to-nft/stablediffusion/src/app.py
Normal file
25
projects/prompt-to-nft/stablediffusion/src/app.py
Normal file
@ -0,0 +1,25 @@
|
||||
from quart import Quart, request, Response
|
||||
|
||||
from stable_diffusion_workflow import StableDiffusionWorkflow
|
||||
|
||||
|
||||
def create_app() -> Quart:
|
||||
app = Quart(__name__)
|
||||
workflow = StableDiffusionWorkflow()
|
||||
workflow.setup()
|
||||
|
||||
@app.get("/")
|
||||
async def hello():
|
||||
return "Hello, World! I'm running stable diffusion"
|
||||
|
||||
@app.post("/service_output")
|
||||
async def service_output():
|
||||
req_data = await request.get_json()
|
||||
image_bytes = workflow.inference(req_data)
|
||||
return Response(image_bytes, mimetype="image/png")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_app().run(host="0.0.0.0", port=3002)
|
10
projects/prompt-to-nft/stablediffusion/src/requirements.txt
Normal file
10
projects/prompt-to-nft/stablediffusion/src/requirements.txt
Normal file
@ -0,0 +1,10 @@
|
||||
diffusers~=0.19
|
||||
invisible_watermark~=0.1
|
||||
transformers==4.36
|
||||
accelerate~=0.21
|
||||
safetensors~=0.3
|
||||
Quart==0.19.4
|
||||
jmespath==1.0.1
|
||||
huggingface-hub==0.20.3
|
||||
infernet_ml==0.1.0
|
||||
PyArweave @ git+https://github.com/ritual-net/pyarweave.git
|
@ -0,0 +1,86 @@
|
||||
import io
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from huggingface_hub import snapshot_download
|
||||
from infernet_ml.workflows.inference.base_inference_workflow import (
|
||||
BaseInferenceWorkflow,
|
||||
)
|
||||
|
||||
|
||||
class StableDiffusionWorkflow(BaseInferenceWorkflow):
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def do_setup(self) -> Any:
|
||||
ignore = [
|
||||
"*.bin",
|
||||
"*.onnx_data",
|
||||
"*/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
snapshot_download(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", ignore_patterns=ignore
|
||||
)
|
||||
snapshot_download(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0",
|
||||
ignore_patterns=ignore,
|
||||
)
|
||||
|
||||
load_options = dict(
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
# Load base model
|
||||
self.base = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", **load_options
|
||||
)
|
||||
|
||||
# Load refiner model
|
||||
self.refiner = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0",
|
||||
text_encoder_2=self.base.text_encoder_2,
|
||||
vae=self.base.vae,
|
||||
**load_options,
|
||||
)
|
||||
|
||||
def do_preprocessing(self, input_data: dict[str, Any]) -> dict[str, Any]:
|
||||
return input_data
|
||||
|
||||
def do_run_model(self, input: dict[str, Any]) -> bytes:
|
||||
negative_prompt = input.get("negative_prompt", "disfigured, ugly, deformed")
|
||||
prompt = input["prompt"]
|
||||
n_steps = input.get("n_steps", 24)
|
||||
high_noise_frac = input.get("high_noise_frac", 0.8)
|
||||
|
||||
image = self.base(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=n_steps,
|
||||
denoising_end=high_noise_frac,
|
||||
output_type="latent",
|
||||
).images
|
||||
|
||||
image = self.refiner(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=n_steps,
|
||||
denoising_start=high_noise_frac,
|
||||
image=image,
|
||||
).images[0]
|
||||
|
||||
byte_stream = io.BytesIO()
|
||||
image.save(byte_stream, format="PNG")
|
||||
image_bytes = byte_stream.getvalue()
|
||||
|
||||
return image_bytes
|
||||
|
||||
def do_postprocessing(self, input: Any, output: Any) -> Any:
|
||||
return output
|
Reference in New Issue
Block a user