//ocrsaga.ts
import { all, call, delay, put, select, takeEvery, takeLatest } from "redux-saga/effects";

import { ocrService } from "src/services/api-services/OcrService";
import { FileReference, OcrResponse, GetOCrParams } from "src/pages/api/documents/ocr/common";
import { OcrAction } from "../actions/actions.constants";
import {
  cachedExtractsSelector,
  currentlyPollingSelector,
  dataNeededSelector,
  runningJobsSelector,
} from "../selectors/ocr.selector";

import { toastService } from "src/services/ToastService";
import {
  OcrCompletedAction,
  OcrErrorAction,
  PollOcrStatus,
  StartOcrSaga,
  UpdateOcrData,
} from "../actions/ocr.actions";
import { ExtractOutput } from "@interfold-ai/shared/models/extract/common";;
import { CalculateSpread, SetSpreadUserState } from "../actions/spread.action";
import { analysisTypeSelector } from "../selectors/spread.selector";
import { AnalysisType } from "@interfold-ai/shared/models/SpreadsTabConstants";
import { SPREAD_USER_STATE } from "src/redux/reducers/spread.reducer";
import { BaseApiError } from "src/services/api-services/BaseApiService";

type DocumentUploadId = string;

const DEFAULT_OCR_POLLING_INTERVAL_MS = 5000;
const OCR_POLLING_INTERVAL_MS = parseInt(
  process.env.NEXT_PUBLIC_OCR_POLLING_INTERVAL_MS ?? DEFAULT_OCR_POLLING_INTERVAL_MS,
  10,
);

export function* startOcr(action: ReturnType<typeof StartOcrSaga>): Generator {
  const analysisType = (yield select(analysisTypeSelector)) as AnalysisType;
  try {
    const response = yield call(ocrService.startOcr, action.payload, analysisType);
    yield put(UpdateOcrData(response as OcrResponse));
  } catch (e) {
    const error = e as Error;
    yield put(OcrErrorAction(error));
  }
}

export function* updateOcrData(action: { type: string; payload: OcrResponse }): Generator {
  try {
    const jobsAndData = action.payload as OcrResponse;
    const _cachedExtracts = yield select(cachedExtractsSelector);
    const cachedExtracts = _cachedExtracts as Record<DocumentUploadId, ExtractOutput>;
    const combined = { ...cachedExtracts, ...jobsAndData } as Record<
      DocumentUploadId,
      ExtractOutput
    >;
    const allOcrCompleted = Object.entries(combined).every(
      ([, jobOrId]) => typeof jobOrId === "object",
    );
    if (allOcrCompleted) {
      yield put(OcrCompletedAction(combined as Record<DocumentUploadId, ExtractOutput>));
      return;
    } else {
      const isPolling = yield select(currentlyPollingSelector);
      if (isPolling) {
        yield delay(OCR_POLLING_INTERVAL_MS);
        yield put(PollOcrStatus());
      }
    }
  } catch (e) {
    const error = e as Error;
    yield put(OcrErrorAction(error));
  }
}

export function* pollOcrStatus(): Generator {
  try {
    const isPolling = yield select(currentlyPollingSelector);
    if (!isPolling) {
      return;
    }
    const fileReferences = (yield select(dataNeededSelector)) as FileReference[];
    const _jobs = yield select(runningJobsSelector);
    const jobs = _jobs as Record<string | number, FileReference>;
    const dataNeeded = fileReferences
      .map((data) => {
        const entry = Object.entries(jobs).find(
          ([, value]) => value.documentRequestId === data.documentRequestId,
        );
        if (!entry) {
          return null;
        } else {
          // either can be jobId or taskId
          const [jobOrTaskId] = entry;
          if (!jobOrTaskId) {
            return data;
          }
          const isNumeric = !isNaN(Number(jobOrTaskId));
          const key = isNumeric ? "taskId" : "jobId";
          return { ...data, [key]: jobOrTaskId };
        }
      })
      .filter((val) => val !== null) as FileReference[];
    if (dataNeeded.length === 0) {
      yield put(OcrErrorAction(new Error("No jobs to poll.")));
      return;
    }

    const analysisType = (yield select(analysisTypeSelector)) as AnalysisType;
    const getOcrParams: GetOCrParams = {
      dataNeeded,
      analysisType,
    };
    const results = yield call(ocrService.checkOcrStatus, getOcrParams);
    if (results instanceof Error) {
      throw results;
    }
    const jobsAndData = results as OcrResponse;

    yield put(UpdateOcrData(jobsAndData));
  } catch (error) {
    const e = error as Error;
    yield put(OcrErrorAction(e));
  }
}

export function* ocrCompleted(action: ReturnType<typeof OcrCompletedAction>) {
  const workflow = (yield select(analysisTypeSelector)) as AnalysisType;
  const extractOutputs: Record<string, ExtractOutput> = action.payload;

  const finalData = { ...extractOutputs };

  switch (workflow) {
    case AnalysisType.NOI_ANALYSIS_NEW_LOAN:
    case AnalysisType.NOI_ANALYSIS_PORTFOLIO_MANAGEMENT:
      yield put(SetSpreadUserState(SPREAD_USER_STATE.CHOOSING_ASSETS));
      break;

    case AnalysisType.EXTRACT_TABLES:
    case AnalysisType.PERSONAL_CASH_FLOW:
    case AnalysisType.BUSINESS_CASH_FLOW:
    case AnalysisType.GENERAL_SPREADS:
      yield put(CalculateSpread(finalData));
      break;
  }
}

export function* ocrErrored(action: ReturnType<typeof OcrErrorAction>): Generator {
  const error = action.payload;
  // BaseApiErrors originate from the BaseApiService
  // BaseApiService already shows an error toast, so we don't need to do that here
  if (!(error instanceof BaseApiError)) {
    yield call(toastService.showError, error.message);
  }
  yield put(SetSpreadUserState(SPREAD_USER_STATE.CHOOSING_FILES));
}

export function* ocrSaga(): Generator {
  yield all([
    takeLatest(OcrAction.START_OCR, startOcr),
    takeLatest(OcrAction.POLL_OCR_STATUS, pollOcrStatus),
    takeEvery(OcrAction.UPDATE_OCR_DATA, updateOcrData),
    takeLatest(OcrAction.OCR_COMPLETED, ocrCompleted),
    takeLatest(OcrAction.OCR_ERROR, ocrErrored),
  ]);
}

export default ocrSaga;
