| using System; |
| using System.Collections.Generic; |
| using System.IO; |
| using Unity.InferenceEngine; |
| using UnityEngine; |
| using UnityEngine.UI; |
| using UnityEngine.Video; |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| public class RunYOLO : MonoBehaviour |
| { |
| [Tooltip("Drag a YOLO model .onnx file here")] |
| public ModelAsset modelAsset; |
|
|
| [Tooltip("Drag the classes.txt here")] |
| public TextAsset classesAsset; |
|
|
| [Tooltip("Create a Raw Image in the scene and link it here")] |
| public RawImage displayImage; |
|
|
| [Tooltip("Drag a border box texture here")] |
| public Texture2D borderTexture; |
|
|
| [Tooltip("Select an appropriate font for the labels")] |
| public Font font; |
|
|
| [Tooltip("Change this to the name of the video you put in the Assets/StreamingAssets folder")] |
| public string videoFilename = "giraffes.mp4"; |
|
|
| const BackendType backend = BackendType.GPUCompute; |
|
|
| private Transform displayLocation; |
| private Worker worker; |
| private string[] labels; |
| private RenderTexture targetRT; |
| private Sprite borderSprite; |
|
|
| |
| private const int imageWidth = 640; |
| private const int imageHeight = 640; |
|
|
| private VideoPlayer video; |
|
|
| List<GameObject> boxPool = new(); |
|
|
| [Tooltip("Intersection over union threshold used for non-maximum suppression")] |
| [SerializeField, Range(0, 1)] |
| float iouThreshold = 0.5f; |
|
|
| [Tooltip("Confidence score threshold used for non-maximum suppression")] |
| [SerializeField, Range(0, 1)] |
| float scoreThreshold = 0.5f; |
|
|
| Tensor<float> centersToCorners; |
| |
| public struct BoundingBox |
| { |
| public float centerX; |
| public float centerY; |
| public float width; |
| public float height; |
| public string label; |
| } |
|
|
| void Start() |
| { |
| Application.targetFrameRate = 60; |
| Screen.orientation = ScreenOrientation.LandscapeLeft; |
|
|
| |
| labels = classesAsset.text.Split('\n'); |
|
|
| LoadModel(); |
|
|
| targetRT = new RenderTexture(imageWidth, imageHeight, 0); |
|
|
| |
| displayLocation = displayImage.transform; |
|
|
| SetupInput(); |
|
|
| borderSprite = Sprite.Create(borderTexture, new Rect(0, 0, borderTexture.width, borderTexture.height), new Vector2(borderTexture.width / 2, borderTexture.height / 2)); |
| } |
| void LoadModel() |
| { |
| |
| var model1 = ModelLoader.Load(modelAsset); |
|
|
| centersToCorners = new Tensor<float>(new TensorShape(4, 4), |
| new float[] |
| { |
| 1, 0, 1, 0, |
| 0, 1, 0, 1, |
| -0.5f, 0, 0.5f, 0, |
| 0, -0.5f, 0, 0.5f |
| }); |
|
|
| |
| var graph = new FunctionalGraph(); |
| var inputs = graph.AddInputs(model1); |
| var modelOutput = Functional.Forward(model1, inputs)[0]; |
| var boxCoords = modelOutput[0, 0..4, ..].Transpose(0, 1); |
| var allScores = modelOutput[0, 4.., ..]; |
| var scores = Functional.ReduceMax(allScores, 0); |
| var classIDs = Functional.ArgMax(allScores, 0); |
| var boxCorners = Functional.MatMul(boxCoords, Functional.Constant(centersToCorners)); |
| var indices = Functional.NMS(boxCorners, scores, iouThreshold, scoreThreshold); |
| var coords = Functional.IndexSelect(boxCoords, 0, indices); |
| var labelIDs = Functional.IndexSelect(classIDs, 0, indices); |
|
|
| |
| worker = new Worker(graph.Compile(coords, labelIDs), backend); |
| } |
|
|
| void SetupInput() |
| { |
| video = gameObject.AddComponent<VideoPlayer>(); |
| video.renderMode = VideoRenderMode.APIOnly; |
| video.source = VideoSource.Url; |
| video.url = Path.Join(Application.streamingAssetsPath, videoFilename); |
| video.isLooping = true; |
| video.Play(); |
| } |
|
|
| private void Update() |
| { |
| ExecuteML(); |
|
|
| if (Input.GetKeyDown(KeyCode.Escape)) |
| { |
| Application.Quit(); |
| } |
| } |
|
|
| public void ExecuteML() |
| { |
| ClearAnnotations(); |
|
|
| if (video && video.texture) |
| { |
| float aspect = video.width * 1f / video.height; |
| Graphics.Blit(video.texture, targetRT, new Vector2(1f / aspect, 1), new Vector2(0, 0)); |
| displayImage.texture = targetRT; |
| } |
| else return; |
|
|
| using Tensor<float> inputTensor = new Tensor<float>(new TensorShape(1, 3, imageHeight, imageWidth)); |
| TextureConverter.ToTensor(targetRT, inputTensor, default); |
| worker.Schedule(inputTensor); |
|
|
| using var output = (worker.PeekOutput("output_0") as Tensor<float>).ReadbackAndClone(); |
| using var labelIDs = (worker.PeekOutput("output_1") as Tensor<int>).ReadbackAndClone(); |
|
|
| float displayWidth = displayImage.rectTransform.rect.width; |
| float displayHeight = displayImage.rectTransform.rect.height; |
|
|
| float scaleX = displayWidth / imageWidth; |
| float scaleY = displayHeight / imageHeight; |
|
|
| int boxesFound = output.shape[0]; |
| |
| for (int n = 0; n < Mathf.Min(boxesFound, 200); n++) |
| { |
| var box = new BoundingBox |
| { |
| centerX = output[n, 0] * scaleX - displayWidth / 2, |
| centerY = output[n, 1] * scaleY - displayHeight / 2, |
| width = output[n, 2] * scaleX, |
| height = output[n, 3] * scaleY, |
| label = labels[labelIDs[n]], |
| }; |
| DrawBox(box, n, displayHeight * 0.05f); |
| } |
| } |
|
|
| public void DrawBox(BoundingBox box, int id, float fontSize) |
| { |
| |
| GameObject panel; |
| if (id < boxPool.Count) |
| { |
| panel = boxPool[id]; |
| panel.SetActive(true); |
| } |
| else |
| { |
| panel = CreateNewBox(Color.yellow); |
| } |
| |
| panel.transform.localPosition = new Vector3(box.centerX, -box.centerY); |
|
|
| |
| RectTransform rt = panel.GetComponent<RectTransform>(); |
| rt.sizeDelta = new Vector2(box.width, box.height); |
|
|
| |
| var label = panel.GetComponentInChildren<Text>(); |
| label.text = box.label; |
| label.fontSize = (int)fontSize; |
| } |
|
|
| public GameObject CreateNewBox(Color color) |
| { |
| |
|
|
| var panel = new GameObject("ObjectBox"); |
| panel.AddComponent<CanvasRenderer>(); |
| Image img = panel.AddComponent<Image>(); |
| img.color = color; |
| img.sprite = borderSprite; |
| img.type = Image.Type.Sliced; |
| panel.transform.SetParent(displayLocation, false); |
|
|
| |
|
|
| var text = new GameObject("ObjectLabel"); |
| text.AddComponent<CanvasRenderer>(); |
| text.transform.SetParent(panel.transform, false); |
| Text txt = text.AddComponent<Text>(); |
| txt.font = font; |
| txt.color = color; |
| txt.fontSize = 40; |
| txt.horizontalOverflow = HorizontalWrapMode.Overflow; |
|
|
| RectTransform rt2 = text.GetComponent<RectTransform>(); |
| rt2.offsetMin = new Vector2(20, rt2.offsetMin.y); |
| rt2.offsetMax = new Vector2(0, rt2.offsetMax.y); |
| rt2.offsetMin = new Vector2(rt2.offsetMin.x, 0); |
| rt2.offsetMax = new Vector2(rt2.offsetMax.x, 30); |
| rt2.anchorMin = new Vector2(0, 0); |
| rt2.anchorMax = new Vector2(1, 1); |
|
|
| boxPool.Add(panel); |
| return panel; |
| } |
|
|
| public void ClearAnnotations() |
| { |
| foreach (var box in boxPool) |
| { |
| box.SetActive(false); |
| } |
| } |
|
|
| void OnDestroy() |
| { |
| centersToCorners?.Dispose(); |
| worker?.Dispose(); |
| } |
| } |
|
|