// Modified from https://github.com/tensorflow/tfjs-models/blob/c5fcabec4cc0335bdfabeb14dc2adfd6ef0cde8c/shared/calculators/render_util.ts#L191
// to fix blurring bug on Safari. Changes are indicated by "MODIFIED" comments below.
// @ts-nocheck
/* eslint-disable valid-jsdoc */

import * as tf from "@tensorflow/tfjs-core";
import {
  Color,
  PixelInput,
  Segmentation,
} from "@tensorflow-models/body-segmentation/dist/shared/calculators/interfaces/common_interfaces";

// MODIFIED - ADDED
import * as StackBlur from "stackblur-canvas";

/**
 * This render_util implementation is based on the body-pix output_rending_util
 * code found here:
 * https://github.com/tensorflow/tfjs-models/blob/master/body-pix/src/output_rendering_util.ts
 * It is adapted to account for the generic segmentation interface.
 */

type ImageType = CanvasImageSource | OffscreenCanvas | PixelInput;

type Canvas = HTMLCanvasElement | OffscreenCanvas;

const CANVAS_NAMES = {
  blurred: "blurred",
  blurredMask: "blurred-mask",
  mask: "mask",
  lowresPartMask: "lowres-part-mask",
  drawImage: "draw-image",
};

const offScreenCanvases: { [name: string]: Canvas } = {};

function isSafari() {
  return /^((?!chrome|android).)*safari/i.test(navigator.userAgent);
}

function getSizeFromImageLikeElement(
  input: HTMLImageElement | HTMLCanvasElement | OffscreenCanvas
): [number, number] {
  if (
    "offsetHeight" in input &&
    input.offsetHeight !== 0 &&
    "offsetWidth" in input &&
    input.offsetWidth !== 0
  ) {
    return [input.offsetHeight, input.offsetWidth];
  } else if (input.height != null && input.width != null) {
    return [input.height, input.width];
  } else {
    throw new Error(
      `HTMLImageElement must have height and width attributes set.`
    );
  }
}

function getSizeFromVideoElement(input: HTMLVideoElement): [number, number] {
  if (input.hasAttribute("height") && input.hasAttribute("width")) {
    // Prioritizes user specified height and width.
    // We can't test the .height and .width properties directly,
    // because they evaluate to 0 if unset.
    return [input.height, input.width];
  } else {
    return [input.videoHeight, input.videoWidth];
  }
}

function getInputSize(input: ImageType): [number, number] {
  if (
    (typeof HTMLCanvasElement !== "undefined" &&
      input instanceof HTMLCanvasElement) ||
    (typeof OffscreenCanvas !== "undefined" &&
      input instanceof OffscreenCanvas) ||
    (typeof HTMLImageElement !== "undefined" &&
      input instanceof HTMLImageElement)
  ) {
    return getSizeFromImageLikeElement(input);
  } else if (typeof ImageData !== "undefined" && input instanceof ImageData) {
    return [input.height, input.width];
  } else if (
    typeof HTMLVideoElement !== "undefined" &&
    input instanceof HTMLVideoElement
  ) {
    return getSizeFromVideoElement(input);
  } else if (input instanceof tf.Tensor) {
    return [input.shape[0], input.shape[1]];
  } else {
    throw new Error(`error: Unknown input type: ${input}.`);
  }
}

function createOffScreenCanvas(): Canvas {
  if (typeof document !== "undefined") {
    return document.createElement("canvas");
  } else if (typeof OffscreenCanvas !== "undefined") {
    return new OffscreenCanvas(0, 0);
  } else {
    throw new Error("Cannot create a canvas in this context");
  }
}

function ensureOffscreenCanvasCreated(id: string): Canvas {
  if (!offScreenCanvases[id]) {
    offScreenCanvases[id] = createOffScreenCanvas();
  }
  return offScreenCanvases[id];
}

/**
 * Draw image data on a canvas.
 */
function renderImageDataToCanvas(image: ImageData, canvas: Canvas) {
  canvas.width = image.width;
  canvas.height = image.height;
  const ctx = canvas.getContext("2d");

  ctx.putImageData(image, 0, 0);
}

function renderImageDataToOffScreenCanvas(
  image: ImageData,
  canvasName: string
): Canvas {
  const canvas = ensureOffscreenCanvasCreated(canvasName);
  renderImageDataToCanvas(image, canvas);

  return canvas;
}

/**
 * Draw image on a 2D rendering context.
 */
async function drawImage(
  ctx: CanvasRenderingContext2D | OffscreenCanvasRenderingContext2D,
  image: ImageType,
  dx: number,
  dy: number,
  dw?: number,
  dh?: number
) {
  if (image instanceof tf.Tensor) {
    const pixels = await tf.browser.toPixels(image);
    const [height, width] = getInputSize(image);
    image = new ImageData(pixels, width, height);
  }
  if (image instanceof ImageData) {
    image = renderImageDataToOffScreenCanvas(image, CANVAS_NAMES.drawImage);
  }
  if (dw == null || dh == null) {
    ctx.drawImage(image, dx, dy);
  } else {
    ctx.drawImage(image, dx, dy, dw, dh);
  }
}

/**
 * Draw image on a canvas.
 */
async function renderImageToCanvas(image: ImageType, canvas: Canvas) {
  const [height, width] = getInputSize(image);
  canvas.width = width;
  canvas.height = height;
  const ctx = canvas.getContext("2d");

  await drawImage(ctx, image, 0, 0, width, height);
}

function flipCanvasHorizontal(canvas: Canvas) {
  const ctx = canvas.getContext("2d");
  ctx.scale(-1, 1);
  ctx.translate(-canvas.width, 0);
}

async function drawWithCompositing(
  ctx: CanvasRenderingContext2D | OffscreenCanvasRenderingContext2D,
  image: ImageType,
  compositeOperation: string
) {
  // TODO: Assert type 'compositeOperation as GlobalCompositeOperation' after
  // typescript update to 4.6.0 or later
  // tslint:disable-next-line: no-any
  ctx.globalCompositeOperation = compositeOperation as any;
  await drawImage(ctx, image, 0, 0);
}

// method copied from bGlur in https://codepen.io/zhaojun/pen/zZmRQe
async function cpuBlur(canvas: Canvas, image: ImageType, blur: number) {
  // MODIFIED - Added the 3 lines below to change this to use the StackBlur library for blurring the image instead.
  // The existing algorithm was buggy, and when fixed, very slow. The existing algorithm works by
  // repeatedly adding the image with varying opacity to the canvas, in a `2*blur` x `2*blur` grid, resulting
  // in a lot of time spent in drawImage calls.
  const ctx = canvas.getContext("2d");
  await drawImage(ctx, image, 0, 0);
  StackBlur.canvasRGBA(canvas, 0, 0, image.width, image.height, blur * 2);

  // MODIFIED - Existing algorithm commented out below
  // const ctx = canvas.getContext("2d");
  //
  // // MODIFIED - This is part 1 of 2 of the bug fix for the existing algorithm
  // await drawImage(ctx, image, 0, 0);
  //
  // let sum = 0;
  // const delta = 5;
  // const alphaLeft = 1 / (2 * Math.PI * delta * delta);
  // const step = blur < 3 ? 1 : 2;
  // for (let y = -blur; y <= blur; y += step) {
  //   for (let x = -blur; x <= blur; x += step) {
  //     const weight =
  //       alphaLeft * Math.exp(-(x * x + y * y) / (2 * delta * delta));
  //     sum += weight;
  //   }
  // }
  // for (let y = -blur; y <= blur; y += step) {
  //   for (let x = -blur; x <= blur; x += step) {
  //     ctx.globalAlpha =
  //       ((alphaLeft * Math.exp(-(x * x + y * y) / (2 * delta * delta))) / sum) *
  //       blur;
  //     // MODIFIED - This is part 2 of 2 of the bug fix for the existing algorithm
  //     // await drawImage(ctx, image, x, y);
  //     await drawImage(ctx, canvas, x, y);
  //   }
  // }
  // ctx.globalAlpha = 1;
}

async function drawAndBlurImageOnCanvas(
  image: ImageType,
  blurAmount: number,
  canvas: Canvas
) {
  const [height, width] = getInputSize(image);
  const ctx = canvas.getContext("2d");
  canvas.width = width;
  canvas.height = height;
  ctx.clearRect(0, 0, width, height);
  ctx.save();
  if (isSafari()) {
    await cpuBlur(canvas, image, blurAmount);
  } else {
    // tslint:disable:no-any
    (ctx as any).filter = `blur(${blurAmount}px)`;
    await drawImage(ctx, image, 0, 0, width, height);
  }
  ctx.restore();
}

async function drawAndBlurImageOnOffScreenCanvas(
  image: ImageType,
  blurAmount: number,
  offscreenCanvasName: string
): Promise<Canvas> {
  const canvas = ensureOffscreenCanvasCreated(offscreenCanvasName);
  if (blurAmount === 0) {
    await renderImageToCanvas(image, canvas);
  } else {
    await drawAndBlurImageOnCanvas(image, blurAmount, canvas);
  }
  return canvas;
}

function drawStroke(
  bytes: Uint8ClampedArray,
  row: number,
  column: number,
  width: number,
  radius: number,
  color: Color = {
    r: 0,
    g: 255,
    b: 255,
    a: 255,
  }
) {
  for (let i = -radius; i <= radius; i++) {
    for (let j = -radius; j <= radius; j++) {
      if (i !== 0 && j !== 0) {
        const n = (row + i) * width + (column + j);
        bytes[4 * n + 0] = color.r;
        bytes[4 * n + 1] = color.g;
        bytes[4 * n + 2] = color.b;
        bytes[4 * n + 3] = color.a;
      }
    }
  }
}

function isSegmentationBoundary(
  data: Uint8ClampedArray,
  row: number,
  column: number,
  width: number,
  isForegroundId: boolean[],
  alphaCutoff: number,
  radius = 1
): boolean {
  let numberBackgroundPixels = 0;
  for (let i = -radius; i <= radius; i++) {
    for (let j = -radius; j <= radius; j++) {
      if (i !== 0 && j !== 0) {
        const n = (row + i) * width + (column + j);
        if (!isForegroundId[data[4 * n]] || data[4 * n + 3] < alphaCutoff) {
          numberBackgroundPixels += 1;
        }
      }
    }
  }
  return numberBackgroundPixels > 0;
}

/**
 * Given a segmentation or array of segmentations, generates an
 * image with foreground and background color at each pixel determined by the
 * corresponding binary segmentation value at the pixel from the output.  In
 * other words, pixels where there is a person will be colored with foreground
 * color and where there is not a person will be colored with background color.
 *
 * @param segmentation Single segmentation or array of segmentations.
 *
 * @param foreground Default to {r:0, g:0, b:0, a: 0}. The foreground color
 * (r,g,b,a) for visualizing pixels that belong to people.
 *
 * @param background Default to {r:0, g:0, b:0, a: 255}. The background color
 * (r,g,b,a) for visualizing pixels that don't belong to people.
 *
 * @param drawContour Default to false. Whether to draw the contour around each
 * person's segmentation mask or body part mask.
 *
 * @param foregroundThreshold Default to 0.5. The minimum probability
 * to color a pixel as foreground rather than background. The alpha channel
 * integer values will be taken as the probabilities (for more information refer
 * to `Segmentation` type's documentation).
 *
 * @param foregroundMaskValues Default to all mask values. The red channel
 *     integer values that represent foreground (for more information refer to
 * `Segmentation` type's documentation).
 *
 * @returns An ImageData with the same width and height of
 * the input segmentations, with opacity and
 * transparency at each pixel determined by the corresponding binary
 * segmentation value at the pixel from the output.
 */
async function toBinaryMask(
  segmentation: Segmentation | Segmentation[],
  foreground: Color = {
    r: 0,
    g: 0,
    b: 0,
    a: 0,
  },
  background: Color = {
    r: 0,
    g: 0,
    b: 0,
    a: 255,
  },
  drawContour = false,
  foregroundThreshold = 0.5,
  foregroundMaskValues = Array.from(Array(256).keys())
) {
  const segmentations = !Array.isArray(segmentation)
    ? [segmentation]
    : segmentation;

  if (segmentations.length === 0) {
    return null;
  }

  const masks = await Promise.all(
    segmentations.map((segmentation) => segmentation.mask.toImageData())
  );

  const { width, height } = masks[0];
  const bytes = new Uint8ClampedArray(width * height * 4);
  const alphaCutoff = Math.round(255 * foregroundThreshold);
  const isForegroundId: boolean[] = new Array(256).fill(false);
  foregroundMaskValues.forEach((id) => (isForegroundId[id] = true));

  for (let i = 0; i < height; i++) {
    for (let j = 0; j < width; j++) {
      const n = i * width + j;
      bytes[4 * n + 0] = background.r;
      bytes[4 * n + 1] = background.g;
      bytes[4 * n + 2] = background.b;
      bytes[4 * n + 3] = background.a;
      for (const mask of masks) {
        if (
          isForegroundId[mask.data[4 * n]] &&
          mask.data[4 * n + 3] >= alphaCutoff
        ) {
          bytes[4 * n] = foreground.r;
          bytes[4 * n + 1] = foreground.g;
          bytes[4 * n + 2] = foreground.b;
          bytes[4 * n + 3] = foreground.a;
          if (
            drawContour &&
            i - 1 >= 0 &&
            i + 1 < height &&
            j - 1 >= 0 &&
            j + 1 < width &&
            isSegmentationBoundary(
              mask.data,
              i,
              j,
              width,
              isForegroundId,
              alphaCutoff
            )
          ) {
            drawStroke(bytes, i, j, width, 1);
          }
        }
      }
    }
  }

  return new ImageData(bytes, width, height);
}

async function createPersonMask(
  segmentation: Segmentation | Segmentation[],
  foregroundThreshold: number,
  edgeBlurAmount: number
): Promise<Canvas> {
  const backgroundMaskImage = await toBinaryMask(
    segmentation,
    { r: 0, g: 0, b: 0, a: 255 },
    { r: 0, g: 0, b: 0, a: 0 },
    false,
    foregroundThreshold
  );

  const backgroundMask = renderImageDataToOffScreenCanvas(
    backgroundMaskImage,
    CANVAS_NAMES.mask
  );
  if (edgeBlurAmount === 0) {
    return backgroundMask;
  } else {
    return drawAndBlurImageOnOffScreenCanvas(
      backgroundMask,
      edgeBlurAmount,
      CANVAS_NAMES.blurredMask
    );
  }
}

/**
 * Given a segmentation or array of segmentations, and an image, draws the image
 * with its background blurred onto the canvas.
 *
 * @param canvas The canvas to draw the background-blurred image onto.
 *
 * @param image The image to blur the background of and draw.
 *
 * @param segmentation Single segmentation or array of segmentations.
 *
 * @param foregroundThreshold Default to 0.5. The minimum probability
 * to color a pixel as foreground rather than background. The alpha channel
 * integer values will be taken as the probabilities (for more information refer
 * to `Segmentation` type's documentation).
 *
 * @param backgroundBlurAmount How many pixels in the background blend into each
 * other.  Defaults to 3. Should be an integer between 1 and 20.
 *
 * @param edgeBlurAmount How many pixels to blur on the edge between the person
 * and the background by.  Defaults to 3. Should be an integer between 0 and 20.
 *
 * @param flipHorizontal If the output should be flipped horizontally.  Defaults
 * to false.
 */
export async function drawBokehEffect(
  canvas: Canvas,
  image: ImageType,
  segmentation: Segmentation | Segmentation[],
  foregroundThreshold = 0.5,
  backgroundBlurAmount = 3,
  edgeBlurAmount = 3,
  flipHorizontal = false
) {
  const blurredImage = await drawAndBlurImageOnOffScreenCanvas(
    image,
    backgroundBlurAmount,
    CANVAS_NAMES.blurred
  );
  canvas.width = blurredImage.width;
  canvas.height = blurredImage.height;

  const ctx = canvas.getContext("2d");

  if (Array.isArray(segmentation) && segmentation.length === 0) {
    ctx.drawImage(blurredImage, 0, 0);
    return;
  }

  const personMask = await createPersonMask(
    segmentation,
    foregroundThreshold,
    edgeBlurAmount
  );

  ctx.save();
  if (flipHorizontal) {
    flipCanvasHorizontal(canvas);
  }
  // draw the original image on the final canvas
  const [height, width] = getInputSize(image);
  await drawImage(ctx, image, 0, 0, width, height);

  // "destination-in" - "The existing canvas content is kept where both the
  // new shape and existing canvas content overlap. Everything else is made
  // transparent."
  // crop what's not the person using the mask from the original image
  await drawWithCompositing(ctx, personMask, "destination-in");
  // "destination-over" - "The existing canvas content is kept where both the
  // new shape and existing canvas content overlap. Everything else is made
  // transparent."
  // draw the blurred background on top of the original image where it doesn't
  // overlap.
  await drawWithCompositing(ctx, blurredImage, "destination-over");
  ctx.restore();
}
