Using background tasks to make predictions on ML model

Hey guys,

Appreciate the work done by this community.

I’m using Sanic for a project that involves running predictions using PyTorch on uploaded audio. I’ve used background tasks before but I wanted to ask how much impact this would have on my API’s performance as I want the predictions to be made in a non-blocking way to the regular API that the mobile app consumes.

I currently deploy my server using Docker and NGINX and was considering running another container alongside the regular server which would be responsible for making predictions.

My question is really for best practices from the community here, what would you all recommend? Are background tasks run in a separate thread that won’t affect the server performance and therefore I can just add the code in my current server or is it best to offload the management of these tasks to a completely separate python container?

Thanks in advance!

Hasan

Thanks :sunglasses:

Running this outside of the main worker process. To keep it simple and not have to run multiple services, you can use a variant on the pattern here: Pushing work to the background of your Sanic app

In short: start a managed process and push the work from your endpoint. Keeps your app free and your deployments tidy.

1 Like

@ahopkins

Thanks, man, really appreciate you!

1 Like

Happy to know that I was helpful.

I’ve simply loaded my model as a global variable (in @app.before_server_start) and done (blocking) calls to it in a normal (websocket) handler. This is not as good as running a separate thread/process, but works well enough for a small number of concurrent users especially if your server doesn’t serve many other requests than those to the model (which due to GPU limitations such as the amount of VRAM cannot really run multiple requests in parallel anyway), and especially when a single model prediction doesn’t take a very long time so your web server can respond in reasonable time.

As a complete example (of a quick hack, all public domain), an app that records audio from browser and does near-realtime transcription from any language. I hope this helps, in particular with the audio streaming part which I found a bit difficult (needs recorder2.js as an alternative worklet to support Firefox):

babel.py:

from sanic import Sanic
import numpy as np
import whisper
import json

app = Sanic("babel")

app.static('/', 'static/')
app.static('/', 'static/index.html', name='index')

def unicode_flag(lang):
    offset = ord('🇦') - ord('a')
    lang = dict(
        en="gb",
        fa="ir",  # Farsi, Iran
    ).get(lang, lang).lower()
    return "".join(chr(ord(ch) + offset) for ch in lang)

@app.before_server_start
def load_model(app, _):
    app.ctx.model = whisper.load_model("medium", device="cuda")

@app.websocket("/sock")
async def ws(req, ws):
    model = app.ctx.model
    buf = np.zeros(whisper.audio.N_SAMPLES, dtype="f4")
    pos = 0
    await ws.send(f"[listening]")
    silence = 0
    language = None
    task = "translate"
    while True:
        msg = await ws.recv()
        if isinstance(msg, str):
            m = json.loads(msg.encode())
            print(m)
            language = m.get("language") or None  # Empty string also maps to None
            task = "translate" if m.get("translate") else "transcribe"
            print(language, task)
            continue
        d = np.frombuffer(msg, "i2") / 32767.0
        if abs(d).mean() < 1e-2:
            silence += 1
        else:
            silence = 0
        # If less than second of audio in buffer and mostly silence, skip it
        if silence > 100 and pos < 16000:
            buf[:pos] = 0.0
            pos = 0
            continue
        # Process if buffer is full or is there is at least 3 seconds followed by brief silence
        if pos + len(d) > len(buf) or pos > 72000 and silence > 100:
            #await ws.send(f"[processing {pos / 16000.0:.0f} s]")
            # make log-Mel spectrogram and move to the same device as the model
            mel = whisper.log_mel_spectrogram(buf).to(model.device)

            # detect the spoken language
            #_, probs = model.detect_language(mel)
            #print(f"Detected language: {max(probs, key=probs.get)}")

            # decode the audio
            options = whisper.DecodingOptions(without_timestamps=True, task=task, language=language)
            result = whisper.decode(model, mel, options)
            print(result)

            if result.no_speech_prob < 0.8 and result.avg_logprob > -1.5 and result.compression_ratio < 10.0:
                # print the recognized text
                await ws.send(unicode_flag(result.language) + " " + result.text)
            else:
                await ws.send("[can't hear anything]")

            # Move forward 10 seconds
            buf[:] = 0.0
            pos = 0


        buf[pos: pos + len(d)] = d
        pos += len(d)

static/index.html:

<!DOCTYPE html>
<html lang="en">
<meta charset="UTF-8"/>
<title>Babel</title>
<style>
    * { padding: 0; margin: 0; }
    html { background: black; color: #888; height: 100%; font-size: 32px; }
    button { width: 4rem; height: 4rem; border-radius: 2em; font-size: 2.5rem; background: #ddd; border: none; }
    #text div { margin: 1em; text-align: justify; hyphens: auto; font-size: 2rem; }
    #text { height: calc(100vh - 12     rem); display: flex; flex-direction: column-reverse; overflow-y: auto; }
    footer { height: 4rem; display: flex; justify-content: space-evenly; }
    select { font-size: 1rem; margin-right: 1rem; }
    input[type="checkbox"] { width: 1rem; height: 1rem; margin: 0; }
    .hidden { display: none; }
    footer {
        position: absolute;
        bottom: 0;
    }
</style>
<script src="client.js"></script>
<div id="text"></div>
<footer>
    <div>
        <button id="start" onclick="start()">🎙️</button>
        <button id="stop" class="hidden" onclick="stop()">⏹️</button>
    </div>
    <div id="settings">
        <label for="language">Language: </label>
        <select id="language">
            <option value="">Auto</option>
        </select>
        <input id="translate" type="checkbox" name="translate" checked>
        <label for="translate">to English</label>
    </div>
</footer>

static/client.js:

const languages = {
  "en": "english",
  "zh": "chinese",
  "de": "german",
  "es": "spanish",
  "ru": "russian",
  "ko": "korean",
  "fr": "french",
  "ja": "japanese",
  "pt": "portuguese",
  "tr": "turkish",
  "pl": "polish",
  "ca": "catalan",
  "nl": "dutch",
  "ar": "arabic",
  "sv": "swedish",
  "it": "italian",
  "id": "indonesian",
  "hi": "hindi",
  "fi": "finnish",
  "vi": "vietnamese",
  "iw": "hebrew",
  "uk": "ukrainian",
  "el": "greek",
  "ms": "malay",
  "cs": "czech",
  "ro": "romanian",
  "da": "danish",
  "hu": "hungarian",
  "ta": "tamil",
  "no": "norwegian",
  "th": "thai",
  "ur": "urdu",
  "hr": "croatian",
  "bg": "bulgarian",
  "lt": "lithuanian",
  "la": "latin",
  "mi": "maori",
  "ml": "malayalam",
  "cy": "welsh",
  "sk": "slovak",
  "te": "telugu",
  "fa": "persian",
  "lv": "latvian",
  "bn": "bengali",
  "sr": "serbian",
  "az": "azerbaijani",
  "sl": "slovenian",
  "kn": "kannada",
  "et": "estonian",
  "mk": "macedonian",
  "br": "breton",
  "eu": "basque",
  "is": "icelandic",
  "hy": "armenian",
  "ne": "nepali",
  "mn": "mongolian",
  "bs": "bosnian",
  "kk": "kazakh",
  "sq": "albanian",
  "sw": "swahili",
  "gl": "galician",
  "mr": "marathi",
  "pa": "punjabi",
  "si": "sinhala",
  "km": "khmer",
  "sn": "shona",
  "yo": "yoruba",
  "so": "somali",
  "af": "afrikaans",
  "oc": "occitan",
  "ka": "georgian",
  "be": "belarusian",
  "tg": "tajik",
  "sd": "sindhi",
  "gu": "gujarati",
  "am": "amharic",
  "yi": "yiddish",
  "lo": "lao",
  "uz": "uzbek",
  "fo": "faroese",
  "ht": "haitian creole",
  "ps": "pashto",
  "tk": "turkmen",
  "nn": "nynorsk",
  "mt": "maltese",
  "sa": "sanskrit",
  "lb": "luxembourgish",
  "my": "myanmar",
  "bo": "tibetan",
  "tl": "tagalog",
  "mg": "malagasy",
  "as": "assamese",
  "tt": "tatar",
  "haw": "hawaiian",
  "ln": "lingala",
  "ha": "hausa",
  "ba": "bashkir",
  "jw": "javanese",
  "su": "sundanese",
}

let recorder, ws

const message = msg => {
  const el = document.createElement("div")
  el.innerText = msg
  const parent = document.getElementById("text")
  parent.prepend(el)
  setTimeout(() => parent.removeChild(el), 20000)
}

const connect = async () => {
  console.log("Connecting websocket")
  ws = new WebSocket(location.href.replace(/^http/, "ws") + "sock")
  ws.binaryType = "arraybuffer"
  ws.onmessage = ev => message(ev.data)
  ws.onopen = async (event) => {
    console.log("Websocket connected")
    recorder.port.onmessage = async (ev) => await ws.send(ev.data)
    await settings()
  }
  ws.onclose = reconnect
}

const reconnect = () => {
  recorder.port.onmessage = undefined
  ws = undefined
  console.log("Websocket disconnected (attempting to reconnect)")
  message("⚠️ Attempting to reconnect server")
  setTimeout(connect, 500)
}

const start = async() => {
  let stream
  try {
    stream = await navigator.mediaDevices.getUserMedia({ audio: true, video: false })
  } catch (e) {
    message("⚠️ Unable to access microphone")
    return
  }
  try {
    let context, source
    try {
      // Asking browser to produce us 16 kHz mic capture (Chrome, Safari, etc)
      context = new AudioContext({ sampleRate: 16000, latencyHint: "interactive" })
      source = context.createMediaStreamSource(stream)
      await context.audioWorklet.addModule('recorder.js')
    } catch (e) {
      // Fallback with 48 kHz and downsampling (Firefox)
      context = new AudioContext({ sampleRate: 48000, latencyHint: "interactive" })
      source = context.createMediaStreamSource(stream)
      await context.audioWorklet.addModule('recorder2.js')
    }
    recorder = new AudioWorkletNode(context, 'recorder')
    source.connect(recorder)
  } catch (e) {
    message(`⚠️ Microphone: ${e}`)
    throw e
  }
  document.getElementById("start").classList.add("hidden")
  document.getElementById("stop").classList.remove("hidden")

  console.log("Recording started")
  // Connect websocket for streaming
  setTimeout(connect, 10)
}

const stop = () => location.reload()

const settings = async () => {
  const translate = document.getElementById("translate")
  const language = document.getElementById("language")
  const s = JSON.stringify({language: language.value, translate: translate.checked})
  localStorage['settings'] = s
  if (ws) await ws.send(s)
}

addEventListener("load", () => {
  const translate = document.getElementById("translate")
  const language = document.getElementById("language")
  for (const key in languages) {
    let name = languages[key]
    name = name[0].toUpperCase() + name.substring(1)
    const opt = document.createElement("option")
    opt.value = key
    opt.innerText = name
    language.append(opt)
  }
  language.addEventListener("change", settings)
  translate.addEventListener("change", settings)
  const s = JSON.parse(localStorage['settings'] || "{}")
  if (s.language !== undefined) {
    language.value = s.language
    translate.checked = s.translate
  }
})

static/recorder.js: (normal browsers)


class RecorderWorkletProcessor extends AudioWorkletProcessor {
    process(inputs, outputs, parameters) {
        const f32 = inputs[0][0]
        const i16 = new Int16Array(f32.length)
        for (const i in f32) i16[i] = f32[i] * 32767
        this.port.postMessage(i16)
        return true
    }
}

registerProcessor("recorder", RecorderWorkletProcessor)

static/recorder2.js (Firefox resampling hack)


class RecorderWorkletProcessor extends AudioWorkletProcessor {
    process(inputs, outputs, parameters) {
        const f32 = inputs[0][0]
        const i16 = new Int16Array(f32.length / 3)
        const s = 32767.0 / 3.0
        const l = f32.length
        for (const i in i16) i16[i] = s * (f32[3 * i] + f32[3 * i + (3 * i + 1 < l ? 1 : 0)] + f32[3 * i + (3 * i + 2 < l ? 2 : 0)])
        this.port.postMessage(i16)
        return true
    }
}

registerProcessor("recorder", RecorderWorkletProcessor)
1 Like

Hey @Tronic, thanks for the example! I ended up going with background workers though especially because it takes about 2 minutes to predict on a 3 min audio file for what I’m trying to do. But I’ll keep this in mind for quick prototypes and POCs down the road, appreciate the help!

1 Like