diff --git a/recovery.go b/recovery.go index d790cad..95784f7 100644 --- a/recovery.go +++ b/recovery.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "net/http" + "net/http/httptest" "os" "runtime" ) @@ -27,6 +28,9 @@ func NewRecovery() *Recovery { } func (rec *Recovery) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + nr := httptest.NewRecorder() + nrw := NewResponseWriter(nr) + defer func() { if err := recover(); err != nil { rw.WriteHeader(http.StatusInternalServerError) @@ -39,8 +43,14 @@ func (rec *Recovery) ServeHTTP(rw http.ResponseWriter, r *http.Request, next htt if rec.PrintStack { fmt.Fprintf(rw, f, err, stack) } + } else { + for k, v := range nrw.Header() { + rw.Header()[k] = v + } + rw.WriteHeader(nr.Code) + rw.Write(nr.Body.Bytes()) } }() - next(rw, r) + next(nrw, r) } diff --git a/recovery_test.go b/recovery_test.go index 3fa264a..d213af7 100644 --- a/recovery_test.go +++ b/recovery_test.go @@ -8,6 +8,25 @@ import ( "testing" ) +func TestNoRecovery(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + + rec := NewRecovery() + rec.Logger = log.New(buff, "[negroni] ", 0) + + n := New() + // replace log for testing + n.Use(rec) + n.UseHandler(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + fmt.Fprint(res, "writing") + })) + n.ServeHTTP(recorder, (*http.Request)(nil)) + expect(t, recorder.Code, http.StatusOK) + refute(t, recorder.Body.Len(), 0) + expect(t, len(buff.String()), 0) +} + func TestRecovery(t *testing.T) { buff := bytes.NewBufferString("") recorder := httptest.NewRecorder() @@ -26,3 +45,23 @@ func TestRecovery(t *testing.T) { refute(t, recorder.Body.Len(), 0) refute(t, len(buff.String()), 0) } + +func TestRecoveryAfterWriting(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + + rec := NewRecovery() + rec.Logger = log.New(buff, "[negroni] ", 0) + + n := New() + // replace log for testing + n.Use(rec) + n.UseHandler(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + fmt.Fprint(res, "writing") + panic("here is a panic!") + })) + n.ServeHTTP(recorder, (*http.Request)(nil)) + expect(t, recorder.Code, http.StatusInternalServerError) + refute(t, recorder.Body.Len(), 0) + refute(t, len(buff.String()), 0) +}