// Copyright 2022 Storro B.V.
// All rights reserved.
// Dit werk is auteursrechtelijk beschermd.
//
// This class should function equivalent to
// Storro/Base/Cryptography/ConvergentEncryption.h
//
import { decode, encode } from '@stablelib/utf8';
import base64url from 'base64url';
import { EcasValue } from '../Serialization/EcasValue';
import { concat } from '../Util/Concat';
import { equal } from '../Util/Equal';
import { number32ToUint8Array } from '../Util/Numeric';
import { Hash } from './Hash';
import { Algorithm, SymmetricKey } from './SymmetricKey';

// This must be equal to Storro's version.
const maxSegmentSize = 262144;

export interface ConvergentEncryptionJson {
  realmSalt: string;
}

export class ConvergentValue {
  constructor(
    private valueHash: Uint8Array,
    private ciphertext: Uint8Array,
  ) {}
  getValueHash(): Uint8Array {
    return this.valueHash;
  }
  getCiphertext(): Uint8Array {
    return this.ciphertext;
  }

  public isEqual(other: ConvergentValue): boolean {
    if (!equal(this.ciphertext, other.ciphertext)) return false;
    if (!equal(this.valueHash, other.valueHash)) return false;
    return true;
  }
}

const tagBytes = 16;

// This should work equivalent to the cpp version:
// https://gitlab.coblue.eu/storro/storro/-/blob/QuickShareCrypto/Base/Cryptography/ConvergentEncryption.cpp#L243-254
function parseAssociated(ciphertext: Uint8Array): Uint8Array {
  // At minimum we assume one ciphertext byte and one length encoding byte.
  if (ciphertext.length < tagBytes + 2) throw 'Encrypted value is too small';
  let index = ciphertext.length - 1;
  const associatedLength = ciphertext[index];
  if (index <= associatedLength) throw 'Encrypted value associated length is corrupt';
  index -= associatedLength;
  return ciphertext.slice(index, index + associatedLength);
}

export class ConvergentEncryption {
  // https://gitlab.coblue.eu/storro/storro/-/blob/master/Base/Cryptography/ConvergentEncryption.cpp#L22
  static maxSegmentSize = 262144;

  // https://gitlab.coblue.eu/storro/storro/-/blob/master/Base/Cryptography/ConvergentEncryption.cpp#L19
  static tagBytes = 16;

  constructor(public readonly realmSalt: Uint8Array) {
    if (this.realmSalt.length !== 32) throw new Error('RealmSalt needs to be 32 bytes');
  }

  public toJson(): ConvergentEncryptionJson {
    return { realmSalt: base64url.encode(Buffer.from(this.realmSalt)) };
  }

  public static fromJson(json: ConvergentEncryptionJson): ConvergentEncryption {
    return new ConvergentEncryption(base64url.toBuffer(json.realmSalt));
  }

  keySize(): number {
    return 32;
  }

  // TODO This only uses the simpler algorithm for small values.
  //
  // Values are split into segments (<= maxSegmentSize bytes)
  // and each segment has a different encryption key derived
  // from the nonce which is a hash over the entire value.
  //
  // See Zooid/Source/Cryptography/ConvergentEncryption.{h, cpp}
  // for more details.
  async encrypt(ecasValue: EcasValue): Promise<ConvergentValue> {
    const value = ecasValue.getValue();
    const associated = ecasValue.getAssociated();
    const plaintextHash = await Hash.blake2bTree(value);
    const valueHash = await Hash.blake2b(concat([this.realmSalt, plaintextHash]));
    const rawKey = await Hash.blake2b(concat([this.realmSalt, valueHash, associated]));
    const symKey = new SymmetricKey(rawKey, Algorithm.XCHACHA);

    // create segments + nonces
    // map-reduce
    let nonceCounter = 0;
    const encryptedSections: Uint8Array[] = [];
    for (let offset = 0; offset < value.byteLength; offset += maxSegmentSize) {
      const size = Math.min(value.byteLength - offset, maxSegmentSize);
      const section = new Uint8Array(value.buffer, offset, size);
      const nonce = concat([valueHash.slice(0, 20), number32ToUint8Array(nonceCounter)]);
      encryptedSections.push(symKey.encrypt(section, nonce, associated));
      nonceCounter++;
    }
    const ciphertext = concat(encryptedSections);
    const associatedLength = new Uint8Array([associated.length]);
    const encrypted = concat([ciphertext, associated, associatedLength]);
    return new ConvergentValue(valueHash, encrypted);
  }

  // This is the synchronous version of encrypt().
  public encryptSync(ecasValue: EcasValue): ConvergentValue {
    const value = ecasValue.getValue();
    const associated = ecasValue.getAssociated();
    if (value.byteLength === 0) throw 'Value needs to be non-empty';
    const plaintextHash = Hash.blake2bTreeSync(value);
    const valueHash = Hash.blake2bSync(concat([this.realmSalt, plaintextHash]));
    const rawKey = Hash.blake2bSync(concat([this.realmSalt, valueHash]));
    const symKey = new SymmetricKey(rawKey, Algorithm.XCHACHA);

    // create segments + nonces
    // map-reduce
    let nonceCounter = 0;
    const encryptedSections: Uint8Array[] = [];
    for (let offset = 0; offset < value.byteLength; offset += maxSegmentSize) {
      const size = Math.min(value.byteLength - offset, maxSegmentSize);
      const section = new Uint8Array(value.buffer, offset, size);
      const nonce = concat([valueHash.slice(0, 20), number32ToUint8Array(nonceCounter)]);
      encryptedSections.push(symKey.encrypt(section, nonce, associated));
      nonceCounter++;
    }
    const ciphertext = concat(encryptedSections);
    const associatedLength = new Uint8Array([associated.length]);
    const encrypted = concat([ciphertext, associated, associatedLength]);
    return new ConvergentValue(valueHash, encrypted);
  }

  async decrypt(convValue: ConvergentValue): Promise<EcasValue> {
    const ciphertext = convValue.getCiphertext();
    const valueHash = convValue.getValueHash();
    const associated = parseAssociated(ciphertext);
    const encrypted = ciphertext.slice(0, ciphertext.length - (associated.length + 1));
    const rawKey = await Hash.blake2b(concat([this.realmSalt, valueHash, associated]));
    const symKey = new SymmetricKey(rawKey, Algorithm.XCHACHA);
    const maxCiphertextSegment = maxSegmentSize + symKey.tagBytes();

    if (encrypted.length <= 16) throw new Error('Input is too short');
    // Subtract 32 bytes for the value hash and one for key index.

    let nonceCounter = 0;
    const decryptedSections = new Array<Uint8Array>();
    for (let offset = 0; offset < encrypted.length; offset += maxCiphertextSegment) {
      const size = Math.min(encrypted.length - offset, maxCiphertextSegment);
      const section = new Uint8Array(encrypted.buffer, offset, size);
      const nonce = concat([valueHash.slice(0, 20), number32ToUint8Array(nonceCounter)]);
      const d = symKey.decrypt(section, nonce, associated);
      decryptedSections.push(d);
      nonceCounter++;
    }
    const value = concat(decryptedSections);
    return new EcasValue(value, associated);
  }

  // This is the synchronous version of decrypt().
  public decryptSync(convValue: ConvergentValue): EcasValue {
    const ciphertext = convValue.getCiphertext();
    const valueHash = convValue.getValueHash();
    const associated = parseAssociated(ciphertext);
    const encrypted = ciphertext.slice(0, ciphertext.length - (associated.length + 1));
    const rawKey = Hash.blake2bSync(concat([this.realmSalt, valueHash, associated]));
    const symKey = new SymmetricKey(rawKey, Algorithm.XCHACHA);
    const maxCiphertextSegment = maxSegmentSize + symKey.tagBytes();

    if (encrypted.length <= 16) throw new Error('Input is too short');
    // Subtract 32 bytes for the value hash and one for key index.

    let nonceCounter = 0;
    const decryptedSections = new Array<Uint8Array>();
    for (let offset = 0; offset < encrypted.length; offset += maxCiphertextSegment) {
      const size = Math.min(encrypted.length - offset, maxCiphertextSegment);
      const section = new Uint8Array(encrypted.buffer, offset, size);
      const nonce = concat([valueHash.slice(0, 20), number32ToUint8Array(nonceCounter)]);
      const d = symKey.decrypt(section, nonce, associated);
      decryptedSections.push(d);
      nonceCounter++;
    }

    const value = concat(decryptedSections);
    return new EcasValue(value, associated);
  }

  public static encryptName(contentKey: SymmetricKey, name: string): Uint8Array {
    const plaintext = encode(name);
    // Use the hash of the content key as a static nonce. (Is that ok?)
    const n = Hash.blake2bSync(contentKey.getRawKey());
    const nonce = n.slice(0, SymmetricKey.nonceBytesForAlgo(Algorithm.XCHACHA));
    // Use contentkey and nonce to encrypt
    return contentKey.encrypt(plaintext, nonce);
  }

  public static decryptName(contentKey: SymmetricKey, ciphertext: Uint8Array): string {
    // Use the hash of the content key as a static nonce. (Is that ok?)
    const n = Hash.blake2bSync(contentKey.getRawKey());
    const nonce = n.slice(0, SymmetricKey.nonceBytesForAlgo(Algorithm.XCHACHA));
    // Use contentkey and nonce to encrypt
    const nameAsUint8Array = contentKey.decrypt(ciphertext, nonce);
    return decode(nameAsUint8Array);
  }
}
