From 51c897b073991dc18af7d77d0dfbd4807620d121 Mon Sep 17 00:00:00 2001
From: Daniel Huigens <d.huigens@protonmail.com>
Date: Mon, 4 Jun 2018 17:19:57 +0200
Subject: [PATCH] Cancelling

---
 src/encoding/armor.js                      |   7 +-
 src/packet/packet.js                       |   3 +-
 src/packet/packetlist.js                   |   3 +-
 src/packet/sym_encrypted_aead_protected.js |   7 +-
 src/stream.js                              | 127 ++++++++++++++++-----
 test/general/streaming.js                  |  81 +++++++++++++
 6 files changed, 188 insertions(+), 40 deletions(-)

diff --git a/src/encoding/armor.js b/src/encoding/armor.js
index 65da3d79..9d2057af 100644
--- a/src/encoding/armor.js
+++ b/src/encoding/armor.js
@@ -206,17 +206,18 @@ function dearmor(input) {
       const reSplit = /^-----[^-]+-----$/;
       const reEmptyLine = /^[ \f\r\t\u00a0\u2000-\u200a\u202f\u205f\u3000]*$/;
 
-      const reader = stream.getReader(input);
       let type;
       const headers = [];
       let lastHeaders = headers;
       let headersDone;
       let text = [];
       let textDone;
+      let reader;
       let controller;
-      let data = base64.decode(new ReadableStream({
-        async start(_controller) {
+      let data = base64.decode(stream.from(input, {
+        start(_controller, _reader) {
           controller = _controller;
+          reader = _reader;
         }
       }));
       let checksum;
diff --git a/src/packet/packet.js b/src/packet/packet.js
index 1ab4da36..bc4e0b6d 100644
--- a/src/packet/packet.js
+++ b/src/packet/packet.js
@@ -223,7 +223,8 @@ export default {
                 // eslint-disable-next-line no-loop-func
                 async start(_controller) {
                   controller = _controller;
-                }
+                },
+                cancel: stream.cancel.bind(input)
               });
               callback({ tag, packet });
             }
diff --git a/src/packet/packetlist.js b/src/packet/packetlist.js
index 3c27ab4a..232e8780 100644
--- a/src/packet/packetlist.js
+++ b/src/packet/packetlist.js
@@ -62,7 +62,8 @@ List.prototype.read = async function (bytes) {
       } catch(e) {
         controller.error(e);
       }
-    }
+    },
+    cancel: stream.cancel.bind(bytes)
   });
 
   // Wait until first few packets have been read
diff --git a/src/packet/sym_encrypted_aead_protected.js b/src/packet/sym_encrypted_aead_protected.js
index d5a94364..e9fdf5ee 100644
--- a/src/packet/sym_encrypted_aead_protected.js
+++ b/src/packet/sym_encrypted_aead_protected.js
@@ -138,14 +138,13 @@ SymEncryptedAEADProtected.prototype.crypt = async function (fn, key, data) {
     const adataView = new DataView(adataBuffer);
     const chunkIndexArray = new Uint8Array(adataBuffer, 5, 8);
     adataArray.set([0xC0 | this.tag, this.version, this.cipherAlgo, this.aeadAlgo, this.chunkSizeByte], 0);
-    const reader = stream.getReader(data);
     let chunkIndex = 0;
     let latestPromise = Promise.resolve();
     let cryptedBytes = 0;
     let queuedBytes = 0;
     const iv = this.iv;
-    return new ReadableStream({
-      async pull(controller) {
+    return stream.from(data, {
+      async pull(controller, reader) {
         let chunk = await reader.readBytes(chunkSize + tagLengthIfDecrypting) || new Uint8Array();
         const finalChunk = chunk.subarray(chunk.length - tagLengthIfDecrypting);
         chunk = chunk.subarray(0, chunk.length - tagLengthIfDecrypting);
@@ -174,7 +173,7 @@ SymEncryptedAEADProtected.prototype.crypt = async function (fn, key, data) {
         }
         if (!done) {
           adataView.setInt32(5 + 4, ++chunkIndex); // Should be setInt64(5, ...)
-          await this.pull(controller);
+          await this.pull(controller, reader);
         } else {
           controller.close();
         }
diff --git a/src/stream.js b/src/stream.js
index 74f00559..bff34a02 100644
--- a/src/stream.js
+++ b/src/stream.js
@@ -1,15 +1,15 @@
 import util from './util';
 
-if (typeof ReadableStream === 'undefined') {
+// if (typeof ReadableStream === 'undefined') {
   Object.assign(typeof window !== 'undefined' ? window : global, require('web-streams-polyfill'));
-}
+// }
 
 const nodeStream = util.getNodeStream();
 
 function concat(arrays) {
   const readers = arrays.map(getReader);
   let current = 0;
-  return new ReadableStream({
+  return create({
     async pull(controller) {
       try {
         const { done, value } = await readers[current].read();
@@ -18,11 +18,15 @@ function concat(arrays) {
         } else if (++current === arrays.length) {
           controller.close();
         } else {
-          await this.pull(controller); // ??? Chrome bug?
+          await this.pull(controller);
         }
       } catch(e) {
         controller.error(e);
       }
+    },
+    cancel() {
+      readers.forEach(reader => reader.releaseLock());
+      return Promise.all(arrays.map(cancel));
     }
   });
 }
@@ -31,16 +35,55 @@ function getReader(input) {
   return new Reader(input);
 }
 
+function create(options, extraArg) {
+  const promises = new Map();
+  const wrap = fn => fn && (controller => {
+    const returnValue = fn.call(options, controller, extraArg);
+    promises.set(fn, returnValue);
+    return returnValue;
+  });
+  options.start = wrap(options.start);
+  options.pull = wrap(options.pull);
+  const _cancel = options.cancel;
+  options.cancel = async controller => {
+    try {
+      console.log('cancel wrapper', options);
+      await promises.get(options.start);
+      console.log('awaited start');
+      await promises.get(options.pull);
+      console.log('awaited pull');
+    } finally {
+      if (_cancel) return _cancel.call(options, controller, extraArg);
+    }
+  };
+  options.options = options;
+  return new ReadableStream(options);
+}
+
+function from(input, options) {
+  const reader = getReader(input);
+  if (!options.cancel) {
+    options.cancel = (controller, reader) => {
+      console.log('from() cancel', stream, input);
+      reader.releaseLock();
+      return cancel(input);
+    };
+  }
+  options.from = input;
+  const stream = create(options, reader);
+  stream.from = input;
+  return stream;
+}
+
 function transform(input, process = () => undefined, finish = () => undefined) {
   if (util.isStream(input)) {
-    const reader = getReader(input);
-    return new ReadableStream({
-      async pull(controller) {
+    return from(input, {
+      async pull(controller, reader) {
         try {
           const { done, value } = await reader.read();
           const result = await (!done ? process : finish)(value);
           if (result !== undefined) controller.enqueue(result);
-          else if (!done) await this.pull(controller); // ??? Chrome bug?
+          else if (!done) await this.pull(controller, reader);
           if (done) controller.close();
         } catch(e) {
           controller.error(e);
@@ -68,7 +111,9 @@ function clone(input) {
     const teed = tee(input);
     // Overwrite input.getReader, input.locked, etc to point to teed[0]
     Object.entries(Object.getOwnPropertyDescriptors(ReadableStream.prototype)).forEach(([name, descriptor]) => {
-      if (name === 'constructor') return;
+      if (name === 'constructor') {
+        return;
+      }
       if (descriptor.value) {
         descriptor.value = descriptor.value.bind(teed[0]);
       } else {
@@ -84,17 +129,16 @@ function clone(input) {
 function slice(input, begin=0, end=Infinity) {
   if (util.isStream(input)) {
     if (begin >= 0 && end >= 0) {
-      const reader = getReader(input);
       let bytesRead = 0;
-      return new ReadableStream({
-        async pull (controller) {
+      return from(input, {
+        async pull (controller, reader) {
           const { done, value } = await reader.read();
           if (!done && bytesRead < end) {
             if (bytesRead + value.length >= begin) {
               controller.enqueue(slice(value, Math.max(begin - bytesRead, 0), end - bytesRead));
             }
             bytesRead += value.length;
-            await this.pull(controller); // Only necessary if the above call to enqueue() didn't happen
+            await this.pull(controller, reader); // Only necessary if the above call to enqueue() didn't happen
           } else {
             controller.close();
           }
@@ -229,10 +273,10 @@ if (nodeStream) {
 }
 
 
-export default { concat, getReader, transform, clone, slice, readToEnd, cancel, nodeToWeb, webToNode, fromAsync };
+export default { concat, getReader, from, transform, clone, slice, readToEnd, cancel, nodeToWeb, webToNode, fromAsync };
 
 
-/*const readerAcquiredMap = new Map();
+const readerAcquiredMap = new Map();
 
 const _getReader = ReadableStream.prototype.getReader;
 ReadableStream.prototype.getReader = function() {
@@ -245,7 +289,9 @@ ReadableStream.prototype.getReader = function() {
   const reader = _getReader.apply(this, arguments);
   const _releaseLock = reader.releaseLock;
   reader.releaseLock = function() {
-    readerAcquiredMap.delete(_this);
+    try {
+      readerAcquiredMap.delete(_this);
+    } catch(e) {}
     return _releaseLock.apply(this, arguments);
   };
   return reader;
@@ -259,7 +305,20 @@ ReadableStream.prototype.tee = function() {
     readerAcquiredMap.set(this, new Error('Reader for this ReadableStream already acquired here.'));
   }
   return _tee.apply(this, arguments);
-};*/
+};
+
+const _cancel = ReadableStream.prototype.cancel;
+ReadableStream.prototype.cancel = function() {
+  try {
+    return _cancel.apply(this, arguments);
+  } finally {
+    if (readerAcquiredMap.has(this)) {
+      console.error(readerAcquiredMap.get(this));
+    } else {
+      readerAcquiredMap.set(this, new Error('Reader for this ReadableStream already acquired here.'));
+    }
+  }
+};
 
 
 const doneReadingSet = new WeakSet();
@@ -284,7 +343,9 @@ function Reader(input) {
   };
   this._releaseLock = () => {
     if (doneReading) {
-      doneReadingSet.add(input);
+      try {
+        doneReadingSet.add(input);
+      } catch(e) {}
     }
   };
 }
@@ -298,7 +359,9 @@ Reader.prototype.read = async function() {
 };
 
 Reader.prototype.releaseLock = function() {
-  this.stream.externalBuffer = this.externalBuffer;
+  if (this.externalBuffer) {
+    this.stream.externalBuffer = this.externalBuffer;
+  }
   this._releaseLock();
 };
 
@@ -365,19 +428,21 @@ Reader.prototype.unshift = function(...values) {
 };
 
 Reader.prototype.substream = function() {
-  return new ReadableStream({ pull: pullFrom(this) });
-};
-
-function pullFrom(reader) {
-  return async controller => {
-    const { done, value } = await reader.read();
-    if (!done) {
-      controller.enqueue(value);
-    } else {
-      controller.close();
+  return Object.assign(create({
+    pull: async controller => {
+      const { done, value } = await this.read();
+      if (!done) {
+        controller.enqueue(value);
+      } else {
+        controller.close();
+      }
+    },
+    cancel: () => {
+      this.releaseLock();
+      return cancel(this.stream);
     }
-  };
-}
+  }), { from: this.stream });
+};
 
 Reader.prototype.readToEnd = async function(join=util.concat) {
   const result = [];
diff --git a/test/general/streaming.js b/test/general/streaming.js
index 4a99633c..866f4cd4 100644
--- a/test/general/streaming.js
+++ b/test/general/streaming.js
@@ -127,6 +127,38 @@ describe('Streaming', function() {
     expect(decrypted.data).to.deep.equal(util.concatUint8Array(plaintext));
   });
 
+  it('Input stream should be canceled when canceling encrypted stream', async function() {
+    let plaintext = [];
+    let i = 0;
+    let canceled = false;
+    const data = new ReadableStream({
+      async pull(controller) {
+        if (i++ < 10) {
+          let randomBytes = await openpgp.crypto.random.getRandomBytes(1024);
+          controller.enqueue(randomBytes);
+          plaintext.push(randomBytes);
+        } else {
+          controller.close();
+        }
+      },
+      cancel() {
+        canceled = true;
+      }
+    });
+    const encrypted = await openpgp.encrypt({
+      data,
+      passwords: ['test'],
+    });
+    const reader = openpgp.stream.getReader(encrypted.data);
+    console.log('read start');
+    expect(await reader.readBytes(1024)).to.match(/^-----BEGIN PGP MESSAGE-----\r\nVersion: OpenPGP.js VERSION\r\nComment: https:\/\/openpgpjs.org\r\n\r\n/);
+    console.log('read end');
+    if (i > 10) throw new Error('Data did not arrive early.');
+    reader.releaseLock();
+    await openpgp.stream.cancel(encrypted.data);
+    expect(canceled).to.be.true;
+  });
+
   it('Encrypt and decrypt larger message roundtrip', async function() {
     let plaintext = [];
     let i = 0;
@@ -378,4 +410,53 @@ describe('Streaming', function() {
       openpgp.config.aead_chunk_size_byte = aead_chunk_size_byteValue;
     }
   });
+
+  it('Input stream should be canceled when canceling decrypted stream (draft04)', async function() {
+    let aead_protectValue = openpgp.config.aead_protect;
+    let aead_chunk_size_byteValue = openpgp.config.aead_chunk_size_byte;
+    openpgp.config.aead_protect = true;
+    openpgp.config.aead_chunk_size_byte = 4;
+    try {
+      let plaintext = [];
+      let i = 0;
+      let canceled = false;
+      const data = new ReadableStream({
+        async pull(controller) {
+          await new Promise(setTimeout);
+          if (i++ < 10) {
+            let randomBytes = await openpgp.crypto.random.getRandomBytes(1024);
+            controller.enqueue(randomBytes);
+            plaintext.push(randomBytes);
+          } else {
+            controller.close();
+          }
+        },
+        cancel() {
+          canceled = true;
+        }
+      });
+      const encrypted = await openpgp.encrypt({
+        data,
+        passwords: ['test'],
+      });
+
+      const msgAsciiArmored = encrypted.data;
+      const message = await openpgp.message.readArmored(msgAsciiArmored);
+      const decrypted = await openpgp.decrypt({
+        passwords: ['test'],
+        message,
+        format: 'binary'
+      });
+      expect(util.isStream(decrypted.data)).to.be.true;
+      const reader = openpgp.stream.getReader(openpgp.stream.clone(decrypted.data));
+      expect(await reader.readBytes(1024)).to.deep.equal(plaintext[0]);
+      if (i > 10) throw new Error('Data did not arrive early.');
+      reader.releaseLock();
+      await openpgp.stream.cancel(decrypted.data);
+      expect(canceled).to.be.true;
+    } finally {
+      openpgp.config.aead_protect = aead_protectValue;
+      openpgp.config.aead_chunk_size_byte = aead_chunk_size_byteValue;
+    }
+  });
 });