|
1 | 1 | defmodule Tqdm do |
| 2 | + |
| 3 | + @num_bars 10 |
| 4 | + |
| 5 | + def tqdm(enumerable, opts \\ []) do |
| 6 | + now = :erlang.monotonic_time() |
| 7 | + |
| 8 | + state = %{ |
| 9 | + n: 0, |
| 10 | + last_print_n: 0, |
| 11 | + start_time: now, |
| 12 | + last_print_time: now, |
| 13 | + last_printed_length: 0, |
| 14 | + prefix: Keyword.get(opts, :description, "") |> prefix(), |
| 15 | + total: Keyword.get_lazy(opts, :total, fn -> Enum.count(enumerable) end), |
| 16 | + clear: Keyword.get(opts, :clear, true), |
| 17 | + device: Keyword.get(opts, :device, :stderr), |
| 18 | + min_interval: Keyword.get(opts, :min_interval, 100), |
| 19 | + min_iterations: Keyword.get(opts, :min_iterations, 1) |
| 20 | + } |
| 21 | + |
| 22 | + Stream.transform(enumerable, fn -> state end, &do_tqdm/2, &do_tqdm_after/1) |
| 23 | + end |
| 24 | + |
| 25 | + defp prefix(""), do: "" |
| 26 | + defp prefix(description), do: description <> ": " |
| 27 | + |
| 28 | + defp do_tqdm(element, %{n: 0} = state) do |
| 29 | + {[element], %{print_status(state, :erlang.monotonic_time()) | n: 1}} |
| 30 | + end |
| 31 | + |
| 32 | + defp do_tqdm(element, %{n: n, last_print_n: last_print_n, min_iterations: min_iterations} = state) |
| 33 | + when n - last_print_n < min_iterations, |
| 34 | + do: {[element], %{state | n: n + 1}} |
| 35 | + |
| 36 | + defp do_tqdm(element, %{n: n, last_print_time: last_print_time, min_interval: min_interval} = state) do |
| 37 | + now = :erlang.monotonic_time() |
| 38 | + |
| 39 | + if :erlang.convert_time_unit(now - last_print_time, :native, :milli_seconds) >= min_interval do |
| 40 | + state = %{print_status(state, now) | last_print_n: n, last_print_time: :erlang.monotonic_time()} |
| 41 | + end |
| 42 | + |
| 43 | + {[element], %{state | n: n + 1}} |
| 44 | + end |
| 45 | + |
| 46 | + defp do_tqdm_after(state) do |
| 47 | + state = print_status(state, :erlang.monotonic_time()) |
| 48 | + |
| 49 | + finish = |
| 50 | + if state.clear do |
| 51 | + "\r" <> String.duplicate(" ", String.length(state.prefix) + state.last_printed_length) <> "\r" |
| 52 | + else |
| 53 | + "\n" |
| 54 | + end |
| 55 | + |
| 56 | + IO.write(state.device, finish) |
| 57 | + end |
| 58 | + |
| 59 | + defp print_status(state, now) do |
| 60 | + status = format_status(state, now) |
| 61 | + status_length = String.length(status) |
| 62 | + |
| 63 | + padding = String.duplicate(" ", max(state.last_printed_length - status_length, 0)) |
| 64 | + |
| 65 | + IO.write(state.device, "\r#{state.prefix}#{status}#{padding}") |
| 66 | + |
| 67 | + %{state | last_printed_length: status_length} |
| 68 | + end |
| 69 | + |
| 70 | + defp format_status(%{n: n, total: total, start_time: start_time}, now) do |
| 71 | + elapsed = :erlang.convert_time_unit(now - start_time, :native, :micro_seconds) |
| 72 | + |
| 73 | + total = if n <= total, do: total |
| 74 | + |
| 75 | + elapsed_str = format_interval(elapsed, false) |
| 76 | + |
| 77 | + rate = if elapsed > 0, do: Float.round(n / (elapsed / 1_000_000), 2), else: "?" |
| 78 | + |
| 79 | + if total do |
| 80 | + progress = n / total |
| 81 | + |
| 82 | + num_bars = trunc(progress * @num_bars) |
| 83 | + bar = String.duplicate("#", num_bars) <> String.duplicate("-", @num_bars - num_bars) |
| 84 | + |
| 85 | + percentage = "#{Float.round(progress * 100)}%" |
| 86 | + |
| 87 | + left_str = if n > 0, do: format_interval(elapsed / n * (total - n), true), else: "?" |
| 88 | + |
| 89 | + "|#{bar}| #{n}/#{total} #{percentage} [elapsed: #{elapsed_str} left: #{left_str}, #{rate} iters/sec]" |
| 90 | + else |
| 91 | + "#{n} [elapsed: #{elapsed_str}, #{rate} iters/sec]" |
| 92 | + end |
| 93 | + end |
| 94 | + |
| 95 | + defp format_interval(elapsed, trunc_seconds) do |
| 96 | + minutes = trunc(elapsed / 60_000_000) |
| 97 | + hours = div(minutes, 60) |
| 98 | + rem_minutes = minutes - hours * 60 |
| 99 | + micro_seconds = elapsed - minutes * 60_000_000 |
| 100 | + seconds = micro_seconds / 1_000_000 |
| 101 | + |
| 102 | + if trunc_seconds do |
| 103 | + seconds = trunc(seconds) |
| 104 | + end |
| 105 | + |
| 106 | + hours_str = format_time_component(hours) |
| 107 | + minutes_str = format_time_component(rem_minutes) |
| 108 | + seconds_str = format_time_component(seconds) |
| 109 | + |
| 110 | + "#{hours_str}:#{minutes_str}:#{seconds_str}" |
| 111 | + end |
| 112 | + |
| 113 | + defp format_time_component(time) do |
| 114 | + time_string = to_string(time) |
| 115 | + |
| 116 | + if time < 10 do |
| 117 | + "0" <> time_string |
| 118 | + else |
| 119 | + time_string |
| 120 | + end |
| 121 | + end |
2 | 122 | end |
0 commit comments