diff --git a/Makefile b/Makefile index 2e7b676..ed126aa 100644 --- a/Makefile +++ b/Makefile @@ -1,19 +1,29 @@ -all : bin/freq01cpp bin/freq02cpp bin/freq03cpp bin/freq01rs bin/hack01cpp bin/freq01go bin/freq01scala.jar +CXX := g++ +CXXFLAGS ?= -O3 + +ifeq ($(shell uname -s),Darwin) +CXXFLAGS += -Dlseek64=lseek -DMAP_POPULATE=0 +endif + +all : bin/freq01cpp bin/freq02cpp bin/freq03cpp bin/freq04cpp bin/freq01rs bin/hack01cpp bin/freq01go bin/freq01scala.jar clean: - rm -f bin/freq01cpp bin/freq02cpp bin/freq03cpp bin/freq01rs bin/hack01cpp bin/freq01go + rm -f bin/freq01cpp bin/freq02cpp bin/freq03cpp bin/freq04cpp bin/freq01rs bin/hack01cpp bin/freq01go bin/freq01cpp: src/freq01.cpp - g++ -O3 -o bin/freq01cpp src/freq01.cpp + $(CXX) $(CXXFLAGS) -o bin/freq01cpp src/freq01.cpp bin/freq02cpp: src/freq02.cpp - g++ -O3 -o bin/freq02cpp src/freq02.cpp + $(CXX) $(CXXFLAGS) -o bin/freq02cpp src/freq02.cpp bin/freq03cpp: src/freq03.cpp - g++ -O3 -o bin/freq03cpp src/freq03.cpp + $(CXX) $(CXXFLAGS) -o bin/freq03cpp src/freq03.cpp + +bin/freq04cpp: src/freq04.cpp + $(CXX) $(CXXFLAGS) -std=c++11 -o bin/freq04cpp src/freq04.cpp bin/hack01cpp: src/hack01.cpp - g++ -O3 -o bin/hack01cpp src/hack01.cpp + $(CXX) $(CXXFLAGS) -o bin/hack01cpp src/hack01.cpp bin/freq01rs: src/freq01.rs cd build/rust && cargo build --release diff --git a/bench.py b/bench.py index 0841030..6daeea5 100755 --- a/bench.py +++ b/bench.py @@ -18,6 +18,7 @@ def run1(args, src_name, num_runs): runs = [ [['java', '-jar', './bin/freq01scala.jar'], 'freq01.scala', 3], [['python', './src/freq01.py'], 'freq01.py', 3], + [['./bin/freq04cpp' + EXE], 'freq04.cpp'], [['./bin/freq03cpp' + EXE], 'freq03.cpp'], [['./bin/freq02cpp' + EXE], 'freq02.cpp'], [['./bin/freq01cpp' + EXE], 'freq01.cpp'], diff --git a/mk.cmd b/mk.cmd index ff772a5..ea14252 100644 --- a/mk.cmd +++ b/mk.cmd @@ -7,6 +7,9 @@ call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Auxiliary echo === freq03.cpp cl.exe /permissive- /GS /GL /W3 /Gy /Zc:wchar_t /Gm- /O2 /sdl /Zc:inline /fp:precise /D "_MBCS" /errorReport:prompt /WX- /Zc:forScope /Gd /Oi /MD /FC /EHsc /nologo /diagnostics:column /Fo"bin\\" /Fe"bin\freq03cpp.exe" src\freq03.cpp src\ext\windows-mmap.c +echo === freq04.cpp +cl.exe /permissive- /GS /GL /W3 /Gy /Zc:wchar_t /Gm- /O2 /sdl /Zc:inline /fp:precise /D "_MBCS" /errorReport:prompt /WX- /Zc:forScope /Gd /Oi /MD /FC /EHsc /nologo /diagnostics:column /Fo"bin\\" /Fe"bin\freq04cpp.exe" src\freq04.cpp + echo === freq01.cpp cl.exe /permissive- /GS /GL /W3 /Gy /Zc:wchar_t /Gm- /O2 /sdl /Zc:inline /fp:precise /D "_MBCS" /errorReport:prompt /WX- /Zc:forScope /Gd /Oi /MD /FC /EHsc /nologo /diagnostics:column /Fo"bin\\" /Fe"bin\freq01cpp.exe" src\freq01.cpp src\ext\windows-mmap.c diff --git a/src/freq04.cpp b/src/freq04.cpp new file mode 100644 index 0000000..10cd2eb --- /dev/null +++ b/src/freq04.cpp @@ -0,0 +1,891 @@ +#define _CRT_SECURE_NO_WARNINGS + +#include +#include +#include +#include +#include +#include + +#if defined(__ARM_NEON) || defined(__ARM_NEON__) || defined(__aarch64__) || defined(_M_ARM64) +#define FREQ_HAS_ARM_NEON 1 +#else +#define FREQ_HAS_ARM_NEON 0 +#endif + +#if defined(__SSE2__) || defined(_M_X64) || \ + (defined(_M_IX86_FP) && _M_IX86_FP >= 2) +#define FREQ_HAS_X86_SSE2 1 +#else +#define FREQ_HAS_X86_SSE2 0 +#endif + +#ifndef FREQ_USE_NEON +#if FREQ_HAS_ARM_NEON +#define FREQ_USE_NEON 1 +#else +#define FREQ_USE_NEON 0 +#endif +#endif + +#ifndef FREQ_USE_SSE2 +#if !FREQ_USE_NEON && FREQ_HAS_X86_SSE2 +#define FREQ_USE_SSE2 1 +#else +#define FREQ_USE_SSE2 0 +#endif +#endif + +#if FREQ_USE_NEON +#if !FREQ_HAS_ARM_NEON +#error "FREQ_USE_NEON requires an ARM NEON target" +#endif +#if defined(_MSC_VER) && defined(_M_ARM64) && !defined(__clang__) +#include +#else +#include +#endif +#endif + +#if FREQ_USE_SSE2 +#if !FREQ_HAS_X86_SSE2 +#error "FREQ_USE_SSE2 requires an x86 SSE2 target" +#endif +#include +#endif + +#if FREQ_USE_NEON && FREQ_USE_SSE2 +#error "Enable only one SIMD scanner" +#endif + +#if FREQ_USE_NEON || FREQ_USE_SSE2 +#define FREQ_USE_SIMD 1 +#else +#define FREQ_USE_SIMD 0 +#endif + +#if defined(_WIN32) || defined(WIN32) +#define FREQ_WINDOWS 1 +#ifndef NOMINMAX +#define NOMINMAX +#endif +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#include +#if defined(_MSC_VER) +#include +#endif +#include +#define open _open +#define close _close +#ifndef O_BINARY +#define O_BINARY _O_BINARY +#endif +#if defined(_MSC_VER) +#pragma warning(disable : 4996) +#endif +#else +#define FREQ_WINDOWS 0 +#include +#include +#ifndef O_BINARY +#define O_BINARY 0 +#endif +#ifndef MAP_POPULATE +#define MAP_POPULATE 0 +#endif +#endif + +#if !defined(_MSC_VER) && !defined(__forceinline) +#define __forceinline __attribute__((always_inline)) inline +#endif + +typedef unsigned char byte; +typedef unsigned int uint; +typedef unsigned long long ullong; + +/* +Performance changes relative to freq03.cpp: + +- Keep the original cache-dense 16-byte hash entry, but let it store either a + long key pointer or an inline short key. Words up to 8 bytes avoid the key-pool + allocation/dereference on hash-table hits. +- Count words of length 1..5 in a direct-address table using 5 bits per letter. + This removes hashing, probing, and string comparison from the shortest and + most frequent natural-language tokens; a touched-index list keeps dump from + scanning the whole direct table. +- Preserve the simple linear-probing hash table, but split the lookup path into + acquire_short() and acquire_long(). Inline words avoid long-string checks, and + long words avoid short-key comparisons. +- Avoid recomputing information the tokenizer already has: long-word insertion + receives the known length, complete inline words use a hash derived from the + packed short key, and complete very-short words build their direct-table code + with a packed 64-bit helper when lookahead is safe. +- Classify input in 64-byte SIMD blocks on supported targets. The scanner builds + a 64-bit "is ASCII letter" mask (NEON on ARM, SSE2 on x86_64), then iterates + letter runs with ctz. Complete in-block words bypass the pending lowercase + buffer; only boundary-crossing words and longer words use it. +- Dump only live entries. Before sorting, used hash slots plus touched direct + counters are compacted to a dense prefix, so sort work tracks unique word + count rather than the 512K-slot table capacity. +- Use a 1 MiB buffered writer and custom integer formatting instead of fprintf + per output row. + +The refactor also adds portability scaffolding for macOS/Linux/Windows and +ARM/x86 SIMD selection; those pieces are for build coverage, not primary speed. +*/ + +#if FREQ_USE_SIMD +static __forceinline int ctz64(ullong x) { + // Callers guarantee x != 0. +#if defined(_MSC_VER) + unsigned long idx; +#if defined(_M_X64) || defined(_M_ARM64) + _BitScanForward64(&idx, x); + return (int)idx; +#else + if (_BitScanForward(&idx, (unsigned long)x)) + return (int)idx; + _BitScanForward(&idx, (unsigned long)(x >> 32)); + return (int)idx + 32; +#endif +#else + return __builtin_ctzll(x); +#endif +} +#endif + +static const uint FNV_PRIME = 0x1000193u; +#if FREQ_USE_SIMD +static const ullong LOWERCASE_BITS = 0x2020202020202020ull; +static const ullong ASCII_A_BYTES = 0x6161616161616161ull; +#endif +static const ullong SHORT_HASH_MIX = 11400714819323198485ull; +static const uint HASH_FLAG_SHORT = 0x80000000u; +static const uint HASH_BITS = ~HASH_FLAG_SHORT; + +static bool get_file_size(int fd, ullong &size) { +#if FREQ_WINDOWS + const __int64 end = _lseeki64(fd, 0, SEEK_END); +#else + const off_t end = lseek(fd, 0, SEEK_END); +#endif + if (end < 0) + return false; + size = (ullong)end; + return true; +} + +static void *map_readonly_file(int fd, ullong size) { + if (!size) + return NULL; + +#if FREQ_WINDOWS + HANDLE file = (HANDLE)_get_osfhandle(fd); + if (file == INVALID_HANDLE_VALUE) + return NULL; + + HANDLE mapping = CreateFileMappingA(file, NULL, PAGE_READONLY, DWORD(size >> 32), + DWORD(size), NULL); + if (!mapping) + return NULL; + + void *p = MapViewOfFile(mapping, FILE_MAP_READ, 0, 0, 0); + CloseHandle(mapping); + return p; +#else + void *p = mmap(NULL, size, PROT_READ, MAP_PRIVATE | MAP_POPULATE, fd, 0); + return (p == MAP_FAILED) ? NULL : p; +#endif +} + +static void unmap_readonly_file(const void *p, ullong size) { + if (!p || !size) + return; + +#if FREQ_WINDOWS + UnmapViewOfFile(p); +#else + munmap((void *)p, size); +#endif +} + +// Words of length 1..5 are direct-addressed as 5 bits per letter. +static const int SMALL1 = 1 << 5; +static const int SMALL2 = 1 << 10; +static const int SMALL3 = 1 << 15; +static const int SMALL4 = 1 << 20; +static const int SMALL5 = 1 << 25; +static const int SMALL_TOTAL = SMALL1 + SMALL2 + SMALL3 + SMALL4 + SMALL5; +static const int SMALL_OFFSET_BY_LEN[6] = { + 0, + 0, + SMALL1, + SMALL1 + SMALL2, + SMALL1 + SMALL2 + SMALL3, + SMALL1 + SMALL2 + SMALL3 + SMALL4, +}; +#if FREQ_USE_SIMD +static const uint SMALL_CODE_MASK[6] = { + 0, + (1u << 5) - 1, + (1u << 10) - 1, + (1u << 15) - 1, + (1u << 20) - 1, + (1u << 25) - 1, +}; +#endif + +// Words up to length 8 can live inline in entry::shortkey. Bytes 0..7 hold +// lowercase ASCII letters; their high bits are free, so bytes 0..3 also +// carry the 4-bit length value. +static const ullong SHORT_MASK[9] = { + 0, + 0x00000000000000ffull, + 0x000000000000ffffull, + 0x0000000000ffffffull, + 0x00000000ffffffffull, + 0x000000ffffffffffull, + 0x0000ffffffffffffull, + 0x00ffffffffffffffull, + 0xffffffffffffffffull, +}; +static const ullong SHORT_LEN_FLAG[9] = { + 0, + 0x0000000000000080ull, + 0x0000000000008000ull, + 0x0000000000008080ull, + 0x0000000000800000ull, + 0x0000000000800080ull, + 0x0000000000808000ull, + 0x0000000000808080ull, + 0x0000000080000000ull, +}; + +static __forceinline uint hash_word(const byte *key, int len) { + uint h = 0; + for (int i = 0; i < len; i++) + h = (h ^ key[i]) * FNV_PRIME; + return h; +} + +// Input is already lowercase. +static __forceinline ullong pack_shortkey_prelowered(const byte *key, int len) { + ullong x = 0; + memcpy(&x, key, 8); + return (x & SHORT_MASK[len]) | SHORT_LEN_FLAG[len]; +} + +#if FREQ_USE_SIMD +// Input may be mixed-case; lowercase while packing. +static __forceinline ullong pack_shortkey_lowercase(const byte *key, int len) { + ullong x = 0; + memcpy(&x, key, 8); + x |= LOWERCASE_BITS; + return (x & SHORT_MASK[len]) | SHORT_LEN_FLAG[len]; +} + +static __forceinline int make_small_code_scalar(const byte *key, int len) { + int code = 0; + for (int i = 0; i < len; i++) + code |= ((key[i] | 32) - 'a') << (i * 5); + return code; +} + +// Caller must guarantee that reading 8 bytes from key stays inside the mapped block. +static __forceinline int make_small_code_lower8(const byte *key, int len) { + ullong x = 0; + memcpy(&x, key, 8); + x = (x | LOWERCASE_BITS) - ASCII_A_BYTES; + const uint code = uint((x & 0x1full) | ((x >> 3) & 0x3e0ull) | + ((x >> 6) & 0x7c00ull) | ((x >> 9) & 0xf8000ull) | + ((x >> 12) & 0x1f00000ull)); + return code & SMALL_CODE_MASK[len]; +} +#endif + +static __forceinline uint hash_shortkey(ullong shortkey) { + const ullong mixed = shortkey * SHORT_HASH_MIX; + return uint(mixed ^ (mixed >> 32)); +} + +struct entry { + union { + const char *key = NULL; // long key pointer + ullong shortkey; // inline key for len <= 8 + }; + uint hash = 0; // top bit is short-key flag; lower bits are hash + int value = 0; + + // For short entries, shortkey overlays key and includes a length flag, + // so the union value is non-null just like a real key pointer. + bool used() const { return key != NULL; } + bool is_short() const { return (hash & HASH_FLAG_SHORT) != 0; } + int short_len() const { + return int(((shortkey >> 7) & 1) | ((shortkey >> 14) & 2) | + ((shortkey >> 21) & 4) | ((shortkey >> 28) & 8)); + } + byte short_char(int i) const { return byte(shortkey >> (i * 8)) & 0x7f; } + + bool operator<(const entry &b) const; +}; + +static_assert(sizeof(entry) == 16, "entry must stay cache-dense"); + +static int cmp_short_short(const entry &a, const entry &b) { + const int alen = a.short_len(); + const int blen = b.short_len(); + const int m = (alen < blen) ? alen : blen; + for (int i = 0; i < m; i++) { + const byte ac = a.short_char(i); + const byte bc = b.short_char(i); + if (ac != bc) + return (ac < bc) ? -1 : 1; + } + return (alen > blen) - (alen < blen); +} + +static int cmp_short_long(const entry &a, const char *b) { + const int alen = a.short_len(); + for (int i = 0; i < alen; i++) { + const byte ac = a.short_char(i); + const byte bc = byte(b[i]); + if (ac != bc) + return (ac < bc) ? -1 : 1; + } + return b[alen] ? -1 : 0; +} + +bool entry::operator<(const entry &b) const { + if (value != b.value) + return value > b.value; + if (!value) + return false; + + if (!is_short() && !b.is_short()) + return strcmp(key, b.key) < 0; + if (is_short() && b.is_short()) + return cmp_short_short(*this, b) < 0; + if (is_short()) + return cmp_short_long(*this, b.key) < 0; + return cmp_short_long(b, key) > 0; +} + +struct keychunk { + char *buffer; // storage space + char *curptr; // current pointer + char *maxptr; // max current pointer, exclusive + keychunk *next; // next chunk in the list + + explicit keychunk(int size, keychunk *anext) { + buffer = new char[size]; + curptr = buffer; + maxptr = buffer + size; + next = anext; + } + + ~keychunk() { + delete[] buffer; + delete next; + } +}; + +class outbuf { + private: + // 1 MiB output buffer; outbuf is instantiated once during dump(). + static const int OUT_SIZE = 1 << 20; + FILE *out; + char buf[OUT_SIZE]; + int pos = 0; + + public: + explicit outbuf(FILE *aout) : out(aout) {} + ~outbuf() { flush(); } + + void flush() { + if (pos) { + fwrite(buf, 1, pos, out); + pos = 0; + } + } + + void put(char c) { + if (pos == OUT_SIZE) + flush(); + buf[pos++] = c; + } + + void write(const char *p, int len) { + if (pos + len > OUT_SIZE) + flush(); + memcpy(buf + pos, p, len); + pos += len; + } + + void write_int(int v) { + char tmp[16]; + int n = 0; + do { + tmp[n++] = char('0' + (v % 10)); + v /= 10; + } while (v); + while (n--) + put(tmp[n]); + } +}; + +class strhash { + private: + static constexpr int INITIAL_SIZE = 512 * 1024; // MUST be a power of 2 + static constexpr int KEYS_CHUNK = 1024 * 1024; // no power of 2 restrictions here + static constexpr int MAX_LOAD_NUM = 9; + static constexpr int MAX_LOAD_DEN = 10; + + entry *data; + int mask = INITIAL_SIZE - 1; + int capacity = INITIAL_SIZE; + int used_count = 0; + int maxused = capacity * MAX_LOAD_NUM / MAX_LOAD_DEN; + keychunk *keys; + + public: + strhash() { + data = new entry[capacity]; + keys = new keychunk(KEYS_CHUNK, NULL); + } + + ~strhash() { + delete[] data; + delete keys; + } + + int &acquire_short(uint hash, ullong shortkey); + int &acquire_long(const char *key, int len, uint hash); + void dump(FILE *out, const int *small_counts, const std::vector &small_touched); + + private: + int &add_short(int i, uint hash, ullong shortkey); + int &add_long(int i, uint hash, const char *key, int len); + const char *dup4(const char *key, int len); + void grow(); + void append_small_word(const byte *word, int len, int value, int &slot); +}; + +static bool streq4(const char *aa, const char *bb) { + // Keys are padded with at least four zero bytes. Mainstream Linux, + // Windows, and macOS targets for this benchmark handle unaligned u32 loads. + const uint *a = (const uint *)aa; + const uint *b = (const uint *)bb; + while ((*a >> 24) && *a == *b) { + a++; + b++; + } + return *a == *b; +} + +__forceinline int &strhash::acquire_short(uint hash, ullong shortkey) { + const uint hbits = hash & HASH_BITS; + int i = hbits & mask; + while (data[i].used()) { + if ((data[i].hash & HASH_BITS) == hbits && data[i].is_short() && + data[i].shortkey == shortkey) + return data[i].value; + i = (i + 1) & mask; + } + return add_short(i, hash, shortkey); +} + +__forceinline int &strhash::acquire_long(const char *key, int len, uint hash) { + const uint hbits = hash & HASH_BITS; + int i = hbits & mask; + while (data[i].used()) { + if ((data[i].hash & HASH_BITS) == hbits && !data[i].is_short() && + streq4(data[i].key, key)) + return data[i].value; + i = (i + 1) & mask; + } + return add_long(i, hash, key, len); +} + +int &strhash::add_short(int i, uint hash, ullong shortkey) { + assert(!data[i].used()); + data[i].hash = (hash & HASH_BITS) | HASH_FLAG_SHORT; + data[i].value = 0; + data[i].shortkey = shortkey; + + if (++used_count <= maxused) + return data[i].value; + + grow(); + return acquire_short(hash, shortkey); +} + +int &strhash::add_long(int i, uint hash, const char *key, int len) { + assert(!data[i].used()); + data[i].hash = hash & HASH_BITS; + data[i].value = 0; + data[i].key = dup4(key, len); + + if (++used_count <= maxused) + return data[i].value; + + grow(); + return acquire_long(key, len, hash); +} + +const char *strhash::dup4(const char *key, int len) { + int gap = 4 - (len & 3); + int lg = len + gap; + + if (keys->curptr + lg > keys->maxptr) + keys = new keychunk((lg > KEYS_CHUNK) ? lg : KEYS_CHUNK, keys); + + assert(keys->curptr + lg <= keys->maxptr); + uint zero = 0; + memcpy(keys->curptr, key, len); + memcpy(keys->curptr + len, &zero, gap); + keys->curptr += lg; + return keys->curptr - lg; +} + +void strhash::grow() { + entry *newdata = new entry[2 * capacity]; + mask = 2 * capacity - 1; + + for (int i = 0; i < capacity; i++) + if (data[i].used()) { + int j = (data[i].hash & HASH_BITS) & mask; + while (newdata[j].used()) + j = (j + 1) & mask; + newdata[j] = data[i]; + } + + delete[] data; + data = newdata; + capacity *= 2; + maxused = capacity * MAX_LOAD_NUM / MAX_LOAD_DEN; +} + +void strhash::append_small_word(const byte *word, int len, int value, int &slot) { + if (used_count >= capacity) + grow(); + + while (data[slot].used()) + slot++; + + data[slot].shortkey = pack_shortkey_prelowered(word, len); + data[slot].hash = HASH_FLAG_SHORT; + data[slot].value = value; + used_count++; +} + +void strhash::dump(FILE *out, const int *small_counts, const std::vector &small_touched) { + byte word[8] = {}; + int slot = 0; + for (int idx : small_touched) { + int len = 5; + while (idx < SMALL_OFFSET_BY_LEN[len]) + len--; + const int code = idx - SMALL_OFFSET_BY_LEN[len]; + + for (int i = 0; i < len; i++) + word[i] = byte('a' + ((code >> (i * 5)) & 31)); + append_small_word(word, len, small_counts[idx], slot); + } + + int live_count = 0; + for (int i = 0; i < capacity; i++) { + if (data[i].used()) + data[live_count++] = data[i]; + } + std::sort(data, data + live_count); + + outbuf wr(out); + char sbuf[8]; + for (int i = 0; i < live_count; i++) { + const entry &e = data[i]; + wr.write_int(e.value); + wr.put(' '); + if (e.is_short()) { + const int len = e.short_len(); + for (int j = 0; j < len; j++) + sbuf[j] = char(e.short_char(j)); + wr.write(sbuf, len); + } else { + wr.write(e.key, (int)strlen(e.key)); + } + wr.put('\n'); + } +} + +static __forceinline void count_small_word(int *small_counts, std::vector &small_touched, + int len, int code) { + const int idx = SMALL_OFFSET_BY_LEN[len] + code; + if (small_counts[idx]++ == 0) + small_touched.push_back(idx); +} + +static __forceinline void count_shortkey_word(strhash &freqs, ullong shortkey) { + freqs.acquire_short(hash_shortkey(shortkey), shortkey)++; +} + +static __forceinline void count_buffered_word(strhash &freqs, byte *word, int len, uint hash) { + if (len <= 8) { + count_shortkey_word(freqs, pack_shortkey_prelowered(word, len)); + } else { + freqs.acquire_long((char *)word, len, hash)++; + } +} + +static __forceinline void finish_pending_word(strhash &freqs, byte *wbuf, byte *&wcur, + uint &h, int &small_code, int *small_counts, + std::vector &small_touched) { + if (wcur == wbuf) + return; + + const int len = (int)(wcur - wbuf); + *(uint *)wcur = 0; + if (len <= 5) { + count_small_word(small_counts, small_touched, len, small_code); + } else { + count_buffered_word(freqs, wbuf, len, h); + } + wcur = wbuf; + h = 0; + small_code = 0; +} + +static __forceinline void append_letter_run(const byte *src, int count, byte *wbuf, + byte *&wcur, byte *wmax, uint &h, + int &small_code) { + for (int i = 0; i < count; i++) { + if (wcur < wmax) { + const byte c = byte(src[i] | 32); + const int len = (int)(wcur - wbuf); + *wcur++ = c; + if (len < 5) { + small_code |= (c - 'a') << (len * 5); + } else if (len == 5) { + h = hash_word(wbuf, 6); + } else { + h = (h ^ c) * FNV_PRIME; + } + } + } +} + +#if FREQ_USE_SIMD +static __forceinline void count_complete_word(strhash &freqs, const byte *src, int len, + byte *wbuf, int *small_counts, + std::vector &small_touched, + bool has_8byte_lookahead) { + if (len <= 5) { + const int code = has_8byte_lookahead ? make_small_code_lower8(src, len) + : make_small_code_scalar(src, len); + count_small_word(small_counts, small_touched, len, code); + return; + } + + if (len <= 8 && has_8byte_lookahead) { + count_shortkey_word(freqs, pack_shortkey_lowercase(src, len)); + return; + } + + if (len <= 8) { + for (int i = 0; i < len; i++) + wbuf[i] = byte(src[i] | 32); + *(uint *)(wbuf + len) = 0; + count_shortkey_word(freqs, pack_shortkey_prelowered(wbuf, len)); + return; + } + + uint h = 0; + for (int i = 0; i < len; i++) { + const byte c = byte(src[i] | 32); + wbuf[i] = c; + h = (h ^ c) * FNV_PRIME; + } + *(uint *)(wbuf + len) = 0; + count_buffered_word(freqs, wbuf, len, h); +} + +static __forceinline ullong letter_mask64(const byte *src) { +#if FREQ_USE_NEON + const uint8x16_t bit_mask = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80}; + const uint8x16_t lower_a = vdupq_n_u8('a'); + const uint8x16_t z_delta = vdupq_n_u8(25); + const uint8x16_t lower_bit = vdupq_n_u8(0x20); + + const uint8x16_t d0 = vld1q_u8(src); + const uint8x16_t d1 = vld1q_u8(src + 16); + const uint8x16_t d2 = vld1q_u8(src + 32); + const uint8x16_t d3 = vld1q_u8(src + 48); + + const uint8x16_t m0 = + vcleq_u8(vsubq_u8(vorrq_u8(d0, lower_bit), lower_a), z_delta); + const uint8x16_t m1 = + vcleq_u8(vsubq_u8(vorrq_u8(d1, lower_bit), lower_a), z_delta); + const uint8x16_t m2 = + vcleq_u8(vsubq_u8(vorrq_u8(d2, lower_bit), lower_a), z_delta); + const uint8x16_t m3 = + vcleq_u8(vsubq_u8(vorrq_u8(d3, lower_bit), lower_a), z_delta); + + uint8x16_t sum0 = vpaddq_u8(vandq_u8(m0, bit_mask), vandq_u8(m1, bit_mask)); + uint8x16_t sum1 = vpaddq_u8(vandq_u8(m2, bit_mask), vandq_u8(m3, bit_mask)); + sum0 = vpaddq_u8(sum0, sum1); + sum0 = vpaddq_u8(sum0, sum0); + return vgetq_lane_u64(vreinterpretq_u64_u8(sum0), 0); +#elif FREQ_USE_SSE2 + const __m128i lower_a = _mm_set1_epi8('a'); + const __m128i z_delta = _mm_set1_epi8(25); + const __m128i lower_bit = _mm_set1_epi8(0x20); + const __m128i zero = _mm_setzero_si128(); + + const __m128i d0 = _mm_loadu_si128((const __m128i *)(src)); + const __m128i d1 = _mm_loadu_si128((const __m128i *)(src + 16)); + const __m128i d2 = _mm_loadu_si128((const __m128i *)(src + 32)); + const __m128i d3 = _mm_loadu_si128((const __m128i *)(src + 48)); + + const __m128i m0 = + _mm_cmpeq_epi8(_mm_subs_epu8(_mm_sub_epi8(_mm_or_si128(d0, lower_bit), lower_a), + z_delta), + zero); + const __m128i m1 = + _mm_cmpeq_epi8(_mm_subs_epu8(_mm_sub_epi8(_mm_or_si128(d1, lower_bit), lower_a), + z_delta), + zero); + const __m128i m2 = + _mm_cmpeq_epi8(_mm_subs_epu8(_mm_sub_epi8(_mm_or_si128(d2, lower_bit), lower_a), + z_delta), + zero); + const __m128i m3 = + _mm_cmpeq_epi8(_mm_subs_epu8(_mm_sub_epi8(_mm_or_si128(d3, lower_bit), lower_a), + z_delta), + zero); + + return ullong(uint(_mm_movemask_epi8(m0))) | + (ullong(uint(_mm_movemask_epi8(m1))) << 16) | + (ullong(uint(_mm_movemask_epi8(m2))) << 32) | + (ullong(uint(_mm_movemask_epi8(m3))) << 48); +#endif +} + +static __forceinline void process_letter_mask64(strhash &freqs, const byte *src, ullong mask, + byte *wbuf, byte *&wcur, byte *wmax, uint &h, + int &small_code, int *small_counts, + std::vector &small_touched) { + int pos = 0; + while (pos < 64) { + const ullong bits = mask >> pos; + if ((bits & 1) == 0) { + finish_pending_word(freqs, wbuf, wcur, h, small_code, small_counts, + small_touched); + if (!bits) + break; + pos += ctz64(bits); + continue; + } + + int ones = (bits == ~0ull) ? 64 : ctz64(~bits); + const int remaining = 64 - pos; + if (ones > remaining) + ones = remaining; + + const bool complete_in_block = (pos + ones) < 64; + if (wcur == wbuf && complete_in_block) { + count_complete_word(freqs, src + pos, ones, wbuf, small_counts, + small_touched, pos + 8 <= 64); + } else { + append_letter_run(src + pos, ones, wbuf, wcur, wmax, h, small_code); + if (complete_in_block) + finish_pending_word(freqs, wbuf, wcur, h, small_code, small_counts, + small_touched); + } + pos += ones; + } +} +#endif + +int main(int argc, char **argv) { + if (argc != 3) { + printf("usage: freq \n"); + return 0; + } + + int fd = open(argv[1], O_RDONLY | O_BINARY); + if (fd < 0) { + printf("FATAL: failed to read %s", argv[1]); + return 1; + } + FILE *fp2 = fopen(argv[2], "wb+"); + if (!fp2) { + printf("FATAL: failed to write %s", argv[2]); + close(fd); + return 1; + } + + ullong fsz = 0; + if (!get_file_size(fd, fsz)) { + printf("FATAL: failed to size %s", argv[1]); + fclose(fp2); + close(fd); + return 1; + } + + void *mapped = map_readonly_file(fd, fsz); + if (fsz && !mapped) { + printf("FATAL: failed to map %s", argv[1]); + fclose(fp2); + close(fd); + return 1; + } + const byte *fbegin = (const byte *)mapped; + + strhash freqs; + // Direct-address table for 1..5 letter words. This trades 128MB of + // zeroed storage for collision-free counting of the shortest tokens. + int *small_counts = new int[SMALL_TOTAL](); + std::vector small_touched; + small_touched.reserve(65536); + + const int WORDBUF = 256; + byte wbuf[WORDBUF + 8] = {}; + byte *wmax = wbuf + WORDBUF; + byte *wcur = wbuf; + + uint h = 0; + int small_code = 0; + if (fsz) { + const byte *rcur = fbegin; + const byte *rmax = fbegin + fsz; +#if FREQ_USE_SIMD + while (rcur + 64 <= rmax) { + const ullong mask = letter_mask64(rcur); + process_letter_mask64(freqs, rcur, mask, wbuf, wcur, wmax, h, small_code, + small_counts, small_touched); + rcur += 64; + } +#endif + while (rcur < rmax) { + byte c = byte(*rcur++ | 32); + if ((uint)(c - 'a') > 25u) { + finish_pending_word(freqs, wbuf, wcur, h, small_code, small_counts, + small_touched); + } else { + append_letter_run(&c, 1, wbuf, wcur, wmax, h, small_code); + } + } + } + + finish_pending_word(freqs, wbuf, wcur, h, small_code, small_counts, small_touched); + + freqs.dump(fp2, small_counts, small_touched); + delete[] small_counts; + unmap_readonly_file(mapped, fsz); + fclose(fp2); + + close(fd); + return 0; +}