diff --git a/README.md b/README.md index 2bc6e2b..14bce7e 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,21 @@ Task.startPolling(); await Task.schedule('sayHello', new Date(Date.now() + 1000)); ``` +## Parallelism + +You can control worker-level concurrency with `startPolling({ parallel })`. To limit how many +instances of a specific task can run concurrently on a worker, pass `maxParallelTasks` when +registering the handler. + +```javascript +Task.registerHandler('sendEmail', async function sendEmail(params) { + // ... +}, { maxParallelTasks: 2 }); + +// Allow up to 4 tasks to run in parallel on this worker +Task.startPolling({ parallel: 4 }); +``` + ## Params The 2nd param to `Task.schedule()` is an object that this framework will call the handler function with. @@ -57,4 +72,4 @@ await Task.schedule( { name: 'Friend' }, 5000 ); -``` \ No newline at end of file +``` diff --git a/src/taskSchema.js b/src/taskSchema.js index 26beb89..c59c9ec 100644 --- a/src/taskSchema.js +++ b/src/taskSchema.js @@ -116,7 +116,15 @@ taskSchema.methods.sideEffect = async function sideEffect(fn, params) { taskSchema.statics.startPolling = function startPolling(options) { const interval = options?.interval ?? 1000; const workerName = options?.workerName; - const pollOptions = workerName ? { workerName } : null; + const parallel = options?.parallel; + const pollOptions = {}; + if (workerName) { + pollOptions.workerName = workerName; + } + if (parallel != null) { + pollOptions.parallel = parallel; + } + const resolvedPollOptions = Object.keys(pollOptions).length ? pollOptions : null; let cancelled = false; let timeout = null; if (!this._cancel) { @@ -124,6 +132,7 @@ taskSchema.statics.startPolling = function startPolling(options) { this._cancel = () => { cancelled = true; clearTimeout(timeout); + this._cancel = null; }; } return this._cancel; @@ -138,7 +147,7 @@ taskSchema.statics.startPolling = function startPolling(options) { // Expire tasks that have timed out (refactored to separate function) await Task.expireTimedOutTasks(); - this._currentPoll = this.poll(pollOptions); + this._currentPoll = this.poll(resolvedPollOptions); await this._currentPoll.then( () => { timeout = setTimeout(() => doPoll.call(this), interval); @@ -196,9 +205,9 @@ taskSchema.statics.expireTimedOutTasks = async function expireTimedOutTasks() { } }; -taskSchema.statics.registerHandler = async function registerHandler(name, fn) { +taskSchema.statics.registerHandler = async function registerHandler(name, fn, options = {}) { this._handlers = this._handlers || new Map(); - this._handlers.set(name, fn); + this._handlers.set(name, { handler: fn, maxParallelTasks: options?.maxParallelTasks }); return this; }; @@ -235,7 +244,7 @@ taskSchema.statics.registerHandlers = async function registerHandlers(obj, prefi for (const key of Object.keys(obj)) { const fullPath = prefix ? `${prefix}.${key}` : key; if (typeof obj[key] === 'function') { - this._handlers.set(fullPath, obj[key]); + this._handlers.set(fullPath, { handler: obj[key] }); } else if (typeof obj[key] === 'object' && obj[key] != null) { this.registerHandlers(obj[key], fullPath); } @@ -253,13 +262,38 @@ taskSchema.statics.poll = async function poll(opts) { const workerName = opts?.workerName; const additionalParams = workerName ? { workerName } : {}; + const handlerLimits = new Map(); + if (this._handlers) { + for (const [name, handlerEntry] of this._handlers.entries()) { + const maxParallelTasks = typeof handlerEntry === 'function' ? null : handlerEntry?.maxParallelTasks; + if (typeof maxParallelTasks === 'number') { + handlerLimits.set(name, maxParallelTasks); + } + } + } + + const runningByName = new Map(); while (true) { const tasksInProgress = []; for (let i = 0; i < parallel; ++i) { const now = time.now(); + const blockedNames = []; + for (const [name, maxParallelTasks] of handlerLimits.entries()) { + const currentCount = runningByName.get(name) || 0; + if (currentCount >= maxParallelTasks) { + blockedNames.push(name); + } + } + const filter = { + status: 'pending', + scheduledAt: { $lte: now } + }; + if (blockedNames.length) { + filter.name = { $nin: blockedNames }; + } const task = await this.findOneAndUpdate( - { status: 'pending', scheduledAt: { $lte: now } }, + filter, { status: 'in_progress', startedRunningAt: now, @@ -273,7 +307,17 @@ taskSchema.statics.poll = async function poll(opts) { break; } - tasksInProgress.push(this.execute(task)); + const maxParallelTasks = handlerLimits.get(task.name); + if (typeof maxParallelTasks === 'number') { + const currentCount = runningByName.get(task.name) || 0; + runningByName.set(task.name, currentCount + 1); + } + tasksInProgress.push(this.execute(task).finally(() => { + if (typeof maxParallelTasks === 'number') { + const currentCount = runningByName.get(task.name) || 0; + runningByName.set(task.name, Math.max(currentCount - 1, 0)); + } + })); } if (tasksInProgress.length === 0) { @@ -285,7 +329,13 @@ taskSchema.statics.poll = async function poll(opts) { }; taskSchema.statics.execute = async function(task) { - if (!this._handlers.has(task.name)) { + if (!this._handlers || !this._handlers.has(task.name)) { + return null; + } + + const handlerEntry = this._handlers.get(task.name); + const handlerFn = typeof handlerEntry === 'function' ? handlerEntry : handlerEntry?.handler; + if (typeof handlerFn !== 'function') { return null; } @@ -306,7 +356,7 @@ taskSchema.statics.execute = async function(task) { if (typeof task.timeoutMS === 'number') { result = await Promise.race([ Promise.resolve( - this._handlers.get(task.name).call(task, task.params, task) + handlerFn.call(task, task.params, task) ), new Promise((_, reject) => { setTimeout(() => reject(new Error(`Task timed out after ${task.timeoutMS} ms`)), task.timeoutMS); @@ -314,7 +364,7 @@ taskSchema.statics.execute = async function(task) { ]); } else { result = await Promise.resolve( - this._handlers.get(task.name).call(task, task.params, task) + handlerFn.call(task, task.params, task) ); } task.status = 'succeeded'; diff --git a/test/task.test.js b/test/task.test.js index a4f958e..a6cae1d 100644 --- a/test/task.test.js +++ b/test/task.test.js @@ -6,6 +6,16 @@ const mongoose = require('mongoose'); const sinon = require('sinon'); const time = require('../src/time'); +const waitFor = async (predicate, timeoutMS = 1000) => { + const started = Date.now(); + while (!predicate()) { + if (Date.now() - started > timeoutMS) { + throw new Error('Timed out waiting for condition'); + } + await new Promise(resolve => setImmediate(resolve)); + } +}; + describe('Task', function() { let cancel = null; const now = new Date('2023-06-01'); @@ -199,6 +209,52 @@ describe('Task', function() { cancel(); }); + it('passes parallel option to startPolling()', async function() { + const pollStub = sinon.stub(Task, 'poll').callsFake(async () => {}); + + cancel = Task.startPolling({ interval: 100, parallel: 2 }); + + await Task._currentPoll; + + assert.ok(pollStub.calledWithMatch({ parallel: 2 })); + + cancel(); + pollStub.restore(); + }); + + it('honors maxParallelTasks per handler', async function() { + const pendingResolves = []; + let concurrent = 0; + let maxConcurrent = 0; + + Task.registerHandler('limitedTask', async () => { + concurrent += 1; + maxConcurrent = Math.max(maxConcurrent, concurrent); + await new Promise(resolve => pendingResolves.push(resolve)); + concurrent -= 1; + return 'done'; + }, { maxParallelTasks: 1 }); + + await Task.schedule('limitedTask', time.now(), {}); + await Task.schedule('limitedTask', time.now(), {}); + + const pollPromise = Task.poll({ parallel: 2 }); + + await waitFor(() => pendingResolves.length === 1); + assert.strictEqual(maxConcurrent, 1); + + pendingResolves.shift()(); + + await waitFor(() => pendingResolves.length === 1); + assert.strictEqual(maxConcurrent, 1); + + pendingResolves.shift()(); + + await pollPromise; + + assert.strictEqual(maxConcurrent, 1); + }); + it('catches errors in task', async function() { let resolve; let reject;