Installation
To get started, make sure to save the Shader component to your project.
"use client";
import React, { useRef, useEffect, useCallback, useState } from "react";
type UniformValue = number | number[] | { value: number | number[] };
export type ShaderProps = {
uniforms?: Record<string, UniformValue>;
fragmentShader: string;
vertexShader?: string;
className?: string;
style?: React.CSSProperties;
};
const defaultVertexShader = `
attribute vec2 a_position;
attribute vec2 a_texCoord;
varying vec2 v_texCoord;
void main() {
gl_Position = vec4(a_position, 0.0, 1.0);
v_texCoord = a_texCoord;
}
`;
export function Shader({
uniforms,
fragmentShader,
vertexShader,
className,
style,
}: ShaderProps) {
const canvasRef = useRef<HTMLCanvasElement>(null);
const glRef = useRef<WebGLRenderingContext | null>(null);
const programRef = useRef<WebGLProgram | null>(null);
const uniformLocationsRef = useRef<
Record<string, WebGLUniformLocation | null>
>({});
const animationIdRef = useRef<number | undefined>(undefined);
const startTimeRef = useRef<number>(Date.now());
const resizeTimeoutRef = useRef<NodeJS.Timeout | undefined>(undefined);
const createShader = useCallback(
(
gl: WebGLRenderingContext,
type: number,
source: string,
): WebGLShader | null => {
const shader = gl.createShader(type);
if (!shader) return null;
gl.shaderSource(shader, source);
gl.compileShader(shader);
if (!gl.getShaderParameter(shader, gl.COMPILE_STATUS)) {
console.error("Shader compilation error:", gl.getShaderInfoLog(shader));
gl.deleteShader(shader);
return null;
}
return shader;
},
[],
);
const createProgram = useCallback(
(
gl: WebGLRenderingContext,
vertexShader: WebGLShader,
fragmentShader: WebGLShader,
): WebGLProgram | null => {
const program = gl.createProgram();
if (!program) return null;
gl.attachShader(program, vertexShader);
gl.attachShader(program, fragmentShader);
gl.linkProgram(program);
if (!gl.getProgramParameter(program, gl.LINK_STATUS)) {
console.error("Program linking error:", gl.getProgramInfoLog(program));
gl.deleteProgram(program);
return null;
}
return program;
},
[],
);
const setUniform = useCallback(
(
gl: WebGLRenderingContext,
location: WebGLUniformLocation,
value: UniformValue,
) => {
const actualValue =
typeof value === "object" && "value" in value ? value.value : value;
if (typeof actualValue === "number") {
gl.uniform1f(location, actualValue);
} else if (Array.isArray(actualValue)) {
switch (actualValue.length) {
case 2:
gl.uniform2fv(location, actualValue);
break;
case 3:
gl.uniform3fv(location, actualValue);
break;
case 4:
gl.uniform4fv(location, actualValue);
break;
default:
gl.uniform1fv(location, actualValue);
}
}
},
[],
);
const initWebGL = useCallback(() => {
const canvas = canvasRef.current;
if (!canvas) return;
const contextAttributes: WebGLContextAttributes = {
alpha: false,
antialias: true,
depth: false,
stencil: false,
preserveDrawingBuffer: false,
powerPreference: "high-performance",
desynchronized: false,
};
const gl =
canvas.getContext("webgl", contextAttributes) ||
canvas.getContext("experimental-webgl", contextAttributes);
if (!gl || !(gl instanceof WebGLRenderingContext)) {
console.error("WebGL not supported");
return;
}
glRef.current = gl;
const vertexShaderObj = vertexShader
? createShader(gl, gl.VERTEX_SHADER, vertexShader)
: createShader(gl, gl.VERTEX_SHADER, defaultVertexShader);
const fragmentShaderObj = createShader(
gl,
gl.FRAGMENT_SHADER,
fragmentShader,
);
if (!vertexShaderObj || !fragmentShaderObj) {
console.error("Failed to create shaders");
return;
}
const program = createProgram(gl, vertexShaderObj, fragmentShaderObj);
if (!program) {
console.error("Failed to create program");
return;
}
programRef.current = program;
const locations: Record<string, WebGLUniformLocation | null> = {};
locations.u_time = gl.getUniformLocation(program, "u_time");
locations.u_resolution = gl.getUniformLocation(program, "u_resolution");
if (uniforms) {
Object.keys(uniforms).forEach((name) => {
locations[name] = gl.getUniformLocation(program, name);
});
}
uniformLocationsRef.current = locations;
const positions = new Float32Array([
-1, -1, 0, 0, 1, -1, 1, 0, -1, 1, 0, 1, -1, 1, 0, 1, 1, -1, 1, 0, 1, 1, 1,
1,
]);
const buffer = gl.createBuffer();
gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
gl.bufferData(gl.ARRAY_BUFFER, positions, gl.STATIC_DRAW);
const positionLocation = gl.getAttribLocation(program, "a_position");
const texCoordLocation = gl.getAttribLocation(program, "a_texCoord");
gl.enableVertexAttribArray(positionLocation);
gl.enableVertexAttribArray(texCoordLocation);
gl.vertexAttribPointer(positionLocation, 2, gl.FLOAT, false, 16, 0);
gl.vertexAttribPointer(texCoordLocation, 2, gl.FLOAT, false, 16, 8);
}, [fragmentShader, uniforms, createShader, createProgram, vertexShader]);
const renderFrame = useCallback(
(currentTime: number = performance.now()) => {
const canvas = canvasRef.current;
const gl = glRef.current;
const program = programRef.current;
if (!canvas || !gl || !program) return;
// biome-ignore lint/correctness/useHookAtTopLevel: gl.useProgram is a WebGL API call, not a React hook
gl.useProgram(program);
gl.viewport(0, 0, canvas.width, canvas.height);
gl.clearColor(0, 0, 0, 1);
gl.clear(gl.COLOR_BUFFER_BIT);
const adjustedTime = (Date.now() - startTimeRef.current) / 1000;
const shaderTime = Math.max(0, adjustedTime);
const timeLocation = uniformLocationsRef.current.u_time;
if (timeLocation) {
gl.uniform1f(timeLocation, shaderTime);
}
const resolutionLocation = uniformLocationsRef.current.u_resolution;
if (resolutionLocation) {
gl.uniform2f(resolutionLocation, canvas.width, canvas.height);
}
if (uniforms) {
Object.entries(uniforms).forEach(([name, value]) => {
const location = uniformLocationsRef.current[name];
if (location) {
setUniform(gl, location, value);
}
});
}
gl.drawArrays(gl.TRIANGLES, 0, 6);
animationIdRef.current = requestAnimationFrame(renderFrame);
},
[uniforms, setUniform],
);
// Handle canvas resize with debouncing
const handleResize = useCallback(() => {
const canvas = canvasRef.current;
const gl = glRef.current;
if (!canvas) return;
if (resizeTimeoutRef.current) {
clearTimeout(resizeTimeoutRef.current);
}
resizeTimeoutRef.current = setTimeout(() => {
const parent = canvas.parentElement;
if (!parent) return;
const parentRect = parent.getBoundingClientRect();
const devicePixelRatio = window.devicePixelRatio || 1;
canvas.width = parentRect.width * devicePixelRatio;
canvas.height = parentRect.height * devicePixelRatio;
canvas.style.width = `${parentRect.width}px`;
canvas.style.height = `${parentRect.height}px`;
canvas.style.display = "block";
canvas.offsetHeight;
if (gl) {
gl.viewport(0, 0, canvas.width, canvas.height);
renderFrame();
}
}, 16);
}, [renderFrame]);
useEffect(() => {
handleResize();
initWebGL();
renderFrame();
const resizeObserver = new ResizeObserver(handleResize);
if (canvasRef.current?.parentElement) {
resizeObserver.observe(canvasRef.current.parentElement);
}
const handleWindowResize = () => {
handleResize();
};
window.addEventListener("resize", handleWindowResize);
return () => {
if (animationIdRef.current) {
cancelAnimationFrame(animationIdRef.current);
}
if (resizeTimeoutRef.current) {
clearTimeout(resizeTimeoutRef.current);
}
resizeObserver.disconnect();
window.removeEventListener("resize", handleWindowResize);
};
}, [initWebGL, renderFrame, handleResize]);
useEffect(() => {
if (glRef.current && programRef.current) {
initWebGL();
}
}, [initWebGL]);
return (
<div style={{ position: "relative", width: "100%", height: "100%" }}>
<canvas
ref={canvasRef}
className={className}
style={{
width: "100%",
height: "100%",
display: "block",
...style,
}}
/>
</div>
);
}
Last updated on