-
Notifications
You must be signed in to change notification settings - Fork 441
Expand file tree
/
Copy pathvecenv.h
More file actions
133 lines (112 loc) · 3.4 KB
/
vecenv.h
File metadata and controls
133 lines (112 loc) · 3.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#ifndef PUFFERLIB_VECENV_H
#define PUFFERLIB_VECENV_H
#include <assert.h>
#include <stdlib.h>
#include <string.h>
#include <pthread.h>
#ifndef __cplusplus
#include <stdatomic.h>
#endif
#include <cuda_runtime.h>
#define FLOAT 1
#define INT 2
#define UNSIGNED_CHAR 3
#define DOUBLE 4
typedef struct {
const char* key;
double value;
void* ptr;
} DictItem;
typedef struct {
DictItem* items;
int size;
int capacity;
} Dict;
typedef struct Env Env;
typedef struct Threading Threading;
typedef struct {
Env* envs;
int size;
float* observations;
double* actions;
float* rewards;
float* terminals;
float* gpu_observations;
double* gpu_actions;
float* gpu_rewards;
float* gpu_terminals;
Threading* threading;
cudaStream_t* streams;
int buffers;
} VecEnv;
Dict* create_dict(int capacity) {
Dict* dict = (Dict*)calloc(1, sizeof(Dict));
dict->capacity = capacity;
dict->items = (DictItem*)calloc(capacity, sizeof(DictItem));
return dict;
}
DictItem* dict_get_unsafe(Dict* dict, const char* key) {
for (int i = 0; i < dict->size; i++) {
if (strcmp(dict->items[i].key, key) == 0) {
return &dict->items[i];
}
}
return NULL;
}
DictItem* dict_get(Dict* dict, const char* key) {
DictItem* item = dict_get_unsafe(dict, key);
assert(item != NULL && "dict_get failed to find key");
return item;
}
void dict_set(Dict* dict, const char* key, double value) {
assert(dict->size < dict->capacity);
DictItem* item = dict_get_unsafe(dict, key);
if (item != NULL) {
item->value = value;
return;
}
dict->items[dict->size].key = key;
dict->items[dict->size].value = value;
dict->size++;
}
void dict_set_int(Dict* dict, const char* key, int value) {
dict_set(dict, key, (double)value);
}
void dict_set_ptr(Dict* dict, const char* key, void* ptr) {
assert(dict->size < dict->capacity);
DictItem* item = dict_get_unsafe(dict, key);
if (item != NULL) {
item->ptr = ptr;
return;
}
dict->items[dict->size].key = key;
dict->items[dict->size].ptr = ptr;
dict->size++;
}
void* my_shared(Env* env, Dict* kwargs);
void my_shared_close(Env* env);
void* my_get(Env* env, Dict* out);
int my_put(Env* env, Dict* kwargs);
typedef struct Log Log;
void my_log(Log* log, Dict* out);
// Sharp bit (puffers have spikes)
// Define function types to be exported to the shared library
// You don't need these, but you have to do some really gross
// casts after loading the library without them.
typedef VecEnv* (*create_environments_fn)(int num_envs, int buffers, bool use_gpu, int test_idx, Dict* kwargs);
typedef Env* (*env_init_fn)(float* observations, double* actions, float* rewards,
float* terminals, int seed, Dict* kwargs);
typedef void (*create_threads_fn)(VecEnv* vec, int threads, int block_size);
typedef void (*vec_reset_fn)(VecEnv* vec);
typedef void (*vec_step_fn)(VecEnv* vec);
typedef void (*vec_recv_fn)(VecEnv* vec, int buffer);
typedef void (*vec_send_fn)(VecEnv* vec, int buffer);
typedef void (*env_close_fn)(Env* env);
typedef void (*vec_close_fn)(VecEnv* vec);
typedef void (*vec_render_fn)(VecEnv* vec, int env_idx);
typedef void (*vec_log_fn)(VecEnv* vec, Dict* out);
typedef void (*c_reset_fn)(Env* env);
typedef void (*c_step_fn)(Env* env);
typedef void (*c_close_fn)(Env* env);
typedef void (*c_render_fn)(Env* env);
#endif // PUFFERLIB_VECENV_H