diff --git a/src/encoding/armor.js b/src/encoding/armor.js index 2d22739a..3ebb1d3e 100644 --- a/src/encoding/armor.js +++ b/src/encoding/armor.js @@ -267,8 +267,8 @@ function dearmor(input) { } })); data = stream.transformPair(data, async (readable, writable) => { - const checksumVerified = getCheckSum(stream.clone(readable)); - stream.pipe(readable, writable, { + const checksumVerified = getCheckSum(stream.passiveClone(readable)); + await stream.pipe(readable, writable, { preventClose: true }); const checksumVerifiedString = await stream.readToEnd(checksumVerified); @@ -306,7 +306,7 @@ function armor(messagetype, body, partindex, parttotal, customComment) { hash = body.hash; body = body.data; } - const bodyClone = stream.clone(body); + const bodyClone = stream.passiveClone(body); const result = []; switch (messagetype) { case enums.armor.multipart_section: diff --git a/src/packet/sym_encrypted_integrity_protected.js b/src/packet/sym_encrypted_integrity_protected.js index e03f01ef..001dd0e2 100644 --- a/src/packet/sym_encrypted_integrity_protected.js +++ b/src/packet/sym_encrypted_integrity_protected.js @@ -98,7 +98,7 @@ SymEncryptedIntegrityProtected.prototype.encrypt = async function (sessionKeyAlg const mdc = new Uint8Array([0xD3, 0x14]); // modification detection code packet let tohash = util.concat([bytes, mdc]); - const hash = crypto.hash.sha1(util.concat([prefix, stream.clone(tohash)])); + const hash = crypto.hash.sha1(util.concat([prefix, stream.passiveClone(tohash)])); tohash = util.concat([tohash, hash]); if (sessionKeyAlgorithm.substr(0, 3) === 'aes') { // AES optimizations. Native code for node, asmCrypto for browser. @@ -120,7 +120,7 @@ SymEncryptedIntegrityProtected.prototype.encrypt = async function (sessionKeyAlg */ SymEncryptedIntegrityProtected.prototype.decrypt = async function (sessionKeyAlgorithm, key) { const encrypted = stream.clone(this.encrypted); - const encryptedClone = stream.clone(encrypted); + const encryptedClone = stream.passiveClone(encrypted); let decrypted; if (sessionKeyAlgorithm.substr(0, 3) === 'aes') { // AES optimizations. Native code for node, asmCrypto for browser. decrypted = aesDecrypt(sessionKeyAlgorithm, encrypted, key); @@ -132,21 +132,22 @@ SymEncryptedIntegrityProtected.prototype.decrypt = async function (sessionKeyAlg // last packet and everything gets hashed except the hash itself const encryptedPrefix = await stream.readToEnd(stream.slice(encryptedClone, 0, crypto.cipher[sessionKeyAlgorithm].blockSize + 2)); const prefix = crypto.cfb.mdc(sessionKeyAlgorithm, key, encryptedPrefix); - const bytes = stream.slice(stream.clone(decrypted), 0, -20); - const tohash = util.concat([prefix, stream.clone(bytes)]); + const realHash = stream.slice(stream.passiveClone(decrypted), -20); + const bytes = stream.slice(decrypted, 0, -20); + const tohash = util.concat([prefix, stream.passiveClone(bytes)]); const verifyHash = Promise.all([ stream.readToEnd(crypto.hash.sha1(tohash)), - stream.readToEnd(stream.slice(decrypted, -20)) + stream.readToEnd(realHash) ]).then(([hash, mdc]) => { if (!util.equalsUint8Array(hash, mdc)) { throw new Error('Modification detected.'); } + return new Uint8Array(); }); let packetbytes = stream.slice(bytes, 0, -2); + packetbytes = stream.concat([packetbytes, stream.fromAsync(() => verifyHash)]); if (!util.isStream(encrypted) || !config.unsafe_stream) { - await verifyHash; - } else { - packetbytes = stream.concat([packetbytes, stream.fromAsync(() => verifyHash)]); + packetbytes = await stream.readToEnd(packetbytes); } await this.packets.read(packetbytes); return true; diff --git a/src/stream.js b/src/stream.js index dadebc33..c003a6fb 100644 --- a/src/stream.js +++ b/src/stream.js @@ -10,7 +10,7 @@ function toStream(input) { if (util.isStream(input)) { return input; } - return create({ + return new ReadableStream({ start(controller) { controller.enqueue(input); controller.close(); @@ -20,22 +20,9 @@ function toStream(input) { function concat(arrays) { arrays = arrays.map(toStream); - let outputController; - const transform = { - readable: new ReadableStream({ - start(_controller) { - outputController = _controller; - }, - async cancel(reason) { - await Promise.all(transforms.map(array => cancel(array, reason))); - } - }), - writable: new WritableStream({ - write: outputController.enqueue.bind(outputController), - close: outputController.close.bind(outputController), - abort: outputController.error.bind(outputController) - }) - }; + const transform = transformWithCancel(async function(reason) { + await Promise.all(transforms.map(array => cancel(array, reason))); + }); let prev = Promise.resolve(); const transforms = arrays.map((array, i) => transformPair(array, (readable, writable) => { prev = prev.then(() => pipe(readable, transform.writable, { @@ -54,19 +41,6 @@ function getWriter(input) { return input.getWriter(); } -function create(options, extraArg) { - const promises = new Map(); - const wrap = fn => fn && (controller => { - const returnValue = fn.call(options, controller, extraArg); - promises.set(fn, returnValue); - return returnValue; - }); - options.options = Object.assign({}, options); - options.start = wrap(options.start); - options.pull = wrap(options.pull); - return new ReadableStream(options); -} - async function pipe(input, target, options) { if (!util.isStream(input)) { input = toStream(input); @@ -83,12 +57,39 @@ async function pipe(input, target, options) { } function transformRaw(input, options) { - options.cancel = cancel.bind(input); const transformStream = new TransformStream(options); pipe(input, transformStream.writable); return transformStream.readable; } +function transformWithCancel(cancel) { + let backpressureChangePromiseResolve = function() {}; + let outputController; + return { + readable: new ReadableStream({ + start(controller) { + outputController = controller; + }, + pull() { + backpressureChangePromiseResolve(); + }, + cancel + }), + writable: new WritableStream({ + write: async function(chunk) { + outputController.enqueue(chunk); + if (outputController.desiredSize <= 0) { + await new Promise(resolve => { + backpressureChangePromiseResolve = resolve; + }); + } + }, + close: outputController.close.bind(outputController), + abort: outputController.error.bind(outputController) + }) + }; +} + function transform(input, process = () => undefined, finish = () => undefined) { if (util.isStream(input)) { return transformRaw(input, { @@ -131,23 +132,10 @@ function transformPair(input, fn) { } }); - let outputController; - const outgoing = { - readable: new ReadableStream({ - start(_controller) { - outputController = _controller; - }, - async cancel() { - incomingTransformController.error(canceledErr); - await pipeDonePromise; - } - }), - writable: new WritableStream({ - write: outputController.enqueue.bind(outputController), - close: outputController.close.bind(outputController), - abort: outputController.error.bind(outputController) - }) - }; + const outgoing = transformWithCancel(async function() { + incomingTransformController.error(canceledErr); + await pipeDonePromise; + }); Promise.resolve(fn(incoming.readable, outgoing.writable)).catch(e => { if (e !== canceledErr) { throw e; @@ -182,23 +170,53 @@ function tee(input) { function clone(input) { if (util.isStream(input)) { const teed = tee(input); - // Overwrite input.getReader, input.locked, etc to point to teed[0] - Object.entries(Object.getOwnPropertyDescriptors(ReadableStream.prototype)).forEach(([name, descriptor]) => { - if (name === 'constructor') { - return; - } - if (descriptor.value) { - descriptor.value = descriptor.value.bind(teed[0]); - } else { - descriptor.get = descriptor.get.bind(teed[0]); - } - Object.defineProperty(input, name, descriptor); - }); + overwrite(input, teed[0]); return teed[1]; } return slice(input); } +function passiveClone(input) { + if (util.isStream(input)) { + return new ReadableStream({ + start(controller) { + const transformed = transformPair(input, async (readable, writable) => { + const reader = getReader(readable); + const writer = getWriter(writable); + while (true) { + await writer.ready; + const { done, value } = await reader.read(); + if (done) { + try { controller.close(); } catch(e) {} + await writer.close(); + return; + } + try { controller.enqueue(value); } catch(e) {} + await writer.write(value); + } + }); + overwrite(input, transformed); + } + }); + } + return slice(input); +} + +function overwrite(input, clone) { + // Overwrite input.getReader, input.locked, etc to point to clone + Object.entries(Object.getOwnPropertyDescriptors(ReadableStream.prototype)).forEach(([name, descriptor]) => { + if (name === 'constructor') { + return; + } + if (descriptor.value) { + descriptor.value = descriptor.value.bind(clone); + } else { + descriptor.get = descriptor.get.bind(clone); + } + Object.defineProperty(input, name, descriptor); + }); +} + function slice(input, begin=0, end=Infinity) { if (util.isStream(input)) { if (begin >= 0 && end >= 0) { @@ -344,7 +362,7 @@ if (nodeStream) { } -export default { toStream, concat, getReader, getWriter, pipe, transformRaw, transform, transformPair, parse, clone, slice, readToEnd, cancel, nodeToWeb, webToNode, fromAsync }; +export default { toStream, concat, getReader, getWriter, pipe, transformRaw, transform, transformPair, parse, clone, passiveClone, slice, readToEnd, cancel, nodeToWeb, webToNode, fromAsync }; const doneReadingSet = new WeakSet(); diff --git a/test/general/streaming.js b/test/general/streaming.js index 01674dcd..b9baf718 100644 --- a/test/general/streaming.js +++ b/test/general/streaming.js @@ -500,4 +500,73 @@ describe('Streaming', function() { openpgp.config.aead_chunk_size_byte = aead_chunk_size_byteValue; } }); + + it("Don't pull entire input stream when we're not pulling encrypted stream", async function() { + let plaintext = []; + let i = 0; + const data = new ReadableStream({ + async pull(controller) { + if (i++ < 100) { + let randomBytes = await openpgp.crypto.random.getRandomBytes(1024); + controller.enqueue(randomBytes); + plaintext.push(randomBytes); + } else { + controller.close(); + } + await new Promise(setTimeout); + } + }); + const encrypted = await openpgp.encrypt({ + data, + passwords: ['test'], + }); + const reader = openpgp.stream.getReader(encrypted.data); + expect(await reader.readBytes(1024)).to.match(/^-----BEGIN PGP MESSAGE-----\r\n/); + if (i > 10) throw new Error('Data did not arrive early.'); + await new Promise(resolve => setTimeout(resolve, 3000)); + expect(i).to.be.lessThan(50); + }); + + it("Don't pull entire input stream when we're not pulling decrypted stream (draft04)", async function() { + let aead_protectValue = openpgp.config.aead_protect; + let aead_chunk_size_byteValue = openpgp.config.aead_chunk_size_byte; + openpgp.config.aead_protect = true; + openpgp.config.aead_chunk_size_byte = 4; + try { + let plaintext = []; + let i = 0; + const data = new ReadableStream({ + async pull(controller) { + if (i++ < 100) { + let randomBytes = await openpgp.crypto.random.getRandomBytes(1024); + controller.enqueue(randomBytes); + plaintext.push(randomBytes); + } else { + controller.close(); + } + await new Promise(setTimeout); + } + }); + const encrypted = await openpgp.encrypt({ + data, + passwords: ['test'], + }); + const msgAsciiArmored = encrypted.data; + const message = await openpgp.message.readArmored(msgAsciiArmored); + const decrypted = await openpgp.decrypt({ + passwords: ['test'], + message, + format: 'binary' + }); + expect(util.isStream(decrypted.data)).to.be.true; + const reader = openpgp.stream.getReader(decrypted.data); + expect(await reader.readBytes(1024)).to.deep.equal(plaintext[0]); + if (i > 10) throw new Error('Data did not arrive early.'); + await new Promise(resolve => setTimeout(resolve, 3000)); + expect(i).to.be.lessThan(50); + } finally { + openpgp.config.aead_protect = aead_protectValue; + openpgp.config.aead_chunk_size_byte = aead_chunk_size_byteValue; + } + }); });