Files
greq/internal/api/lua_test_runner.go
Mester Gábor 2dd6519168 Initial commit
2026-03-19 07:12:03 +01:00

168 lines
4.6 KiB
Go

package api
import (
"encoding/json"
"fmt"
"strings"
lua "github.com/yuin/gopher-lua"
)
// TestResult holds the outcome of a single test assertion.
type TestResult struct {
Pass bool
Message string
}
// RunLuaTestScript executes a Lua script against a Response.
//
// Globals available in the script:
//
// status (number) HTTP status code, e.g. 200
// status_text (string) Full status string, e.g. "200 OK"
// body (string) Raw response body
// headers (table) Response headers, lower-cased keys
// json_body (table) Parsed JSON body, or nil if not JSON
// duration_ms (number) Round-trip duration in milliseconds
// size (number) Response body size in bytes
//
// Helper functions:
//
// pass(message) Record a passing assertion
// fail(message) Record a failing assertion (does NOT abort the script)
// assert_eq(a, b, msg) fail(msg) if a ~= b, else pass(msg)
// assert_status(code) assert_eq(status, code, "status == "..code)
// json_decode(str) Parse a JSON string → Lua table
func RunLuaTestScript(script string, resp *Response) ([]TestResult, error) {
if strings.TrimSpace(script) == "" {
return nil, nil
}
L := lua.NewState(lua.Options{SkipOpenLibs: false})
defer L.Close()
var results []TestResult
// Inject pass / fail helpers
L.SetGlobal("pass", L.NewFunction(func(L *lua.LState) int {
msg := L.OptString(1, "")
results = append(results, TestResult{Pass: true, Message: msg})
return 0
}))
L.SetGlobal("fail", L.NewFunction(func(L *lua.LState) int {
msg := L.OptString(1, "")
results = append(results, TestResult{Pass: false, Message: msg})
return 0
}))
L.SetGlobal("assert_eq", L.NewFunction(func(L *lua.LState) int {
a := L.Get(1)
b := L.Get(2)
msg := L.OptString(3, fmt.Sprintf("%v == %v", a, b))
if lua.LVCanConvToString(a) && lua.LVCanConvToString(b) &&
lua.LVAsString(a) == lua.LVAsString(b) {
results = append(results, TestResult{Pass: true, Message: msg})
} else {
results = append(results, TestResult{
Pass: false,
Message: fmt.Sprintf("%s (got %v, expected %v)", msg, a, b),
})
}
return 0
}))
L.SetGlobal("assert_status", L.NewFunction(func(L *lua.LState) int {
expected := L.CheckInt(1)
msg := fmt.Sprintf("status == %d", expected)
if resp.StatusCode == expected {
results = append(results, TestResult{Pass: true, Message: msg})
} else {
results = append(results, TestResult{
Pass: false,
Message: fmt.Sprintf("%s (got %d)", msg, resp.StatusCode),
})
}
return 0
}))
// json_decode helper
L.SetGlobal("json_decode", L.NewFunction(func(L *lua.LState) int {
s := L.CheckString(1)
var v interface{}
if err := json.Unmarshal([]byte(s), &v); err != nil {
L.Push(lua.LNil)
L.Push(lua.LString(err.Error()))
return 2
}
L.Push(goToLua(L, v))
return 1
}))
// Set response globals
L.SetGlobal("status", lua.LNumber(resp.StatusCode))
L.SetGlobal("status_text", lua.LString(resp.Status))
L.SetGlobal("body", lua.LString(resp.Body))
L.SetGlobal("duration_ms", lua.LNumber(resp.Duration.Milliseconds()))
L.SetGlobal("size", lua.LNumber(resp.Size))
// headers table (lower-cased keys)
headersTbl := L.NewTable()
for k, vals := range resp.Headers {
if len(vals) > 0 {
L.SetField(headersTbl, strings.ToLower(k), lua.LString(vals[0]))
}
}
L.SetGlobal("headers", headersTbl)
// json_body: parse resp.Body if JSON
var jsonBody interface{}
if resp.IsJSON {
if err := json.Unmarshal([]byte(resp.Body), &jsonBody); err == nil {
L.SetGlobal("json_body", goToLua(L, jsonBody))
} else {
L.SetGlobal("json_body", lua.LNil)
}
} else {
L.SetGlobal("json_body", lua.LNil)
}
if err := L.DoString(script); err != nil {
// Include any results collected before the error
errResult := TestResult{
Pass: false,
Message: "Lua error: " + err.Error(),
}
results = append(results, errResult)
return results, nil // return results, not the error, so UI can display them
}
return results, nil
}
// goToLua converts a Go value (from json.Unmarshal) into a lua.LValue.
func goToLua(L *lua.LState, v interface{}) lua.LValue {
if v == nil {
return lua.LNil
}
switch val := v.(type) {
case bool:
return lua.LBool(val)
case float64:
return lua.LNumber(val)
case string:
return lua.LString(val)
case []interface{}:
tbl := L.NewTable()
for i, item := range val {
tbl.RawSetInt(i+1, goToLua(L, item))
}
return tbl
case map[string]interface{}:
tbl := L.NewTable()
for k, item := range val {
L.SetField(tbl, k, goToLua(L, item))
}
return tbl
default:
return lua.LString(fmt.Sprintf("%v", val))
}
}