import {InferenceSession, Tensor, env} from 'onnxruntime-web';
import {guideMapping, infosMapping, keypointsMapping, sidesMapping, viewsMapping} from '../constants/Mapping.Constant';
import {addOrUpdateItem, getItem, StoreName} from "./IndexedDB";

// set wasm path override
env.wasm.wasmPaths = "../model/";

// Session for detection model
export const sessionKeypointDetection = async detectionModelPath => {
		if (!detectionModelPath) return null;
		const startTime = performance.now();
		
		let modelFile = await getItem(detectionModelPath, StoreName.DAT_FORM_FILES);

		if (modelFile && 'data' in modelFile && modelFile.data) {
				modelFile = modelFile.data
		}else{
				const modelResponse = await fetch(detectionModelPath);
				modelFile = await modelResponse.arrayBuffer();
				
				await addOrUpdateItem(detectionModelPath, {
						data: modelFile
				}, StoreName.DAT_FORM_FILES)
		}
		
		const sess = await InferenceSession.create(modelFile, {
				executionProviders: ['wasm'],
				graphOptimizationLevel: 'all',
				// enableProfiling: true,
				// logSeverityLevel: 3,
				// logVerbosityLevel: 3
		});
		
		const creation_elapsed = (performance.now() - startTime) / 1000;
		console.info(`DETECTION MODEL CREATED IN ${detectionModelPath} ${creation_elapsed}`);
		return sess;
};

// Session for checker model
export const sessionKeypointChecker = async checkerModelPath => {
		if (!checkerModelPath) return null;
		
		const startTime = performance.now();
		
		let modelFile = await getItem(checkerModelPath, StoreName.DAT_FORM_FILES);
		
		if (modelFile && 'data' in modelFile && modelFile.data) {
				modelFile = modelFile.data
		}else{
				const modelResponse = await fetch(checkerModelPath);
				modelFile = await modelResponse.arrayBuffer();
				
				await addOrUpdateItem(checkerModelPath, {
						data: modelFile
				}, StoreName.DAT_FORM_FILES)
		}
		
		const sess = await InferenceSession.create(modelFile, {
				executionProviders: ['wasm'],
				graphOptimizationLevel: 'all',
				// enableProfiling: true,
				// logSeverityLevel: 3,
				// logVerbosityLevel: 3
		});
		
		const creation_elapsed = (performance.now() - startTime) / 1000;
		console.info(`CHECKER MODEL CREATED IN ${checkerModelPath} ${creation_elapsed}`);
		return sess;
};

// import Session for detection and checker model
const modelCache = {};

export const importDetectionModels = async (detectionModelPath, checkerModelPath) => {
		const start = new Date().getTime();
		const cacheKey = `${detectionModelPath}-${checkerModelPath}`;
		
		if (modelCache[cacheKey]) {
				// If the result is already in the cache, return it.
				return modelCache[cacheKey];
		}
		
		try {
				const detectionModelResponse = await fetch(detectionModelPath);
				const detectionModelFile = await detectionModelResponse.arrayBuffer();
				
				const keypointSession = await InferenceSession.create(detectionModelFile, {
						executionProviders: ['wasm'],
						graphOptimizationLevel: 'all',
						// enableProfiling: true,
						// logSeverityLevel: 3,
						// logVerbosityLevel: 3
				});
				
				const checkerModelResponse = await fetch(checkerModelPath);
				const checkerModelFile = await checkerModelResponse.arrayBuffer();
				
				const checkerSession = await InferenceSession.create(checkerModelFile, {
						executionProviders: ['wasm'],
						graphOptimizationLevel: 'all',
						// enableProfiling: true,
						// logSeverityLevel: 3,
						// logVerbosityLevel: 3
				});
				
				// Cache the results for future use.
				modelCache[cacheKey] = {keypoint_session: keypointSession, checker_session: checkerSession};
				
				let creation_elapsed = new Date().getTime() - start;
				console.log('DETECTION MODEL CREATED IN ', creation_elapsed);
				
				return modelCache[cacheKey];
		} catch (error) {
				console.error('Error creating inference sessions:', error);
		}
};

export const imageDataToTensor = imageData => {
		const
				imageWidth = imageData.width,
				imageHeight = imageData.height,
				imageArray = imageData.data;
		
		const [R, G, B] = new Array([], [], []);
		
		for (let i = 0; i < imageArray.length; i += 4) {
				R.push(imageArray[i]);
				G.push(imageArray[i + 1]);
				B.push(imageArray[i + 2]);
				// 2. skip data[i + 3] thus filtering out the alpha channel
		}
		
		// 1b. concatenate RGB ~= transpose [H, W, 3] -> [3, H, W]
		const transposedData = R.concat(G).concat(B);
		
		// 3. convert to uint8
		let i, l = transposedData.length; // length, we need this for the loop
		const Uint8Data = new Uint8Array(3 * imageHeight * imageWidth); // create the Uint8Array for output
		for (i = 0; i < l; i++) {
				Uint8Data[i] = transposedData[i]; // / 255.0; // convert to float
		}
		return new Tensor("uint8", Uint8Data, [3, imageHeight, imageWidth]);
};

export const resizeCanvas = (vid, canvas) => {
		const vidStyleData = vid.getBoundingClientRect();
		canvas.style.width = vidStyleData.width + "px";
		canvas.style.height = vidStyleData.height + "px";
		canvas.width = vidStyleData.width;
		canvas.height = vidStyleData.height;
		canvas.style.left = vidStyleData.left + "px";
		canvas.style.top = vidStyleData.top + "px";
}

export const runKeypointDetectionModel = async (
		keypointSession,
		preprocessedData
) => {
		try {
				const {keypoints, keypoints_pos, classPred, inferenceTime} = await runDetectionInference(
						keypointSession,
						preprocessedData
				);
				return {keypoints, keypoints_pos, classPred, inferenceTime};
		} catch (error) {
				const stringError = `ERROR IN INFERENCE PART :  ${error}`;
				console.log(stringError);
				throw error;
		}
}

export const runDetectionInference = async (
		keypointSession,
		preprocessedData
) => {
		// Get start time to calculate inference time.
		const start = new Date();
		// create feeds for detection model
		const feeds_detection = {};
		// console.log(keypointSession)
		// console.log(preprocessedData)
		feeds_detection[keypointSession.handler.inputNames[0]] = preprocessedData;
		// console.log('feeds for keypoint detection model')
		// console.log(feeds_detection)
		// Run the session detection inference.
		const outputDetection = await keypointSession.run(feeds_detection);
		// Get the end time to calculate inference time.
		const end = new Date();
		// Convert to seconds.
		const inferenceTime = end.getTime() - start.getTime();
		// Get output results with the output name from the model export.
		const keypoint = outputDetection[keypointSession.outputNames[0]];
		const keypointArray = keypoint.data;
		const scores = outputDetection[keypointSession.outputNames[1]];
		const scoresArray = scores.data;
		const class_pred = outputDetection[keypointSession.outputNames[2]];
		const classesPredArray = class_pred.data;
		// const classes = outputDetection[keypointSession.outputNames[3]];
		// const classesArray = classes.data as Float32Array;
		
		const keypoints_pos = [];
		const keypoints = [];
		for (let i = 0; i < scoresArray.length; i += 1) {
				const kid = 2 * i;
				if (scoresArray[i] > 0.7) {
						keypoints_pos.push(keypointArray[kid]);
						keypoints_pos.push(keypointArray[kid + 1]);
						keypoints.push(keypointsMapping[i + 1]);
				}
		}
		const classPred = sidesMapping[Number(classesPredArray[0])];
		
		return {keypoints, keypoints_pos, classPred, inferenceTime};
}

export const runKeypointCheckerModel = async (
		preprocessedData,
		view,
		keypointSession,
		checkerSession,
		debug = false
) => {
		try {
				const {keypoints, keypoints_pos, classPred, checksArray, infosArray, inferenceTime} = await runBothInference(
						keypointSession,
						checkerSession,
						preprocessedData,
						view,
						debug
				);
				return {keypoints, keypoints_pos, classPred, checksArray, infosArray, inferenceTime};
		} catch (error) {
				const stringError = `ERROR IN INFERENCE PART :  ${error}`;
				console.log(stringError);
				throw error;
		}
}

export const runBothInference = async (
		keypointSession,
		checkerSession,
		preprocessedData,
		view,
		debug = false
) => {
		// Get start time to calculate inference time.
		const start = new Date();
		// create feeds for detection model
		const feeds_detection = {};
		feeds_detection[keypointSession.handler.inputNames[0]] = preprocessedData;
		
		// Run the session detection inference.
		const outputDetection = await keypointSession.run(feeds_detection);
		// Get output results with the output name from the model export.
		const keypoint = outputDetection[keypointSession.outputNames[0]];
		const keypointArray = keypoint.data;
		const scores = outputDetection[keypointSession.outputNames[1]];
		const scoresArray = scores.data;
		const class_pred = outputDetection[keypointSession.outputNames[2]];
		const classesPredArray = class_pred.data;
		const classes = outputDetection[keypointSession.outputNames[3]];
		const classesArray = classes.data;
		
		if (debug) {
				console.log('keypoint detection output names : ', keypointSession.outputNames);
				console.log('keypoints : ', keypointArray);
				console.log('scores : ', scoresArray);
				console.log('classes scores : ', classesArray);
				console.log('class num : ', Number(classesPredArray[0]));
		}
		const endDet = new Date();
		
		// create feeds for checker model.
		const feeds_checker = {};
		const tensorView = new Tensor('int32', view, []);
		feeds_checker[checkerSession.handler.inputNames[0]] = keypoint;
		feeds_checker[checkerSession.handler.inputNames[1]] = scores;
		feeds_checker[checkerSession.handler.inputNames[2]] = class_pred;
		feeds_checker[checkerSession.handler.inputNames[3]] = tensorView;
		
		const outputChecker = await checkerSession.run(feeds_checker);
		const checks = outputChecker[checkerSession.outputNames[0]];
		const infos = outputChecker[checkerSession.outputNames[1]];
		const checksArray = checks.data;
		const infosArray = infos.data;
		
		if (debug) {
				console.log('keypoint checker output names : ', checkerSession.outputNames);
				console.log('checks : ', checksArray);
				console.log('infos : ', infosArray);
		}
		
		const keypoints_pos = [];
		const keypoints = [];
		for (let i = 0; i < scoresArray.length; i += 1) {
				const kid = 2 * i;
				if (scoresArray[i] > 0.7) {
						keypoints_pos.push(keypointArray[kid]);
						keypoints_pos.push(keypointArray[kid + 1]);
						keypoints.push(keypointsMapping[i + 1]);
				}
		}
		const classPred = sidesMapping[Number(classesPredArray[0])];
		
		// Get the end time to calculate inference time.
		const end = new Date();
		// Convert to seconds.
		const detectionTime = endDet.getTime() - start.getTime();
		const checkerTime = end.getTime() - endDet.getTime();
		const inferenceTime = end.getTime() - start.getTime();
		if (debug) {
				console.log('detetion inference time : ', detectionTime);
				console.log('checker inference time : ', checkerTime);
		}
		
		return {keypoints, keypoints_pos, classPred, checksArray, infosArray, inferenceTime};
}

export const runDetection = async (videoRef, ctx, keypointSession, checkerSession, viewName, debug = false) => {
		const width = videoRef.clientWidth;
		const height = videoRef.clientHeight;
		
		ctx.drawImage(videoRef, 0, 0, width, height);
		
		const imageData = ctx.getImageData(0, 0, width, height);
		const inputTensor = imageDataToTensor(imageData, [3, width, height])
		ctx.clearRect(0, 0, width, height);
		
		if (keypointSession == null || checkerSession == null) {
				console.log('the sessions are not initialized!!')
				return -1
		}
		
		const view = new Int32Array(1);
		view[0] = viewsMapping[viewName];
		
		const {
				keypoints,
				keypoints_pos,
				classPred,
				checksArray,
				infosArray,
				inferenceTime
		} = await runBothInference(
				keypointSession,
				checkerSession,
				inputTensor,
				view
		);
		
		const guide = checksArray[0]
		const result = checksArray == 0
		const guideResult = viewName === 'freeView' ? guideMapping.freeView[guide] : guideMapping.constraintViews[guide];
		
		const string_output0 = `guide:${guideResult} `;
		const string_output1 = `checked for view :${viewName}, in:${inferenceTime}ms `;
		const string_output2 = `class predicted:${classPred}, num_kpts:${infosArray[1]} `;
		
		let string_output3 = '';
		
		if (!(viewName === 'freeView')) {
				string_output3 = `top:${infosMapping.top[infosArray[2]]}, bottom:${infosMapping.bottom[infosArray[3]]}, left:${infosMapping.left[infosArray[4]]}, right:${infosMapping.right[infosArray[5]]}, center:${infosMapping.center[infosArray[6]]}, wrong:${infosMapping.wrong[infosArray[7]]} `;
		}
		const string_output = string_output0 + string_output1 + string_output2 + string_output3
		
		if (debug) {
				console.log('INFERENCE OUTPUT ', string_output);
		}
		
		ctx.font = "bold 9px sans-serif";
		ctx.fillStyle = '#000000';
		
		const drawText = (text, fontSize, x, y) => {
				ctx.font = `bold ${fontSize}px sans-serif`;
				ctx.fillText(text, x, y);
		}
		
		const drawRectangle = (x, y, width, height, color) => {
				ctx.fillStyle = color;
				ctx.fillRect(x, y, width, height);
		}
		
		drawText(string_output0, viewName === 'freeView' ? 20 : 20, 10, height - 50);
		drawText(string_output1, 10, 10, height - (viewName === 'freeView' ? 34 : 34));
		drawText(string_output2, 10, 10, height - 23);
		
		if (viewName !== 'freeView') {
				drawText(string_output3, 10, 10, height - 12);
		}
		
		for (let i = 0; i < keypoints.length; i += 1) {
				drawRectangle(keypoints_pos[i * 2] * width, keypoints_pos[i * 2 + 1] * height, 5, 5, '#ff0000');
				drawText(keypoints[i], 10, keypoints_pos[i * 2] * width, keypoints_pos[i * 2 + 1] * height);
		}
		
		ctx.fillStyle = result ? '#138707' : '#ff0000';
		drawRectangle(0, 0, width, 5, ctx.fillStyle);
		drawRectangle(width - 5, 0, 5, height, ctx.fillStyle);
		drawRectangle(0, 0, 5, height, ctx.fillStyle);
		drawRectangle(0, height - 5, width, 5, ctx.fillStyle);
		
		return {result, guide, guideResult};
}

export const runBothInferenceOptimized = async (
		keypointSession,
		checkerSession,
		preprocessedData,
		view
) => {
		// Run the session detection inference.
		let outputDetection = await keypointSession.run({
				[keypointSession.handler.inputNames[0]]: preprocessedData
		});
		
		let outputChecker = await checkerSession.run({
				[checkerSession.handler.inputNames[0]]: outputDetection[keypointSession.outputNames[0]],
				[checkerSession.handler.inputNames[1]]: outputDetection[keypointSession.outputNames[1]],
				[checkerSession.handler.inputNames[2]]: outputDetection[keypointSession.outputNames[2]],
				[checkerSession.handler.inputNames[3]]: new Tensor('int32', view, []),
		});
		
		return {
				checksArray: outputChecker[checkerSession.outputNames[0]].data,
				// checksArray: [],
		};
}

export const runDetectionOptimized = async (videoRef, ctx, keypointSession, checkerSession, viewName) => {
		ctx.drawImage(videoRef, 0, 0, videoRef.clientWidth, videoRef.clientHeight);
		
		const inputTensor = imageDataToTensor(
				ctx.getImageData(0, 0, videoRef.clientWidth, videoRef.clientHeight),
				[3, videoRef.clientWidth, videoRef.clientHeight]
		)
		
		ctx.clearRect(0, 0, videoRef.clientWidth, videoRef.clientHeight);
		
		if (keypointSession == null || checkerSession == null) {
				console.log('the sessions are not initialized!!')
				return -1
		}
		
		const view = new Int32Array(1);
		view[0] = viewsMapping[viewName];
		
		const {
				checksArray,
		} = await runBothInferenceOptimized(
				keypointSession,
				checkerSession,
				inputTensor,
				view
		);
		
		const guide = checksArray[0]
		const result = checksArray == 0
		const guideResult = viewName === 'freeView' ? guideMapping.freeView[guide] : guideMapping.constraintViews[guide];
		
		return {result, guide, guideResult};
}
