| const MODEL_PATH = '../onnx/model.onnx'; |
| const INPUT_WIDTH = 1092; |
| const INPUT_HEIGHT = 546; |
|
|
| let session = null; |
| const statusElement = document.getElementById('status'); |
| const runBtn = document.getElementById('runBtn'); |
| const imageInput = document.getElementById('imageInput'); |
| const inputCanvas = document.getElementById('inputCanvas'); |
| const outputCanvas = document.getElementById('outputCanvas'); |
| const inputCtx = inputCanvas.getContext('2d'); |
| const outputCtx = outputCanvas.getContext('2d'); |
|
|
| |
| async function init() { |
| try { |
| |
| ort.env.debug = true; |
| ort.env.logLevel = 'verbose'; |
|
|
| statusElement.textContent = 'Loading model... (this may take a while)'; |
| |
| const options = { |
| executionProviders: ['webgpu'], |
| }; |
| session = await ort.InferenceSession.create(MODEL_PATH, options); |
| statusElement.textContent = 'Model loaded. Ready.'; |
| runBtn.disabled = false; |
| } catch (e) { |
| console.error(e); |
| statusElement.textContent = 'Error loading model: ' + e.message; |
| |
| try { |
| statusElement.textContent = 'WebGPU failed, trying WASM...'; |
| session = await ort.InferenceSession.create(MODEL_PATH, { executionProviders: ['wasm'] }); |
| statusElement.textContent = 'Model loaded (WASM). Ready.'; |
| runBtn.disabled = false; |
| } catch (e2) { |
| statusElement.textContent = 'Error loading model (WASM): ' + e2.message; |
| } |
| } |
| } |
|
|
| imageInput.addEventListener('change', (e) => { |
| const file = e.target.files[0]; |
| if (!file) return; |
|
|
| const img = new Image(); |
| img.onload = () => { |
| inputCanvas.width = INPUT_WIDTH; |
| inputCanvas.height = INPUT_HEIGHT; |
| inputCtx.drawImage(img, 0, 0, INPUT_WIDTH, INPUT_HEIGHT); |
| |
| |
| outputCanvas.width = INPUT_WIDTH; |
| outputCanvas.height = INPUT_HEIGHT; |
| outputCtx.clearRect(0, 0, INPUT_WIDTH, INPUT_HEIGHT); |
| }; |
| img.src = URL.createObjectURL(file); |
| }); |
|
|
| runBtn.addEventListener('click', async () => { |
| if (!session) return; |
| |
| statusElement.textContent = 'Running inference...'; |
| runBtn.disabled = true; |
|
|
| try { |
| |
| const imageData = inputCtx.getImageData(0, 0, INPUT_WIDTH, INPUT_HEIGHT); |
| const tensor = preprocess(imageData); |
|
|
| |
| const feeds = { pixel_values: tensor }; |
| const results = await session.run(feeds); |
| const output = results.predicted_depth; |
|
|
| |
| visualize(output.data, INPUT_WIDTH, INPUT_HEIGHT); |
| statusElement.textContent = 'Done.'; |
| } catch (e) { |
| console.error(e); |
| statusElement.textContent = 'Error running inference: ' + e.message; |
| } finally { |
| runBtn.disabled = false; |
| } |
| }); |
|
|
| function preprocess(imageData) { |
| const { data, width, height } = imageData; |
| const float32Data = new Float32Array(3 * width * height); |
| |
| |
| for (let i = 0; i < width * height; i++) { |
| const r = data[i * 4] / 255.0; |
| const g = data[i * 4 + 1] / 255.0; |
| const b = data[i * 4 + 2] / 255.0; |
|
|
| float32Data[i] = r; |
| float32Data[width * height + i] = g; |
| float32Data[2 * width * height + i] = b; |
| } |
|
|
| return new ort.Tensor('float32', float32Data, [1, 3, height, width]); |
| } |
|
|
| function visualize(data, width, height) { |
| |
| let min = Infinity; |
| let max = -Infinity; |
| for (let i = 0; i < data.length; i++) { |
| if (data[i] < min) min = data[i]; |
| if (data[i] > max) max = data[i]; |
| } |
|
|
| const range = max - min; |
| const imageData = outputCtx.createImageData(width, height); |
| |
| for (let i = 0; i < data.length; i++) { |
| |
| const val = (data[i] - min) / (range || 1); |
| |
| |
| |
| |
| |
| |
| |
| const pixelVal = Math.floor((1 - val) * 255); |
|
|
| imageData.data[i * 4] = pixelVal; |
| imageData.data[i * 4 + 1] = pixelVal; |
| imageData.data[i * 4 + 2] = pixelVal; |
| imageData.data[i * 4 + 3] = 255; |
| } |
| |
| outputCtx.putImageData(imageData, 0, 0); |
| } |
|
|
| init(); |