-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Add model manager that automatically manage model across processes #37113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @AMOOOMA, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a sophisticated model management system for Apache Beam's ML inference capabilities. The core "ModelManager" class, supported by "GPUMonitor" and "ResourceEstimator", intelligently handles the lifecycle of machine learning models, particularly on GPU-accelerated environments. It aims to prevent out-of-memory errors by dynamically estimating model memory requirements, isolating unknown models for profiling, and implementing a demand-aware eviction strategy. This system ensures efficient and concurrent execution of diverse ML models within Beam pipelines, optimizing GPU resource utilization and improving overall stability. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
R: @damccorm |
|
Stopping reviewer notifications for this pull request: review requested by someone other than the bot, ceding control. If you'd like to restart, comment |
| logger.info("Initial Profile for %s: %s MB", model_tag, cost) | ||
|
|
||
| def add_observation( | ||
| self, active_snapshot: Dict[str, int], peak_memory: float): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm having a little bit of a hard time following this. Is active_snapshot a map of model tags to the number of models loaded for that tag?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep! That's correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we update the log below to be clear on that? Ideally it would format the dict in a human readable way that makes that clear
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes for sure, updated the print string.
damccorm
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been looking at it for a while today and I am still having a hard time understanding the full scope of this PR (even with this being the second long look), and it will probably take a few more passes.
For now, added some things that will help me review better, but if there are pieces we can separate out further to make this more reviewable (either by splitting large functions apart or by pulling classes out of the PR), that would be quite helpful.
| logger.info("Initial Profile for %s: %s MB", model_tag, cost) | ||
|
|
||
| def add_observation( | ||
| self, active_snapshot: Dict[str, int], peak_memory: float): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we update the log below to be clear on that? Ideally it would format the dict in a human readable way that makes that clear
|
|
||
| return evicted_something | ||
|
|
||
| def _perform_eviction(self, key, tag, instance, score): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add types to the args here and elsewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
| del self._models[tag][i] | ||
| break | ||
|
|
||
| if hasattr(instance, "trackedModelProxy_unsafe_hard_delete"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this ever not exist?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally was aiming for this model manager to be versatile and can store other objects too but I guess at this point that ship has sailed. We won't need this check anymore. Updated.
| with self._load_lock: | ||
| logger.info("Loading Model: %s (Unknown: %s)", tag, is_unknown) | ||
| isolation_baseline_snap, _, _ = self._monitor.get_stats() | ||
| instance = TrackedModelProxy(loader_func()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is loader_func here - is this spawning the model in process?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep! This depends on what the user pass in which in the RunInference case would be spawning a model in process with the new MultiProcessShared util.
| ticket_num = next(self._ticket_counter) | ||
| my_id = object() | ||
|
|
||
| with self._cv: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is thread contention going to be an issue here/do we need this? If we are using this model_manager in a multi_process_shared setting, the calls to this (and other functions) should be serialized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit worried that this serialization is going to cause broader perf issues
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 I was also worried but as a first step this is kept in this way since model inference can take some time. I originally tried splitting up the read and load lock to not have loading models block normal operations to grab stuff from the idle pool but I think there's some race condition that can happen if we do release model and model loading at the same time, the snapshot taken can be broken in some cases. We can probably figure out a smarter way to unblock this perf issue though, I will add a comment and defer it to later for now.
|
|
||
| with self._cv: | ||
| # FAST PATH | ||
| if self._pending_isolation_count == 0 and not self._isolation_mode: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is self._pending_isolation_count ever non-zero?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! This is some old code from older iterations. Removed now.
| self._total_active_jobs = 0 | ||
| self._pending_reservations = 0.0 | ||
|
|
||
| self._isolation_mode = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add comments describing what these are?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
| It handles: | ||
| 1. LRU Caching of idle models. | ||
| 2. Resource estimation and admission control (preventing OOM). | ||
| 3. Dynamic eviction of low-priority models when space is needed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add more info on what makes a model low-priority here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
| def all_models(self, tag) -> list[Any]: | ||
| return self._models[tag] | ||
|
|
||
| def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is pretty large and I find it hard to follow; could you try breaking it up into subfunctions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! Sorry about this giant change I forgot it can be really hard to read if it's for the first time. Iterating it kind of just grow the code larger and larger but I was under the false impression that it's pretty much the same size as the first draft.
I added more comments and split this function into smaller pieces, hopefully that would help. Let me know if there's anything I can help! Thanks!
…code and cleanup some code logics
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #37113 +/- ##
============================================
- Coverage 40.38% 39.97% -0.42%
Complexity 3476 3476
============================================
Files 1226 1222 -4
Lines 188339 187813 -526
Branches 3607 3607
============================================
- Hits 76058 75074 -984
- Misses 108883 109341 +458
Partials 3398 3398
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
damccorm
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't get to the tests yet, but overall things generally look good. Left some comments below and I'll ask gemini to review as well
| return False | ||
|
|
||
| logger.info("Unknown model %s detected. Flushing GPU.", tag) | ||
| self._delete_all_models() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to delete all models, or would it be good enough to just not feed them records for inference for a bit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, because otherwise we might not be able to fit a copy in. Feel free to ignore this comment
| for key, (tag, instance, release_time) in self._idle_lru.items(): | ||
| candidate_demand = demand_map[tag] | ||
|
|
||
| if not am_i_starving and candidate_demand >= my_demand: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we weight candidate_demand more heavily than my_demand since we already have a copy in memory? (aka try to be more conservative)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is a TODO, feel free to ignore this comment (from live review)
| curr, | ||
| total_model_count, | ||
| other_idle_count) | ||
| self._cv.wait(timeout=10.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the control flow reads more cleanly if we keep this self._cv.wait in acquire_model. As is, it is confusing why we're waiting in should_spawn_model. It is also unexpected for a caller that this might induce a hang, and unclear that they need the cv lock. I'd recommend logging the resource usage, and then logging the wait separately.
It took me a while to figure out why we're waiting (I think basically because resource usage won't change until someone calls notify_all)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the timeout here basically in case GPU memory does change without notify being called somehow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broadly - could you add comments to functions which expect the caller to hold self._cv?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes the timeout is just a safety net for in case we somehow missed the notify call, so it won't be forever stuck waiting and never retry to check usages. Sg! Added some comments.
| est_cost = self._estimator.get_estimate(tag) | ||
| break | ||
| else: | ||
| # We waited, need to re-evaluate our turn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When will acquire_model get called? Is it possible to get stuck trying to acquire a model forever, and should we have some concept of max tries before we just fail and throw?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's possible for it to forever retry, I agree, some sort of max should be enforced and we need to let the user know something is broken. I will add some timeout or max retries count.
| def all_models(self, tag) -> list[Any]: | ||
| return self._models[tag] | ||
|
|
||
| def enter_isolation_mode(self, tag: str, ticket_num: int) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: try_enter_isolation_mode might be clearer since its not guaranteed to happen
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
| heapq.heapify(self._wait_queue) | ||
| self._cv.notify_all() | ||
|
|
||
| if should_spawn: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this ever be false at this point? If not, do we need this variable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, this is used to be inside the loop hence the need for checks but yeah the current implementation won't need it. Fixed now.
| # Remove self from wait queue once done | ||
| if self._wait_queue and self._wait_queue[0][2] is my_id: | ||
| heapq.heappop(self._wait_queue) | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When will this condition be false?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is just defensive coding, that's fine, but please note that (and probably warn)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some comments and warning. This is mostly in case the priority is updated somehow and queue is reshuffled.
| for item in self._wait_queue: | ||
| demand_map[item[3]] += 1 | ||
|
|
||
| my_demand = demand_map[requesting_tag] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we factor in the number of models loaded when considering demand? If I'm reading this right, if I have 5 copies of model A and 3 requests in the queue, vs 3 copies of model B with 2 requests in the queue, we're more likely to evict model B. That doesn't seem right to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that would be a good addition, I will add some TODO comment here and we can probably revamp this balancing of demands altogether in another PR.
| multi_process_shared.MultiProcessShared( | ||
| lambda: "N/A", tag=instance).unsafe_hard_delete() | ||
| if hasattr(instance, 'mock_model_unsafe_hard_delete'): | ||
| # Call the mock unsafe hard delete method for testing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this mock or can we just create the multiprocess shared objects and pass in tags in the test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried creating the multiprocess shared objects earlier but mocking the resource usage is tricky. Because we need to manually call free memory usage on the fake monitor and without calling unsafe_delete the other option is editing __del__ but that doesn't seem to be reliable after testing with it.
| self._monitor.refresh() | ||
| self._monitor.reset_peak() | ||
|
|
||
| def _force_reset(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be private? How are we planning on this being used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I mostly made this for tests, so normal clients should not try to call this.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a sophisticated ModelManager for handling ML models, with a strong focus on GPU memory management. The implementation is well-structured, separating concerns into GPUMonitor, ResourceEstimator, and the ModelManager itself. The logic for concurrency, resource contention, and model eviction is complex but appears robust. The accompanying tests are thorough and cover a wide range of scenarios. My review includes several suggestions to enhance security, code clarity, and efficiency in specific areas.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Added Model Manager as a util class that offers managed access to models, the client can request models without having to worry about managing GPU OOMs.
Also added various tests that checks the functions of all classes.
Classes
GPUMonitorstart(): Begins background memory polling.stop(): Stops polling.reset_peak(): Resets peak usage tracking.get_stats() -> (current, peak, total): Returns memory stats.ResourceEstimatoris_unknown(model_tag: str) -> bool: Checks if model needs profiling.get_estimate(model_tag: str, default_mb: float) -> float: Returns memory cost.set_initial_estimate(model_tag: str, cost: float): Manually sets cost.add_observation(active_snapshot, peak_memory): Updates cost model via NNLS solver.ModelManageracquire_model(tag: str, loader_func: Callable) -> Any: Gets model instance (handles isolation/concurrency).release_model(tag: str, instance: Any): Returns model to pool.force_reset(): Clears all models and caches.shutdown(): Cleans up resources.Thank you for your contribution! Follow this checklist to help us incorporate your contribution quickly and easily:
addresses #123), if applicable. This will automatically add a link to the pull request in the issue. If you would like the issue to automatically close on merging the pull request, commentfixes #<ISSUE NUMBER>instead.CHANGES.mdwith noteworthy changes.See the Contributor Guide for more tips on how to make review process smoother.
To check the build health, please visit https://github.com/apache/beam/blob/master/.test-infra/BUILD_STATUS.md
GitHub Actions Tests Status (on master branch)
See CI.md for more information about GitHub Actions CI or the workflows README to see a list of phrases to trigger workflows.