diff --git a/rx.lua b/rx.lua index bef15ac..fcb5cf2 100644 --- a/rx.lua +++ b/rx.lua @@ -963,24 +963,29 @@ end function Observable:flatten() return Observable.create(function(observer) local subscriptions = {} + local remaining = 1 local function onError(message) return observer:onError(message) end + local function onCompleted() + remaining = remaining - 1 + if remaining == 0 then + return observer:onCompleted() + end + end + local function onNext(observable) local function innerOnNext(...) observer:onNext(...) end - local subscription = observable:subscribe(innerOnNext, onError, util.noop) + remaining = remaining + 1 + local subscription = observable:subscribe(innerOnNext, onError, onCompleted) subscriptions[#subscriptions + 1] = subscription end - local function onCompleted() - return observer:onCompleted() - end - subscriptions[#subscriptions + 1] = self:subscribe(onNext, onError, onCompleted) return Subscription.create(function () for i = 1, #subscriptions do diff --git a/src/operators/flatten.lua b/src/operators/flatten.lua index 381ce29..e5e898f 100644 --- a/src/operators/flatten.lua +++ b/src/operators/flatten.lua @@ -7,24 +7,29 @@ local util = require 'util' function Observable:flatten() return Observable.create(function(observer) local subscriptions = {} + local remaining = 1 local function onError(message) return observer:onError(message) end + local function onCompleted() + remaining = remaining - 1 + if remaining == 0 then + return observer:onCompleted() + end + end + local function onNext(observable) local function innerOnNext(...) observer:onNext(...) end - local subscription = observable:subscribe(innerOnNext, onError, util.noop) + remaining = remaining + 1 + local subscription = observable:subscribe(innerOnNext, onError, onCompleted) subscriptions[#subscriptions + 1] = subscription end - local function onCompleted() - return observer:onCompleted() - end - subscriptions[#subscriptions + 1] = self:subscribe(onNext, onError, onCompleted) return Subscription.create(function () for i = 1, #subscriptions do diff --git a/tests/flatMap.lua b/tests/flatMap.lua index 8b6749a..e339ac3 100644 --- a/tests/flatMap.lua +++ b/tests/flatMap.lua @@ -19,4 +19,17 @@ describe('flatMap', function() expect(observable).to.produce(1, 2, 3, 2, 3, 3) end) + + it('completes after all observables produced by its parent', function() + s = Rx.CooperativeScheduler.create() + local observable = Rx.Observable.fromRange(3):flatMap(function(i) + return Rx.Observable.fromRange(i, 3):delay(i, s) + end) + + local onNext, onError, onCompleted, order = observableSpy(observable) + repeat s:update(1) + until s:isEmpty() + expect(#onNext).to.equal(6) + expect(#onCompleted).to.equal(1) + end) end) diff --git a/tests/flatten.lua b/tests/flatten.lua index 9993490..69b044d 100644 --- a/tests/flatten.lua +++ b/tests/flatten.lua @@ -13,6 +13,19 @@ describe('flatten', function() expect(observable).to.produce(1, 2, 3, 2, 3, 3) end) + it('completes after all observables produced by its parent', function() + s = Rx.CooperativeScheduler.create() + local observable = Rx.Observable.fromRange(3):map(function(i) + return Rx.Observable.fromRange(i, 3):delay(i, s) + end):flatten() + + local onNext, onError, onCompleted, order = observableSpy(observable) + repeat s:update(1) + until s:isEmpty() + expect(#onNext).to.equal(6) + expect(#onCompleted).to.equal(1) + end) + it('should unsubscribe from all source observables', function() local unsubscribeA = spy() local observableA = Rx.Observable.create(function(observer)