Skip to Content
DocsInstallation

Installation

To get started, make sure to save the Shader component to your project.

"use client";

import React, { useRef, useEffect, useCallback, useState } from "react";

type UniformValue = number | number[] | { value: number | number[] };

export type ShaderProps = {
  uniforms?: Record<string, UniformValue>;
  fragmentShader: string;
  vertexShader?: string;
  className?: string;
  style?: React.CSSProperties;
};

const defaultVertexShader = `
  attribute vec2 a_position;
  attribute vec2 a_texCoord;
  varying vec2 v_texCoord;
  
  void main() {
    gl_Position = vec4(a_position, 0.0, 1.0);
    v_texCoord = a_texCoord;
  }
`;

export function Shader({
  uniforms,
  fragmentShader,
  vertexShader,
  className,
  style,
}: ShaderProps) {
  const canvasRef = useRef<HTMLCanvasElement>(null);
  const glRef = useRef<WebGLRenderingContext | null>(null);
  const programRef = useRef<WebGLProgram | null>(null);
  const uniformLocationsRef = useRef<
    Record<string, WebGLUniformLocation | null>
  >({});
  const animationIdRef = useRef<number | undefined>(undefined);
  const startTimeRef = useRef<number>(Date.now());
  const resizeTimeoutRef = useRef<NodeJS.Timeout | undefined>(undefined);

  const createShader = useCallback(
    (
      gl: WebGLRenderingContext,
      type: number,
      source: string,
    ): WebGLShader | null => {
      const shader = gl.createShader(type);
      if (!shader) return null;

      gl.shaderSource(shader, source);
      gl.compileShader(shader);

      if (!gl.getShaderParameter(shader, gl.COMPILE_STATUS)) {
        console.error("Shader compilation error:", gl.getShaderInfoLog(shader));
        gl.deleteShader(shader);
        return null;
      }

      return shader;
    },
    [],
  );

  const createProgram = useCallback(
    (
      gl: WebGLRenderingContext,
      vertexShader: WebGLShader,
      fragmentShader: WebGLShader,
    ): WebGLProgram | null => {
      const program = gl.createProgram();
      if (!program) return null;

      gl.attachShader(program, vertexShader);
      gl.attachShader(program, fragmentShader);
      gl.linkProgram(program);

      if (!gl.getProgramParameter(program, gl.LINK_STATUS)) {
        console.error("Program linking error:", gl.getProgramInfoLog(program));
        gl.deleteProgram(program);
        return null;
      }

      return program;
    },
    [],
  );

  const setUniform = useCallback(
    (
      gl: WebGLRenderingContext,
      location: WebGLUniformLocation,
      value: UniformValue,
    ) => {
      const actualValue =
        typeof value === "object" && "value" in value ? value.value : value;

      if (typeof actualValue === "number") {
        gl.uniform1f(location, actualValue);
      } else if (Array.isArray(actualValue)) {
        switch (actualValue.length) {
          case 2:
            gl.uniform2fv(location, actualValue);
            break;
          case 3:
            gl.uniform3fv(location, actualValue);
            break;
          case 4:
            gl.uniform4fv(location, actualValue);
            break;
          default:
            gl.uniform1fv(location, actualValue);
        }
      }
    },
    [],
  );

  const initWebGL = useCallback(() => {
    const canvas = canvasRef.current;
    if (!canvas) return;

    const contextAttributes: WebGLContextAttributes = {
      alpha: false,
      antialias: true,
      depth: false,
      stencil: false,
      preserveDrawingBuffer: false,
      powerPreference: "high-performance",
      desynchronized: false,
    };

    const gl =
      canvas.getContext("webgl", contextAttributes) ||
      canvas.getContext("experimental-webgl", contextAttributes);
    if (!gl || !(gl instanceof WebGLRenderingContext)) {
      console.error("WebGL not supported");
      return;
    }

    glRef.current = gl;

    const vertexShaderObj = vertexShader
      ? createShader(gl, gl.VERTEX_SHADER, vertexShader)
      : createShader(gl, gl.VERTEX_SHADER, defaultVertexShader);
    const fragmentShaderObj = createShader(
      gl,
      gl.FRAGMENT_SHADER,
      fragmentShader,
    );

    if (!vertexShaderObj || !fragmentShaderObj) {
      console.error("Failed to create shaders");
      return;
    }

    const program = createProgram(gl, vertexShaderObj, fragmentShaderObj);
    if (!program) {
      console.error("Failed to create program");
      return;
    }

    programRef.current = program;

    const locations: Record<string, WebGLUniformLocation | null> = {};

    locations.u_time = gl.getUniformLocation(program, "u_time");
    locations.u_resolution = gl.getUniformLocation(program, "u_resolution");

    if (uniforms) {
      Object.keys(uniforms).forEach((name) => {
        locations[name] = gl.getUniformLocation(program, name);
      });
    }

    uniformLocationsRef.current = locations;

    const positions = new Float32Array([
      -1, -1, 0, 0, 1, -1, 1, 0, -1, 1, 0, 1, -1, 1, 0, 1, 1, -1, 1, 0, 1, 1, 1,
      1,
    ]);

    const buffer = gl.createBuffer();
    gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
    gl.bufferData(gl.ARRAY_BUFFER, positions, gl.STATIC_DRAW);

    const positionLocation = gl.getAttribLocation(program, "a_position");
    const texCoordLocation = gl.getAttribLocation(program, "a_texCoord");

    gl.enableVertexAttribArray(positionLocation);
    gl.enableVertexAttribArray(texCoordLocation);

    gl.vertexAttribPointer(positionLocation, 2, gl.FLOAT, false, 16, 0);
    gl.vertexAttribPointer(texCoordLocation, 2, gl.FLOAT, false, 16, 8);
  }, [fragmentShader, uniforms, createShader, createProgram, vertexShader]);

  const renderFrame = useCallback(
    (currentTime: number = performance.now()) => {
      const canvas = canvasRef.current;
      const gl = glRef.current;
      const program = programRef.current;

      if (!canvas || !gl || !program) return;

      // biome-ignore lint/correctness/useHookAtTopLevel: gl.useProgram is a WebGL API call, not a React hook
      gl.useProgram(program);

      gl.viewport(0, 0, canvas.width, canvas.height);

      gl.clearColor(0, 0, 0, 1);
      gl.clear(gl.COLOR_BUFFER_BIT);

      const adjustedTime = (Date.now() - startTimeRef.current) / 1000;
      const shaderTime = Math.max(0, adjustedTime);
      const timeLocation = uniformLocationsRef.current.u_time;
      if (timeLocation) {
        gl.uniform1f(timeLocation, shaderTime);
      }

      const resolutionLocation = uniformLocationsRef.current.u_resolution;
      if (resolutionLocation) {
        gl.uniform2f(resolutionLocation, canvas.width, canvas.height);
      }

      if (uniforms) {
        Object.entries(uniforms).forEach(([name, value]) => {
          const location = uniformLocationsRef.current[name];
          if (location) {
            setUniform(gl, location, value);
          }
        });
      }

      gl.drawArrays(gl.TRIANGLES, 0, 6);

      animationIdRef.current = requestAnimationFrame(renderFrame);
    },
    [uniforms, setUniform],
  );

  // Handle canvas resize with debouncing
  const handleResize = useCallback(() => {
    const canvas = canvasRef.current;
    const gl = glRef.current;
    if (!canvas) return;

    if (resizeTimeoutRef.current) {
      clearTimeout(resizeTimeoutRef.current);
    }

    resizeTimeoutRef.current = setTimeout(() => {
      const parent = canvas.parentElement;
      if (!parent) return;

      const parentRect = parent.getBoundingClientRect();
      const devicePixelRatio = window.devicePixelRatio || 1;

      canvas.width = parentRect.width * devicePixelRatio;
      canvas.height = parentRect.height * devicePixelRatio;

      canvas.style.width = `${parentRect.width}px`;
      canvas.style.height = `${parentRect.height}px`;
      canvas.style.display = "block";

      canvas.offsetHeight;

      if (gl) {
        gl.viewport(0, 0, canvas.width, canvas.height);

        renderFrame();
      }
    }, 16);
  }, [renderFrame]);

  useEffect(() => {
    handleResize();
    initWebGL();
    renderFrame();

    const resizeObserver = new ResizeObserver(handleResize);
    if (canvasRef.current?.parentElement) {
      resizeObserver.observe(canvasRef.current.parentElement);
    }

    const handleWindowResize = () => {
      handleResize();
    };
    window.addEventListener("resize", handleWindowResize);

    return () => {
      if (animationIdRef.current) {
        cancelAnimationFrame(animationIdRef.current);
      }
      if (resizeTimeoutRef.current) {
        clearTimeout(resizeTimeoutRef.current);
      }
      resizeObserver.disconnect();
      window.removeEventListener("resize", handleWindowResize);
    };
  }, [initWebGL, renderFrame, handleResize]);

  useEffect(() => {
    if (glRef.current && programRef.current) {
      initWebGL();
    }
  }, [initWebGL]);

  return (
    <div style={{ position: "relative", width: "100%", height: "100%" }}>
      <canvas
        ref={canvasRef}
        className={className}
        style={{
          width: "100%",
          height: "100%",
          display: "block",
          ...style,
        }}
      />
    </div>
  );
}
Last updated on